// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.

package mysql

import (
	
	
	
	
	
	
	
	
)

// server pub keys registry
var (
	serverPubKeyLock     sync.RWMutex
	serverPubKeyRegistry map[string]*rsa.PublicKey
)

// RegisterServerPubKey registers a server RSA public key which can be used to
// send data in a secure manner to the server without receiving the public key
// in a potentially insecure way from the server first.
// Registered keys can afterwards be used adding serverPubKey=<name> to the DSN.
//
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver
// after registering it and may not be modified.
//
//	data, err := ioutil.ReadFile("mykey.pem")
//	if err != nil {
//		log.Fatal(err)
//	}
//
//	block, _ := pem.Decode(data)
//	if block == nil || block.Type != "PUBLIC KEY" {
//		log.Fatal("failed to decode PEM block containing public key")
//	}
//
//	pub, err := x509.ParsePKIXPublicKey(block.Bytes)
//	if err != nil {
//		log.Fatal(err)
//	}
//
//	if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
//		mysql.RegisterServerPubKey("mykey", rsaPubKey)
//	} else {
//		log.Fatal("not a RSA public key")
//	}
func ( string,  *rsa.PublicKey) {
	serverPubKeyLock.Lock()
	if serverPubKeyRegistry == nil {
		serverPubKeyRegistry = make(map[string]*rsa.PublicKey)
	}

	serverPubKeyRegistry[] = 
	serverPubKeyLock.Unlock()
}

// DeregisterServerPubKey removes the public key registered with the given name.
func ( string) {
	serverPubKeyLock.Lock()
	if serverPubKeyRegistry != nil {
		delete(serverPubKeyRegistry, )
	}
	serverPubKeyLock.Unlock()
}

func getServerPubKey( string) ( *rsa.PublicKey) {
	serverPubKeyLock.RLock()
	if ,  := serverPubKeyRegistry[];  {
		 = 
	}
	serverPubKeyLock.RUnlock()
	return
}

// Hash password using pre 4.1 (old password) method
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
type myRnd struct {
	seed1, seed2 uint32
}

const myRndMaxVal = 0x3FFFFFFF

// Pseudo random number generator
func newMyRnd(,  uint32) *myRnd {
	return &myRnd{
		seed1:  % myRndMaxVal,
		seed2:  % myRndMaxVal,
	}
}

// Tested to be equivalent to MariaDB's floating point variant
// http://play.golang.org/p/QHvhd4qved
// http://play.golang.org/p/RG0q4ElWDx
func ( *myRnd) () byte {
	.seed1 = (.seed1*3 + .seed2) % myRndMaxVal
	.seed2 = (.seed1 + .seed2 + 33) % myRndMaxVal

	return byte(uint64(.seed1) * 31 / myRndMaxVal)
}

// Generate binary hash from byte string using insecure pre 4.1 method
func pwHash( []byte) ( [2]uint32) {
	var  uint32 = 7
	var  uint32

	[0] = 1345345333
	[1] = 0x12345671

	for ,  := range  {
		// skip spaces and tabs in password
		if  == ' ' ||  == '\t' {
			continue
		}

		 = uint32()
		[0] ^= ((([0] & 63) + ) * ) + ([0] << 8)
		[1] += ([1] << 8) ^ [0]
		 += 
	}

	// Remove sign bit (1<<31)-1)
	[0] &= 0x7FFFFFFF
	[1] &= 0x7FFFFFFF

	return
}

// Hash password using insecure pre 4.1 method
func scrambleOldPassword( []byte,  string) []byte {
	 = [:8]

	 := pwHash([]byte())
	 := pwHash()

	 := newMyRnd([0]^[0], [1]^[1])

	var  [8]byte
	for  := range  {
		[] = .NextByte() + 64
	}

	 := .NextByte()
	for  := range  {
		[] ^= 
	}

	return [:]
}

// Hash password using 4.1+ method (SHA1)
func scramblePassword( []byte,  string) []byte {
	if len() == 0 {
		return nil
	}

	// stage1Hash = SHA1(password)
	 := sha1.New()
	.Write([]byte())
	 := .Sum(nil)

	// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
	// inner Hash
	.Reset()
	.Write()
	 := .Sum(nil)

	// outer Hash
	.Reset()
	.Write()
	.Write()
	 = .Sum(nil)

	// token = scrambleHash XOR stage1Hash
	for  := range  {
		[] ^= []
	}
	return 
}

// Hash password using MySQL 8+ method (SHA256)
func scrambleSHA256Password( []byte,  string) []byte {
	if len() == 0 {
		return nil
	}

	// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))

	 := sha256.New()
	.Write([]byte())
	 := .Sum(nil)

	.Reset()
	.Write()
	 := .Sum(nil)

	.Reset()
	.Write()
	.Write()
	 := .Sum(nil)

	for  := range  {
		[] ^= []
	}

	return 
}

func encryptPassword( string,  []byte,  *rsa.PublicKey) ([]byte, error) {
	 := make([]byte, len()+1)
	copy(, )
	for  := range  {
		 :=  % len()
		[] ^= []
	}
	 := sha1.New()
	return rsa.EncryptOAEP(, rand.Reader, , , nil)
}

func ( *mysqlConn) ( []byte,  *rsa.PublicKey) error {
	,  := encryptPassword(.cfg.Passwd, , )
	if  != nil {
		return 
	}
	return .writeAuthSwitchPacket()
}

func ( *mysqlConn) ( []byte,  string) ([]byte, error) {
	switch  {
	case "caching_sha2_password":
		 := scrambleSHA256Password(, .cfg.Passwd)
		return , nil

	case "mysql_old_password":
		if !.cfg.AllowOldPasswords {
			return nil, ErrOldPassword
		}
		if len(.cfg.Passwd) == 0 {
			return nil, nil
		}
		// Note: there are edge cases where this should work but doesn't;
		// this is currently "wontfix":
		// https://github.com/go-sql-driver/mysql/issues/184
		 := append(scrambleOldPassword([:8], .cfg.Passwd), 0)
		return , nil

	case "mysql_clear_password":
		if !.cfg.AllowCleartextPasswords {
			return nil, ErrCleartextPassword
		}
		// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
		// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
		return append([]byte(.cfg.Passwd), 0), nil

	case "mysql_native_password":
		if !.cfg.AllowNativePasswords {
			return nil, ErrNativePassword
		}
		// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
		// Native password authentication only need and will need 20-byte challenge.
		 := scramblePassword([:20], .cfg.Passwd)
		return , nil

	case "sha256_password":
		if len(.cfg.Passwd) == 0 {
			return []byte{0}, nil
		}
		// unlike caching_sha2_password, sha256_password does not accept
		// cleartext password on unix transport.
		if .cfg.TLS != nil {
			// write cleartext auth packet
			return append([]byte(.cfg.Passwd), 0), nil
		}

		 := .cfg.pubKey
		if  == nil {
			// request public key from server
			return []byte{1}, nil
		}

		// encrypted password
		,  := encryptPassword(.cfg.Passwd, , )
		return , 

	default:
		errLog.Print("unknown auth plugin:", )
		return nil, ErrUnknownPlugin
	}
}

func ( *mysqlConn) ( []byte,  string) error {
	// Read Result Packet
	, ,  := .readAuthResult()
	if  != nil {
		return 
	}

	// handle auth plugin switch, if requested
	if  != "" {
		// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
		// sent and we have to keep using the cipher sent in the init packet.
		if  == nil {
			 = 
		} else {
			// copy data from read buffer to owned slice
			copy(, )
		}

		 = 

		,  := .auth(, )
		if  != nil {
			return 
		}
		if  = .writeAuthSwitchPacket();  != nil {
			return 
		}

		// Read Result Packet
		, ,  = .readAuthResult()
		if  != nil {
			return 
		}

		// Do not allow to change the auth plugin more than once
		if  != "" {
			return ErrMalformPkt
		}
	}

	switch  {

	// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
	case "caching_sha2_password":
		switch len() {
		case 0:
			return nil // auth successful
		case 1:
			switch [0] {
			case cachingSha2PasswordFastAuthSuccess:
				if  = .readResultOK();  == nil {
					return nil // auth successful
				}

			case cachingSha2PasswordPerformFullAuthentication:
				if .cfg.TLS != nil || .cfg.Net == "unix" {
					// write cleartext auth packet
					 = .writeAuthSwitchPacket(append([]byte(.cfg.Passwd), 0))
					if  != nil {
						return 
					}
				} else {
					 := .cfg.pubKey
					if  == nil {
						// request public key from server
						,  := .buf.takeSmallBuffer(4 + 1)
						if  != nil {
							return 
						}
						[4] = cachingSha2PasswordRequestPublicKey
						 = .writePacket()
						if  != nil {
							return 
						}

						if ,  = .readPacket();  != nil {
							return 
						}

						if [0] != iAuthMoreData {
							return fmt.Errorf("unexpect resp from server for caching_sha2_password perform full authentication")
						}

						// parse public key
						,  := pem.Decode([1:])
						if  == nil {
							return fmt.Errorf("No Pem data found, data: %s", )
						}
						,  := x509.ParsePKIXPublicKey(.Bytes)
						if  != nil {
							return 
						}
						 = .(*rsa.PublicKey)
					}

					// send encrypted password
					 = .sendEncryptedPassword(, )
					if  != nil {
						return 
					}
				}
				return .readResultOK()

			default:
				return ErrMalformPkt
			}
		default:
			return ErrMalformPkt
		}

	case "sha256_password":
		switch len() {
		case 0:
			return nil // auth successful
		default:
			,  := pem.Decode()
			if  == nil {
				return fmt.Errorf("no Pem data found, data: %s", )
			}

			,  := x509.ParsePKIXPublicKey(.Bytes)
			if  != nil {
				return 
			}

			// send encrypted password
			 = .sendEncryptedPassword(, .(*rsa.PublicKey))
			if  != nil {
				return 
			}
			return .readResultOK()
		}

	default:
		return nil // auth successful
	}

	return 
}