// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.

package mysql

import (
	
	
	
	
	
	
	
	
	
	
	
	
)

var (
	errInvalidDSNUnescaped       = errors.New("invalid DSN: did you forget to escape a param value?")
	errInvalidDSNAddr            = errors.New("invalid DSN: network address not terminated (missing closing brace)")
	errInvalidDSNNoSlash         = errors.New("invalid DSN: missing the slash separating the database name")
	errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
)

// Config is a configuration parsed from a DSN string.
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
	User             string            // Username
	Passwd           string            // Password (requires User)
	Net              string            // Network type
	Addr             string            // Network address (requires Net)
	DBName           string            // Database name
	Params           map[string]string // Connection parameters
	Collation        string            // Connection collation
	Loc              *time.Location    // Location for time.Time values
	MaxAllowedPacket int               // Max packet size allowed
	ServerPubKey     string            // Server public key name
	pubKey           *rsa.PublicKey    // Server public key
	TLSConfig        string            // TLS configuration name
	TLS              *tls.Config       // TLS configuration, its priority is higher than TLSConfig
	Timeout          time.Duration     // Dial timeout
	ReadTimeout      time.Duration     // I/O read timeout
	WriteTimeout     time.Duration     // I/O write timeout

	AllowAllFiles            bool // Allow all files to be used with LOAD DATA LOCAL INFILE
	AllowCleartextPasswords  bool // Allows the cleartext client side plugin
	AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
	AllowNativePasswords     bool // Allows the native password authentication method
	AllowOldPasswords        bool // Allows the old insecure password method
	CheckConnLiveness        bool // Check connections for liveness before using them
	ClientFoundRows          bool // Return number of matching rows instead of rows changed
	ColumnsWithAlias         bool // Prepend table alias to column names
	InterpolateParams        bool // Interpolate placeholders into query string
	MultiStatements          bool // Allow multiple statements in one query
	ParseTime                bool // Parse time values to time.Time
	RejectReadOnly           bool // Reject read-only connections
}

// NewConfig creates a new Config and sets default values.
func () *Config {
	return &Config{
		Collation:            defaultCollation,
		Loc:                  time.UTC,
		MaxAllowedPacket:     defaultMaxAllowedPacket,
		AllowNativePasswords: true,
		CheckConnLiveness:    true,
	}
}

func ( *Config) () *Config {
	 := *
	if .TLS != nil {
		.TLS = .TLS.Clone()
	}
	if len(.Params) > 0 {
		.Params = make(map[string]string, len(.Params))
		for ,  := range .Params {
			.Params[] = 
		}
	}
	if .pubKey != nil {
		.pubKey = &rsa.PublicKey{
			N: new(big.Int).Set(.pubKey.N),
			E: .pubKey.E,
		}
	}
	return &
}

func ( *Config) () error {
	if .InterpolateParams && unsafeCollations[.Collation] {
		return errInvalidDSNUnsafeCollation
	}

	// Set default network if empty
	if .Net == "" {
		.Net = "tcp"
	}

	// Set default address if empty
	if .Addr == "" {
		switch .Net {
		case "tcp":
			.Addr = "127.0.0.1:3306"
		case "unix":
			.Addr = "/tmp/mysql.sock"
		default:
			return errors.New("default addr for network '" + .Net + "' unknown")
		}
	} else if .Net == "tcp" {
		.Addr = ensureHavePort(.Addr)
	}

	if .TLS == nil {
		switch .TLSConfig {
		case "false", "":
			// don't set anything
		case "true":
			.TLS = &tls.Config{}
		case "skip-verify":
			.TLS = &tls.Config{InsecureSkipVerify: true}
		case "preferred":
			.TLS = &tls.Config{InsecureSkipVerify: true}
			.AllowFallbackToPlaintext = true
		default:
			.TLS = getTLSConfigClone(.TLSConfig)
			if .TLS == nil {
				return errors.New("invalid value / unknown config name: " + .TLSConfig)
			}
		}
	}

	if .TLS != nil && .TLS.ServerName == "" && !.TLS.InsecureSkipVerify {
		, ,  := net.SplitHostPort(.Addr)
		if  == nil {
			.TLS.ServerName = 
		}
	}

	if .ServerPubKey != "" {
		.pubKey = getServerPubKey(.ServerPubKey)
		if .pubKey == nil {
			return errors.New("invalid value / unknown server pub key name: " + .ServerPubKey)
		}
	}

	return nil
}

func writeDSNParam( *bytes.Buffer,  *bool, ,  string) {
	.Grow(1 + len() + 1 + len())
	if !* {
		* = true
		.WriteByte('?')
	} else {
		.WriteByte('&')
	}
	.WriteString()
	.WriteByte('=')
	.WriteString()
}

// FormatDSN formats the given Config into a DSN string which can be passed to
// the driver.
func ( *Config) () string {
	var  bytes.Buffer

	// [username[:password]@]
	if len(.User) > 0 {
		.WriteString(.User)
		if len(.Passwd) > 0 {
			.WriteByte(':')
			.WriteString(.Passwd)
		}
		.WriteByte('@')
	}

	// [protocol[(address)]]
	if len(.Net) > 0 {
		.WriteString(.Net)
		if len(.Addr) > 0 {
			.WriteByte('(')
			.WriteString(.Addr)
			.WriteByte(')')
		}
	}

	// /dbname
	.WriteByte('/')
	.WriteString(.DBName)

	// [?param1=value1&...&paramN=valueN]
	 := false

	if .AllowAllFiles {
		 = true
		.WriteString("?allowAllFiles=true")
	}

	if .AllowCleartextPasswords {
		writeDSNParam(&, &, "allowCleartextPasswords", "true")
	}

	if .AllowFallbackToPlaintext {
		writeDSNParam(&, &, "allowFallbackToPlaintext", "true")
	}

	if !.AllowNativePasswords {
		writeDSNParam(&, &, "allowNativePasswords", "false")
	}

	if .AllowOldPasswords {
		writeDSNParam(&, &, "allowOldPasswords", "true")
	}

	if !.CheckConnLiveness {
		writeDSNParam(&, &, "checkConnLiveness", "false")
	}

	if .ClientFoundRows {
		writeDSNParam(&, &, "clientFoundRows", "true")
	}

	if  := .Collation;  != defaultCollation && len() > 0 {
		writeDSNParam(&, &, "collation", )
	}

	if .ColumnsWithAlias {
		writeDSNParam(&, &, "columnsWithAlias", "true")
	}

	if .InterpolateParams {
		writeDSNParam(&, &, "interpolateParams", "true")
	}

	if .Loc != time.UTC && .Loc != nil {
		writeDSNParam(&, &, "loc", url.QueryEscape(.Loc.String()))
	}

	if .MultiStatements {
		writeDSNParam(&, &, "multiStatements", "true")
	}

	if .ParseTime {
		writeDSNParam(&, &, "parseTime", "true")
	}

	if .ReadTimeout > 0 {
		writeDSNParam(&, &, "readTimeout", .ReadTimeout.String())
	}

	if .RejectReadOnly {
		writeDSNParam(&, &, "rejectReadOnly", "true")
	}

	if len(.ServerPubKey) > 0 {
		writeDSNParam(&, &, "serverPubKey", url.QueryEscape(.ServerPubKey))
	}

	if .Timeout > 0 {
		writeDSNParam(&, &, "timeout", .Timeout.String())
	}

	if len(.TLSConfig) > 0 {
		writeDSNParam(&, &, "tls", url.QueryEscape(.TLSConfig))
	}

	if .WriteTimeout > 0 {
		writeDSNParam(&, &, "writeTimeout", .WriteTimeout.String())
	}

	if .MaxAllowedPacket != defaultMaxAllowedPacket {
		writeDSNParam(&, &, "maxAllowedPacket", strconv.Itoa(.MaxAllowedPacket))
	}

	// other params
	if .Params != nil {
		var  []string
		for  := range .Params {
			 = append(, )
		}
		sort.Strings()
		for ,  := range  {
			writeDSNParam(&, &, , url.QueryEscape(.Params[]))
		}
	}

	return .String()
}

// ParseDSN parses the DSN string to a Config
func ( string) ( *Config,  error) {
	// New config with some default values
	 = NewConfig()

	// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
	// Find the last '/' (since the password or the net addr might contain a '/')
	 := false
	for  := len() - 1;  >= 0; -- {
		if [] == '/' {
			 = true
			var ,  int

			// left part is empty if i <= 0
			if  > 0 {
				// [username[:password]@][protocol[(address)]]
				// Find the last '@' in dsn[:i]
				for  = ;  >= 0; -- {
					if [] == '@' {
						// username[:password]
						// Find the first ':' in dsn[:j]
						for  = 0;  < ; ++ {
							if [] == ':' {
								.Passwd = [+1 : ]
								break
							}
						}
						.User = [:]

						break
					}
				}

				// [protocol[(address)]]
				// Find the first '(' in dsn[j+1:i]
				for  =  + 1;  < ; ++ {
					if [] == '(' {
						// dsn[i-1] must be == ')' if an address is specified
						if [-1] != ')' {
							if strings.ContainsRune([+1:], ')') {
								return nil, errInvalidDSNUnescaped
							}
							return nil, errInvalidDSNAddr
						}
						.Addr = [+1 : -1]
						break
					}
				}
				.Net = [+1 : ]
			}

			// dbname[?param1=value1&...&paramN=valueN]
			// Find the first '?' in dsn[i+1:]
			for  =  + 1;  < len(); ++ {
				if [] == '?' {
					if  = parseDSNParams(, [+1:]);  != nil {
						return
					}
					break
				}
			}
			.DBName = [+1 : ]

			break
		}
	}

	if ! && len() > 0 {
		return nil, errInvalidDSNNoSlash
	}

	if  = .normalize();  != nil {
		return nil, 
	}
	return
}

// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams( *Config,  string) ( error) {
	for ,  := range strings.Split(, "&") {
		 := strings.SplitN(, "=", 2)
		if len() != 2 {
			continue
		}

		// cfg params
		switch  := [1]; [0] {
		// Disable INFILE allowlist / enable all files
		case "allowAllFiles":
			var  bool
			.AllowAllFiles,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Use cleartext authentication mode (MySQL 5.5.10+)
		case "allowCleartextPasswords":
			var  bool
			.AllowCleartextPasswords,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Allow fallback to unencrypted connection if server does not support TLS
		case "allowFallbackToPlaintext":
			var  bool
			.AllowFallbackToPlaintext,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Use native password authentication
		case "allowNativePasswords":
			var  bool
			.AllowNativePasswords,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Use old authentication mode (pre MySQL 4.1)
		case "allowOldPasswords":
			var  bool
			.AllowOldPasswords,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Check connections for Liveness before using them
		case "checkConnLiveness":
			var  bool
			.CheckConnLiveness,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Switch "rowsAffected" mode
		case "clientFoundRows":
			var  bool
			.ClientFoundRows,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Collation
		case "collation":
			.Collation = 

		case "columnsWithAlias":
			var  bool
			.ColumnsWithAlias,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Compression
		case "compress":
			return errors.New("compression not implemented yet")

		// Enable client side placeholder substitution
		case "interpolateParams":
			var  bool
			.InterpolateParams,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Time Location
		case "loc":
			if ,  = url.QueryUnescape();  != nil {
				return
			}
			.Loc,  = time.LoadLocation()
			if  != nil {
				return
			}

		// multiple statements in one query
		case "multiStatements":
			var  bool
			.MultiStatements,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// time.Time parsing
		case "parseTime":
			var  bool
			.ParseTime,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// I/O read Timeout
		case "readTimeout":
			.ReadTimeout,  = time.ParseDuration()
			if  != nil {
				return
			}

		// Reject read-only connections
		case "rejectReadOnly":
			var  bool
			.RejectReadOnly,  = readBool()
			if ! {
				return errors.New("invalid bool value: " + )
			}

		// Server public key
		case "serverPubKey":
			,  := url.QueryUnescape()
			if  != nil {
				return fmt.Errorf("invalid value for server pub key name: %v", )
			}
			.ServerPubKey = 

		// Strict mode
		case "strict":
			panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")

		// Dial Timeout
		case "timeout":
			.Timeout,  = time.ParseDuration()
			if  != nil {
				return
			}

		// TLS-Encryption
		case "tls":
			,  := readBool()
			if  {
				if  {
					.TLSConfig = "true"
				} else {
					.TLSConfig = "false"
				}
			} else if  := strings.ToLower();  == "skip-verify" ||  == "preferred" {
				.TLSConfig = 
			} else {
				,  := url.QueryUnescape()
				if  != nil {
					return fmt.Errorf("invalid value for TLS config name: %v", )
				}
				.TLSConfig = 
			}

		// I/O write Timeout
		case "writeTimeout":
			.WriteTimeout,  = time.ParseDuration()
			if  != nil {
				return
			}
		case "maxAllowedPacket":
			.MaxAllowedPacket,  = strconv.Atoi()
			if  != nil {
				return
			}
		default:
			// lazy init
			if .Params == nil {
				.Params = make(map[string]string)
			}

			if .Params[[0]],  = url.QueryUnescape();  != nil {
				return
			}
		}
	}

	return
}

func ensureHavePort( string) string {
	if , ,  := net.SplitHostPort();  != nil {
		return net.JoinHostPort(, "3306")
	}
	return 
}