package mysql
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"io"
"net"
"strconv"
"strings"
"time"
)
type mysqlConn struct {
buf buffer
netConn net .Conn
rawConn net .Conn
affectedRows uint64
insertId uint64
cfg *Config
maxAllowedPacket int
maxWriteSize int
writeTimeout time .Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
reset bool
watching bool
watcher chan <- context .Context
closech chan struct {}
finished chan <- struct {}
canceled atomicError
closed atomicBool
}
func (mc *mysqlConn ) handleParams () (err error ) {
var cmdSet strings .Builder
for param , val := range mc .cfg .Params {
switch param {
case "charset" :
charsets := strings .Split (val , "," )
for i := range charsets {
err = mc .exec ("SET NAMES " + charsets [i ])
if err == nil {
break
}
}
if err != nil {
return
}
default :
if cmdSet .Len () == 0 {
cmdSet .Grow (4 + len (param ) + 1 + len (val ) + 30 *(len (mc .cfg .Params )-1 ))
cmdSet .WriteString ("SET " )
} else {
cmdSet .WriteString (", " )
}
cmdSet .WriteString (param )
cmdSet .WriteString (" = " )
cmdSet .WriteString (val )
}
}
if cmdSet .Len () > 0 {
err = mc .exec (cmdSet .String ())
if err != nil {
return
}
}
return
}
func (mc *mysqlConn ) markBadConn (err error ) error {
if mc == nil {
return err
}
if err != errBadConnNoWrite {
return err
}
return driver .ErrBadConn
}
func (mc *mysqlConn ) Begin () (driver .Tx , error ) {
return mc .begin (false )
}
func (mc *mysqlConn ) begin (readOnly bool ) (driver .Tx , error ) {
if mc .closed .Load () {
errLog .Print (ErrInvalidConn )
return nil , driver .ErrBadConn
}
var q string
if readOnly {
q = "START TRANSACTION READ ONLY"
} else {
q = "START TRANSACTION"
}
err := mc .exec (q )
if err == nil {
return &mysqlTx {mc }, err
}
return nil , mc .markBadConn (err )
}
func (mc *mysqlConn ) Close () (err error ) {
if !mc .closed .Load () {
err = mc .writeCommandPacket (comQuit )
}
mc .cleanup ()
return
}
func (mc *mysqlConn ) cleanup () {
if mc .closed .Swap (true ) {
return
}
close (mc .closech )
if mc .netConn == nil {
return
}
if err := mc .netConn .Close (); err != nil {
errLog .Print (err )
}
}
func (mc *mysqlConn ) error () error {
if mc .closed .Load () {
if err := mc .canceled .Value (); err != nil {
return err
}
return ErrInvalidConn
}
return nil
}
func (mc *mysqlConn ) Prepare (query string ) (driver .Stmt , error ) {
if mc .closed .Load () {
errLog .Print (ErrInvalidConn )
return nil , driver .ErrBadConn
}
err := mc .writeCommandPacketStr (comStmtPrepare , query )
if err != nil {
errLog .Print (err )
return nil , driver .ErrBadConn
}
stmt := &mysqlStmt {
mc : mc ,
}
columnCount , err := stmt .readPrepareResultPacket ()
if err == nil {
if stmt .paramCount > 0 {
if err = mc .readUntilEOF (); err != nil {
return nil , err
}
}
if columnCount > 0 {
err = mc .readUntilEOF ()
}
}
return stmt , err
}
func (mc *mysqlConn ) interpolateParams (query string , args []driver .Value ) (string , error ) {
if strings .Count (query , "?" ) != len (args ) {
return "" , driver .ErrSkip
}
buf , err := mc .buf .takeCompleteBuffer ()
if err != nil {
errLog .Print (err )
return "" , ErrInvalidConn
}
buf = buf [:0 ]
argPos := 0
for i := 0 ; i < len (query ); i ++ {
q := strings .IndexByte (query [i :], '?' )
if q == -1 {
buf = append (buf , query [i :]...)
break
}
buf = append (buf , query [i :i +q ]...)
i += q
arg := args [argPos ]
argPos ++
if arg == nil {
buf = append (buf , "NULL" ...)
continue
}
switch v := arg .(type ) {
case int64 :
buf = strconv .AppendInt (buf , v , 10 )
case uint64 :
buf = strconv .AppendUint (buf , v , 10 )
case float64 :
buf = strconv .AppendFloat (buf , v , 'g' , -1 , 64 )
case bool :
if v {
buf = append (buf , '1' )
} else {
buf = append (buf , '0' )
}
case time .Time :
if v .IsZero () {
buf = append (buf , "'0000-00-00'" ...)
} else {
buf = append (buf , '\'' )
buf , err = appendDateTime (buf , v .In (mc .cfg .Loc ))
if err != nil {
return "" , err
}
buf = append (buf , '\'' )
}
case json .RawMessage :
buf = append (buf , '\'' )
if mc .status &statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash (buf , v )
} else {
buf = escapeBytesQuotes (buf , v )
}
buf = append (buf , '\'' )
case []byte :
if v == nil {
buf = append (buf , "NULL" ...)
} else {
buf = append (buf , "_binary'" ...)
if mc .status &statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash (buf , v )
} else {
buf = escapeBytesQuotes (buf , v )
}
buf = append (buf , '\'' )
}
case string :
buf = append (buf , '\'' )
if mc .status &statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash (buf , v )
} else {
buf = escapeStringQuotes (buf , v )
}
buf = append (buf , '\'' )
default :
return "" , driver .ErrSkip
}
if len (buf )+4 > mc .maxAllowedPacket {
return "" , driver .ErrSkip
}
}
if argPos != len (args ) {
return "" , driver .ErrSkip
}
return string (buf ), nil
}
func (mc *mysqlConn ) Exec (query string , args []driver .Value ) (driver .Result , error ) {
if mc .closed .Load () {
errLog .Print (ErrInvalidConn )
return nil , driver .ErrBadConn
}
if len (args ) != 0 {
if !mc .cfg .InterpolateParams {
return nil , driver .ErrSkip
}
prepared , err := mc .interpolateParams (query , args )
if err != nil {
return nil , err
}
query = prepared
}
mc .affectedRows = 0
mc .insertId = 0
err := mc .exec (query )
if err == nil {
return &mysqlResult {
affectedRows : int64 (mc .affectedRows ),
insertId : int64 (mc .insertId ),
}, err
}
return nil , mc .markBadConn (err )
}
func (mc *mysqlConn ) exec (query string ) error {
if err := mc .writeCommandPacketStr (comQuery , query ); err != nil {
return mc .markBadConn (err )
}
resLen , err := mc .readResultSetHeaderPacket ()
if err != nil {
return err
}
if resLen > 0 {
if err := mc .readUntilEOF (); err != nil {
return err
}
if err := mc .readUntilEOF (); err != nil {
return err
}
}
return mc .discardResults ()
}
func (mc *mysqlConn ) Query (query string , args []driver .Value ) (driver .Rows , error ) {
return mc .query (query , args )
}
func (mc *mysqlConn ) query (query string , args []driver .Value ) (*textRows , error ) {
if mc .closed .Load () {
errLog .Print (ErrInvalidConn )
return nil , driver .ErrBadConn
}
if len (args ) != 0 {
if !mc .cfg .InterpolateParams {
return nil , driver .ErrSkip
}
prepared , err := mc .interpolateParams (query , args )
if err != nil {
return nil , err
}
query = prepared
}
err := mc .writeCommandPacketStr (comQuery , query )
if err == nil {
var resLen int
resLen , err = mc .readResultSetHeaderPacket ()
if err == nil {
rows := new (textRows )
rows .mc = mc
if resLen == 0 {
rows .rs .done = true
switch err := rows .NextResultSet (); err {
case nil , io .EOF :
return rows , nil
default :
return nil , err
}
}
rows .rs .columns , err = mc .readColumns (resLen )
return rows , err
}
}
return nil , mc .markBadConn (err )
}
func (mc *mysqlConn ) getSystemVar (name string ) ([]byte , error ) {
if err := mc .writeCommandPacketStr (comQuery , "SELECT @@" +name ); err != nil {
return nil , err
}
resLen , err := mc .readResultSetHeaderPacket ()
if err == nil {
rows := new (textRows )
rows .mc = mc
rows .rs .columns = []mysqlField {{fieldType : fieldTypeVarChar }}
if resLen > 0 {
if err := mc .readUntilEOF (); err != nil {
return nil , err
}
}
dest := make ([]driver .Value , resLen )
if err = rows .readRow (dest ); err == nil {
return dest [0 ].([]byte ), mc .readUntilEOF ()
}
}
return nil , err
}
func (mc *mysqlConn ) cancel (err error ) {
mc .canceled .Set (err )
mc .cleanup ()
}
func (mc *mysqlConn ) finish () {
if !mc .watching || mc .finished == nil {
return
}
select {
case mc .finished <- struct {}{}:
mc .watching = false
case <- mc .closech :
}
}
func (mc *mysqlConn ) Ping (ctx context .Context ) (err error ) {
if mc .closed .Load () {
errLog .Print (ErrInvalidConn )
return driver .ErrBadConn
}
if err = mc .watchCancel (ctx ); err != nil {
return
}
defer mc .finish ()
if err = mc .writeCommandPacket (comPing ); err != nil {
return mc .markBadConn (err )
}
return mc .readResultOK ()
}
func (mc *mysqlConn ) BeginTx (ctx context .Context , opts driver .TxOptions ) (driver .Tx , error ) {
if mc .closed .Load () {
return nil , driver .ErrBadConn
}
if err := mc .watchCancel (ctx ); err != nil {
return nil , err
}
defer mc .finish ()
if sql .IsolationLevel (opts .Isolation ) != sql .LevelDefault {
level , err := mapIsolationLevel (opts .Isolation )
if err != nil {
return nil , err
}
err = mc .exec ("SET TRANSACTION ISOLATION LEVEL " + level )
if err != nil {
return nil , err
}
}
return mc .begin (opts .ReadOnly )
}
func (mc *mysqlConn ) QueryContext (ctx context .Context , query string , args []driver .NamedValue ) (driver .Rows , error ) {
dargs , err := namedValueToValue (args )
if err != nil {
return nil , err
}
if err := mc .watchCancel (ctx ); err != nil {
return nil , err
}
rows , err := mc .query (query , dargs )
if err != nil {
mc .finish ()
return nil , err
}
rows .finish = mc .finish
return rows , err
}
func (mc *mysqlConn ) ExecContext (ctx context .Context , query string , args []driver .NamedValue ) (driver .Result , error ) {
dargs , err := namedValueToValue (args )
if err != nil {
return nil , err
}
if err := mc .watchCancel (ctx ); err != nil {
return nil , err
}
defer mc .finish ()
return mc .Exec (query , dargs )
}
func (mc *mysqlConn ) PrepareContext (ctx context .Context , query string ) (driver .Stmt , error ) {
if err := mc .watchCancel (ctx ); err != nil {
return nil , err
}
stmt , err := mc .Prepare (query )
mc .finish ()
if err != nil {
return nil , err
}
select {
default :
case <- ctx .Done ():
stmt .Close ()
return nil , ctx .Err ()
}
return stmt , nil
}
func (stmt *mysqlStmt ) QueryContext (ctx context .Context , args []driver .NamedValue ) (driver .Rows , error ) {
dargs , err := namedValueToValue (args )
if err != nil {
return nil , err
}
if err := stmt .mc .watchCancel (ctx ); err != nil {
return nil , err
}
rows , err := stmt .query (dargs )
if err != nil {
stmt .mc .finish ()
return nil , err
}
rows .finish = stmt .mc .finish
return rows , err
}
func (stmt *mysqlStmt ) ExecContext (ctx context .Context , args []driver .NamedValue ) (driver .Result , error ) {
dargs , err := namedValueToValue (args )
if err != nil {
return nil , err
}
if err := stmt .mc .watchCancel (ctx ); err != nil {
return nil , err
}
defer stmt .mc .finish ()
return stmt .Exec (dargs )
}
func (mc *mysqlConn ) watchCancel (ctx context .Context ) error {
if mc .watching {
mc .cleanup ()
return nil
}
if err := ctx .Err (); err != nil {
return err
}
if ctx .Done () == nil {
return nil
}
if mc .watcher == nil {
return nil
}
mc .watching = true
mc .watcher <- ctx
return nil
}
func (mc *mysqlConn ) startWatcher () {
watcher := make (chan context .Context , 1 )
mc .watcher = watcher
finished := make (chan struct {})
mc .finished = finished
go func () {
for {
var ctx context .Context
select {
case ctx = <- watcher :
case <- mc .closech :
return
}
select {
case <- ctx .Done ():
mc .cancel (ctx .Err ())
case <- finished :
case <- mc .closech :
return
}
}
}()
}
func (mc *mysqlConn ) CheckNamedValue (nv *driver .NamedValue ) (err error ) {
nv .Value , err = converter {}.ConvertValue (nv .Value )
return
}
func (mc *mysqlConn ) ResetSession (ctx context .Context ) error {
if mc .closed .Load () {
return driver .ErrBadConn
}
mc .reset = true
return nil
}
func (mc *mysqlConn ) IsValid () bool {
return !mc .closed .Load ()
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .