package stdlib
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"math"
"math/rand"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
var databaseSQLResultFormats pgx .QueryResultFormatsByOID
var pgxDriver *Driver
func init() {
pgxDriver = &Driver {
configs : make (map [string ]*pgx .ConnConfig ),
}
if !contains (sql .Drivers (), "pgx" ) {
sql .Register ("pgx" , pgxDriver )
}
sql .Register ("pgx/v5" , pgxDriver )
databaseSQLResultFormats = pgx .QueryResultFormatsByOID {
pgtype .BoolOID : 1 ,
pgtype .ByteaOID : 1 ,
pgtype .CIDOID : 1 ,
pgtype .DateOID : 1 ,
pgtype .Float4OID : 1 ,
pgtype .Float8OID : 1 ,
pgtype .Int2OID : 1 ,
pgtype .Int4OID : 1 ,
pgtype .Int8OID : 1 ,
pgtype .OIDOID : 1 ,
pgtype .TimestampOID : 1 ,
pgtype .TimestamptzOID : 1 ,
pgtype .XIDOID : 1 ,
}
}
func contains(list []string , y string ) bool {
for _ , x := range list {
if x == y {
return true
}
}
return false
}
type OptionOpenDB func (*connector )
func OptionBeforeConnect (bc func (context .Context , *pgx .ConnConfig ) error ) OptionOpenDB {
return func (dc *connector ) {
dc .BeforeConnect = bc
}
}
func OptionAfterConnect (ac func (context .Context , *pgx .Conn ) error ) OptionOpenDB {
return func (dc *connector ) {
dc .AfterConnect = ac
}
}
func OptionResetSession (rs func (context .Context , *pgx .Conn ) error ) OptionOpenDB {
return func (dc *connector ) {
dc .ResetSession = rs
}
}
func RandomizeHostOrderFunc (ctx context .Context , connConfig *pgx .ConnConfig ) error {
if len (connConfig .Fallbacks ) == 0 {
return nil
}
newFallbacks := append ([]*pgconn .FallbackConfig {{
Host : connConfig .Host ,
Port : connConfig .Port ,
TLSConfig : connConfig .TLSConfig ,
}}, connConfig .Fallbacks ...)
rand .Shuffle (len (newFallbacks ), func (i , j int ) {
newFallbacks [i ], newFallbacks [j ] = newFallbacks [j ], newFallbacks [i ]
})
newPrimary := newFallbacks [len (newFallbacks )-1 ]
connConfig .Host = newPrimary .Host
connConfig .Port = newPrimary .Port
connConfig .TLSConfig = newPrimary .TLSConfig
connConfig .Fallbacks = newFallbacks [:len (newFallbacks )-1 ]
return nil
}
func GetConnector (config pgx .ConnConfig , opts ...OptionOpenDB ) driver .Connector {
c := connector {
ConnConfig : config ,
BeforeConnect : func (context .Context , *pgx .ConnConfig ) error { return nil },
AfterConnect : func (context .Context , *pgx .Conn ) error { return nil },
ResetSession : func (context .Context , *pgx .Conn ) error { return nil },
driver : pgxDriver ,
}
for _ , opt := range opts {
opt (&c )
}
return c
}
func OpenDB (config pgx .ConnConfig , opts ...OptionOpenDB ) *sql .DB {
c := GetConnector (config , opts ...)
return sql .OpenDB (c )
}
type connector struct {
pgx .ConnConfig
BeforeConnect func (context .Context , *pgx .ConnConfig ) error
AfterConnect func (context .Context , *pgx .Conn ) error
ResetSession func (context .Context , *pgx .Conn ) error
driver *Driver
}
func (c connector ) Connect (ctx context .Context ) (driver .Conn , error ) {
var (
err error
conn *pgx .Conn
)
connConfig := c .ConnConfig
if err = c .BeforeConnect (ctx , &connConfig ); err != nil {
return nil , err
}
if conn , err = pgx .ConnectConfig (ctx , &connConfig ); err != nil {
return nil , err
}
if err = c .AfterConnect (ctx , conn ); err != nil {
return nil , err
}
return &Conn {conn : conn , driver : c .driver , connConfig : connConfig , resetSessionFunc : c .ResetSession }, nil
}
func (c connector ) Driver () driver .Driver {
return c .driver
}
func GetDefaultDriver () driver .Driver {
return pgxDriver
}
type Driver struct {
configMutex sync .Mutex
configs map [string ]*pgx .ConnConfig
sequence int
}
func (d *Driver ) Open (name string ) (driver .Conn , error ) {
ctx , cancel := context .WithTimeout (context .Background (), 60 *time .Second )
defer cancel ()
connector , err := d .OpenConnector (name )
if err != nil {
return nil , err
}
return connector .Connect (ctx )
}
func (d *Driver ) OpenConnector (name string ) (driver .Connector , error ) {
return &driverConnector {driver : d , name : name }, nil
}
func (d *Driver ) registerConnConfig (c *pgx .ConnConfig ) string {
d .configMutex .Lock ()
connStr := fmt .Sprintf ("registeredConnConfig%d" , d .sequence )
d .sequence ++
d .configs [connStr ] = c
d .configMutex .Unlock ()
return connStr
}
func (d *Driver ) unregisterConnConfig (connStr string ) {
d .configMutex .Lock ()
delete (d .configs , connStr )
d .configMutex .Unlock ()
}
type driverConnector struct {
driver *Driver
name string
}
func (dc *driverConnector ) Connect (ctx context .Context ) (driver .Conn , error ) {
var connConfig *pgx .ConnConfig
dc .driver .configMutex .Lock ()
connConfig = dc .driver .configs [dc .name ]
dc .driver .configMutex .Unlock ()
if connConfig == nil {
var err error
connConfig , err = pgx .ParseConfig (dc .name )
if err != nil {
return nil , err
}
}
conn , err := pgx .ConnectConfig (ctx , connConfig )
if err != nil {
return nil , err
}
c := &Conn {
conn : conn ,
driver : dc .driver ,
connConfig : *connConfig ,
resetSessionFunc : func (context .Context , *pgx .Conn ) error { return nil },
}
return c , nil
}
func (dc *driverConnector ) Driver () driver .Driver {
return dc .driver
}
func RegisterConnConfig (c *pgx .ConnConfig ) string {
return pgxDriver .registerConnConfig (c )
}
func UnregisterConnConfig (connStr string ) {
pgxDriver .unregisterConnConfig (connStr )
}
type Conn struct {
conn *pgx .Conn
psCount int64
driver *Driver
connConfig pgx .ConnConfig
resetSessionFunc func (context .Context , *pgx .Conn ) error
lastResetSessionTime time .Time
}
func (c *Conn ) Conn () *pgx .Conn {
return c .conn
}
func (c *Conn ) Prepare (query string ) (driver .Stmt , error ) {
return c .PrepareContext (context .Background (), query )
}
func (c *Conn ) PrepareContext (ctx context .Context , query string ) (driver .Stmt , error ) {
if c .conn .IsClosed () {
return nil , driver .ErrBadConn
}
name := fmt .Sprintf ("pgx_%d" , c .psCount )
c .psCount ++
sd , err := c .conn .Prepare (ctx , name , query )
if err != nil {
return nil , err
}
return &Stmt {sd : sd , conn : c }, nil
}
func (c *Conn ) Close () error {
ctx , cancel := context .WithTimeout (context .Background (), time .Second *5 )
defer cancel ()
return c .conn .Close (ctx )
}
func (c *Conn ) Begin () (driver .Tx , error ) {
return c .BeginTx (context .Background (), driver .TxOptions {})
}
func (c *Conn ) BeginTx (ctx context .Context , opts driver .TxOptions ) (driver .Tx , error ) {
if c .conn .IsClosed () {
return nil , driver .ErrBadConn
}
var pgxOpts pgx .TxOptions
switch sql .IsolationLevel (opts .Isolation ) {
case sql .LevelDefault :
case sql .LevelReadUncommitted :
pgxOpts .IsoLevel = pgx .ReadUncommitted
case sql .LevelReadCommitted :
pgxOpts .IsoLevel = pgx .ReadCommitted
case sql .LevelRepeatableRead , sql .LevelSnapshot :
pgxOpts .IsoLevel = pgx .RepeatableRead
case sql .LevelSerializable :
pgxOpts .IsoLevel = pgx .Serializable
default :
return nil , fmt .Errorf ("unsupported isolation: %v" , opts .Isolation )
}
if opts .ReadOnly {
pgxOpts .AccessMode = pgx .ReadOnly
}
tx , err := c .conn .BeginTx (ctx , pgxOpts )
if err != nil {
return nil , err
}
return wrapTx {ctx : ctx , tx : tx }, nil
}
func (c *Conn ) ExecContext (ctx context .Context , query string , argsV []driver .NamedValue ) (driver .Result , error ) {
if c .conn .IsClosed () {
return nil , driver .ErrBadConn
}
args := namedValueToInterface (argsV )
commandTag , err := c .conn .Exec (ctx , query , args ...)
if err != nil {
if pgconn .SafeToRetry (err ) {
return nil , driver .ErrBadConn
}
}
return driver .RowsAffected (commandTag .RowsAffected ()), err
}
func (c *Conn ) QueryContext (ctx context .Context , query string , argsV []driver .NamedValue ) (driver .Rows , error ) {
if c .conn .IsClosed () {
return nil , driver .ErrBadConn
}
args := []any {databaseSQLResultFormats }
args = append (args , namedValueToInterface (argsV )...)
rows , err := c .conn .Query (ctx , query , args ...)
if err != nil {
if pgconn .SafeToRetry (err ) {
return nil , driver .ErrBadConn
}
return nil , err
}
more := rows .Next ()
if err = rows .Err (); err != nil {
rows .Close ()
return nil , err
}
return &Rows {conn : c , rows : rows , skipNext : true , skipNextMore : more }, nil
}
func (c *Conn ) Ping (ctx context .Context ) error {
if c .conn .IsClosed () {
return driver .ErrBadConn
}
err := c .conn .Ping (ctx )
if err != nil {
c .Close ()
return driver .ErrBadConn
}
return nil
}
func (c *Conn ) CheckNamedValue (*driver .NamedValue ) error {
return nil
}
func (c *Conn ) ResetSession (ctx context .Context ) error {
if c .conn .IsClosed () {
return driver .ErrBadConn
}
now := time .Now ()
if now .Sub (c .lastResetSessionTime ) > time .Second {
if err := c .conn .PgConn ().CheckConn (); err != nil {
return driver .ErrBadConn
}
}
c .lastResetSessionTime = now
return c .resetSessionFunc (ctx , c .conn )
}
type Stmt struct {
sd *pgconn .StatementDescription
conn *Conn
}
func (s *Stmt ) Close () error {
ctx , cancel := context .WithTimeout (context .Background (), time .Second *5 )
defer cancel ()
return s .conn .conn .Deallocate (ctx , s .sd .Name )
}
func (s *Stmt ) NumInput () int {
return len (s .sd .ParamOIDs )
}
func (s *Stmt ) Exec (argsV []driver .Value ) (driver .Result , error ) {
return nil , errors .New ("Stmt.Exec deprecated and not implemented" )
}
func (s *Stmt ) ExecContext (ctx context .Context , argsV []driver .NamedValue ) (driver .Result , error ) {
return s .conn .ExecContext (ctx , s .sd .Name , argsV )
}
func (s *Stmt ) Query (argsV []driver .Value ) (driver .Rows , error ) {
return nil , errors .New ("Stmt.Query deprecated and not implemented" )
}
func (s *Stmt ) QueryContext (ctx context .Context , argsV []driver .NamedValue ) (driver .Rows , error ) {
return s .conn .QueryContext (ctx , s .sd .Name , argsV )
}
type rowValueFunc func (src []byte ) (driver .Value , error )
type Rows struct {
conn *Conn
rows pgx .Rows
valueFuncs []rowValueFunc
skipNext bool
skipNextMore bool
columnNames []string
}
func (r *Rows ) Columns () []string {
if r .columnNames == nil {
fields := r .rows .FieldDescriptions ()
r .columnNames = make ([]string , len (fields ))
for i , fd := range fields {
r .columnNames [i ] = string (fd .Name )
}
}
return r .columnNames
}
func (r *Rows ) ColumnTypeDatabaseTypeName (index int ) string {
if dt , ok := r .conn .conn .TypeMap ().TypeForOID (r .rows .FieldDescriptions ()[index ].DataTypeOID ); ok {
return strings .ToUpper (dt .Name )
}
return strconv .FormatInt (int64 (r .rows .FieldDescriptions ()[index ].DataTypeOID ), 10 )
}
const varHeaderSize = 4
func (r *Rows ) ColumnTypeLength (index int ) (int64 , bool ) {
fd := r .rows .FieldDescriptions ()[index ]
switch fd .DataTypeOID {
case pgtype .TextOID , pgtype .ByteaOID :
return math .MaxInt64 , true
case pgtype .VarcharOID , pgtype .BPCharArrayOID :
return int64 (fd .TypeModifier - varHeaderSize ), true
default :
return 0 , false
}
}
func (r *Rows ) ColumnTypePrecisionScale (index int ) (precision , scale int64 , ok bool ) {
fd := r .rows .FieldDescriptions ()[index ]
switch fd .DataTypeOID {
case pgtype .NumericOID :
mod := fd .TypeModifier - varHeaderSize
precision = int64 ((mod >> 16 ) & 0xffff )
scale = int64 (mod & 0xffff )
return precision , scale , true
default :
return 0 , 0 , false
}
}
func (r *Rows ) ColumnTypeScanType (index int ) reflect .Type {
fd := r .rows .FieldDescriptions ()[index ]
switch fd .DataTypeOID {
case pgtype .Float8OID :
return reflect .TypeOf (float64 (0 ))
case pgtype .Float4OID :
return reflect .TypeOf (float32 (0 ))
case pgtype .Int8OID :
return reflect .TypeOf (int64 (0 ))
case pgtype .Int4OID :
return reflect .TypeOf (int32 (0 ))
case pgtype .Int2OID :
return reflect .TypeOf (int16 (0 ))
case pgtype .BoolOID :
return reflect .TypeOf (false )
case pgtype .NumericOID :
return reflect .TypeOf (float64 (0 ))
case pgtype .DateOID , pgtype .TimestampOID , pgtype .TimestamptzOID :
return reflect .TypeOf (time .Time {})
case pgtype .ByteaOID :
return reflect .TypeOf ([]byte (nil ))
default :
return reflect .TypeOf ("" )
}
}
func (r *Rows ) Close () error {
r .rows .Close ()
return r .rows .Err ()
}
func (r *Rows ) Next (dest []driver .Value ) error {
m := r .conn .conn .TypeMap ()
fieldDescriptions := r .rows .FieldDescriptions ()
if r .valueFuncs == nil {
r .valueFuncs = make ([]rowValueFunc , len (fieldDescriptions ))
for i , fd := range fieldDescriptions {
dataTypeOID := fd .DataTypeOID
format := fd .Format
switch fd .DataTypeOID {
case pgtype .BoolOID :
var d bool
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
return d , err
}
case pgtype .ByteaOID :
var d []byte
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
return d , err
}
case pgtype .CIDOID , pgtype .OIDOID , pgtype .XIDOID :
var d pgtype .Uint32
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
if err != nil {
return nil , err
}
return d .Value ()
}
case pgtype .DateOID :
var d pgtype .Date
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
if err != nil {
return nil , err
}
return d .Value ()
}
case pgtype .Float4OID :
var d float32
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
return float64 (d ), err
}
case pgtype .Float8OID :
var d float64
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
return d , err
}
case pgtype .Int2OID :
var d int16
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
return int64 (d ), err
}
case pgtype .Int4OID :
var d int32
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
return int64 (d ), err
}
case pgtype .Int8OID :
var d int64
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
return d , err
}
case pgtype .JSONOID , pgtype .JSONBOID :
var d []byte
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
if err != nil {
return nil , err
}
return d , nil
}
case pgtype .TimestampOID :
var d pgtype .Timestamp
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
if err != nil {
return nil , err
}
return d .Value ()
}
case pgtype .TimestamptzOID :
var d pgtype .Timestamptz
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
if err != nil {
return nil , err
}
return d .Value ()
}
default :
var d string
scanPlan := m .PlanScan (dataTypeOID , format , &d )
r .valueFuncs [i ] = func (src []byte ) (driver .Value , error ) {
err := scanPlan .Scan (src , &d )
return d , err
}
}
}
}
var more bool
if r .skipNext {
more = r .skipNextMore
r .skipNext = false
} else {
more = r .rows .Next ()
}
if !more {
if r .rows .Err () == nil {
return io .EOF
} else {
return r .rows .Err ()
}
}
for i , rv := range r .rows .RawValues () {
if rv != nil {
var err error
dest [i ], err = r .valueFuncs [i ](rv )
if err != nil {
return fmt .Errorf ("convert field %d failed: %v" , i , err )
}
} else {
dest [i ] = nil
}
}
return nil
}
func valueToInterface(argsV []driver .Value ) []any {
args := make ([]any , 0 , len (argsV ))
for _ , v := range argsV {
if v != nil {
args = append (args , v .(any ))
} else {
args = append (args , nil )
}
}
return args
}
func namedValueToInterface(argsV []driver .NamedValue ) []any {
args := make ([]any , 0 , len (argsV ))
for _ , v := range argsV {
if v .Value != nil {
args = append (args , v .Value .(any ))
} else {
args = append (args , nil )
}
}
return args
}
type wrapTx struct {
ctx context .Context
tx pgx .Tx
}
func (wtx wrapTx ) Commit () error { return wtx .tx .Commit (wtx .ctx ) }
func (wtx wrapTx ) Rollback () error { return wtx .tx .Rollback (wtx .ctx ) }
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .