// Package stdlib is the compatibility layer from pgx to database/sql. // // A database/sql connection can be established through sql.Open. // // db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") // if err != nil { // return err // } // // Or from a DSN string. // // db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") // if err != nil { // return err // } // // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used // with sql.Open. // // connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) // connConfig.Logger = myLogger // connStr := stdlib.RegisterConnConfig(connConfig) // db, _ := sql.Open("pgx", connStr) // // pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. It does not support named parameters. // // db.QueryRow("select * from users where id=$1", userID) // // (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection pool. This allows // operations that use pgx specific functionality. // // // Given db is a *sql.DB // conn, err := db.Conn(context.Background()) // if err != nil { // // handle error from acquiring connection from DB pool // } // // err = conn.Raw(func(driverConn any) error { // conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn // // Do pgx specific stuff with conn // conn.CopyFrom(...) // return nil // }) // if err != nil { // // handle error that occurred while using *pgx.Conn // } // // # PostgreSQL Specific Data Types // // The pgtype package provides support for PostgreSQL specific types. *pgtype.Map.SQLScanner is an adapter that makes // these types usable as a sql.Scanner. // // m := pgtype.NewMap() // var a []int64 // err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
package stdlib import ( ) // Only intrinsic types should be binary format with database/sql. var databaseSQLResultFormats pgx.QueryResultFormatsByOID var pgxDriver *Driver func init() { pgxDriver = &Driver{ configs: make(map[string]*pgx.ConnConfig), } // if pgx driver was already registered by different pgx major version then we // skip registration under the default name. if !contains(sql.Drivers(), "pgx") { sql.Register("pgx", pgxDriver) } sql.Register("pgx/v5", pgxDriver) databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ pgtype.BoolOID: 1, pgtype.ByteaOID: 1, pgtype.CIDOID: 1, pgtype.DateOID: 1, pgtype.Float4OID: 1, pgtype.Float8OID: 1, pgtype.Int2OID: 1, pgtype.Int4OID: 1, pgtype.Int8OID: 1, pgtype.OIDOID: 1, pgtype.TimestampOID: 1, pgtype.TimestamptzOID: 1, pgtype.XIDOID: 1, } } // TODO replace by slices.Contains when experimental package will be merged to stdlib // https://pkg.go.dev/golang.org/x/exp/slices#Contains func contains( []string, string) bool { for , := range { if == { return true } } return false } // OptionOpenDB options for configuring the driver when opening a new db pool. type OptionOpenDB func(*connector) // OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will // be used to connect, so only its immediate members should be modified. func ( func(context.Context, *pgx.ConnConfig) error) OptionOpenDB { return func( *connector) { .BeforeConnect = } } // OptionAfterConnect provides a callback for after connect. func ( func(context.Context, *pgx.Conn) error) OptionOpenDB { return func( *connector) { .AfterConnect = } } // OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the // connection if the connection has been used before. // If ResetSessionFunc returns ErrBadConn error the connection will be discarded. func ( func(context.Context, *pgx.Conn) error) OptionOpenDB { return func( *connector) { .ResetSession = } } // RandomizeHostOrderFunc is a BeforeConnect hook that randomizes the host order in the provided connConfig, so that a // new host becomes primary each time. This is useful to distribute connections for multi-master databases like // CockroachDB. If you use this you likely should set https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime as well // to ensure that connections are periodically rebalanced across your nodes. func ( context.Context, *pgx.ConnConfig) error { if len(.Fallbacks) == 0 { return nil } := append([]*pgconn.FallbackConfig{{ Host: .Host, Port: .Port, TLSConfig: .TLSConfig, }}, .Fallbacks...) rand.Shuffle(len(), func(, int) { [], [] = [], [] }) // Use the one that sorted last as the primary and keep the rest as the fallbacks := [len()-1] .Host = .Host .Port = .Port .TLSConfig = .TLSConfig .Fallbacks = [:len()-1] return nil } func ( pgx.ConnConfig, ...OptionOpenDB) driver.Connector { := connector{ ConnConfig: , BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil }, // noop before connect by default AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default driver: pgxDriver, } for , := range { (&) } return } func ( pgx.ConnConfig, ...OptionOpenDB) *sql.DB { := GetConnector(, ...) return sql.OpenDB() } type connector struct { pgx.ConnConfig BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused driver *Driver } // Connect implement driver.Connector interface func ( connector) ( context.Context) (driver.Conn, error) { var ( error *pgx.Conn ) // Create a shallow copy of the config, so that BeforeConnect can safely modify it := .ConnConfig if = .BeforeConnect(, &); != nil { return nil, } if , = pgx.ConnectConfig(, &); != nil { return nil, } if = .AfterConnect(, ); != nil { return nil, } return &Conn{conn: , driver: .driver, connConfig: , resetSessionFunc: .ResetSession}, nil } // Driver implement driver.Connector interface func ( connector) () driver.Driver { return .driver } // GetDefaultDriver returns the driver initialized in the init function // and used when the pgx driver is registered. func () driver.Driver { return pgxDriver } type Driver struct { configMutex sync.Mutex configs map[string]*pgx.ConnConfig sequence int } func ( *Driver) ( string) (driver.Conn, error) { , := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout defer () , := .OpenConnector() if != nil { return nil, } return .Connect() } func ( *Driver) ( string) (driver.Connector, error) { return &driverConnector{driver: , name: }, nil } func ( *Driver) ( *pgx.ConnConfig) string { .configMutex.Lock() := fmt.Sprintf("registeredConnConfig%d", .sequence) .sequence++ .configs[] = .configMutex.Unlock() return } func ( *Driver) ( string) { .configMutex.Lock() delete(.configs, ) .configMutex.Unlock() } type driverConnector struct { driver *Driver name string } func ( *driverConnector) ( context.Context) (driver.Conn, error) { var *pgx.ConnConfig .driver.configMutex.Lock() = .driver.configs[.name] .driver.configMutex.Unlock() if == nil { var error , = pgx.ParseConfig(.name) if != nil { return nil, } } , := pgx.ConnectConfig(, ) if != nil { return nil, } := &Conn{ conn: , driver: .driver, connConfig: *, resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil }, } return , nil } func ( *driverConnector) () driver.Driver { return .driver } // RegisterConnConfig registers a ConnConfig and returns the connection string to use with Open. func ( *pgx.ConnConfig) string { return pgxDriver.registerConnConfig() } // UnregisterConnConfig removes the ConnConfig registration for connStr. func ( string) { pgxDriver.unregisterConnConfig() } type Conn struct { conn *pgx.Conn psCount int64 // Counter used for creating unique prepared statement names driver *Driver connConfig pgx.ConnConfig resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused lastResetSessionTime time.Time } // Conn returns the underlying *pgx.Conn func ( *Conn) () *pgx.Conn { return .conn } func ( *Conn) ( string) (driver.Stmt, error) { return .PrepareContext(context.Background(), ) } func ( *Conn) ( context.Context, string) (driver.Stmt, error) { if .conn.IsClosed() { return nil, driver.ErrBadConn } := fmt.Sprintf("pgx_%d", .psCount) .psCount++ , := .conn.Prepare(, , ) if != nil { return nil, } return &Stmt{sd: , conn: }, nil } func ( *Conn) () error { , := context.WithTimeout(context.Background(), time.Second*5) defer () return .conn.Close() } func ( *Conn) () (driver.Tx, error) { return .BeginTx(context.Background(), driver.TxOptions{}) } func ( *Conn) ( context.Context, driver.TxOptions) (driver.Tx, error) { if .conn.IsClosed() { return nil, driver.ErrBadConn } var pgx.TxOptions switch sql.IsolationLevel(.Isolation) { case sql.LevelDefault: case sql.LevelReadUncommitted: .IsoLevel = pgx.ReadUncommitted case sql.LevelReadCommitted: .IsoLevel = pgx.ReadCommitted case sql.LevelRepeatableRead, sql.LevelSnapshot: .IsoLevel = pgx.RepeatableRead case sql.LevelSerializable: .IsoLevel = pgx.Serializable default: return nil, fmt.Errorf("unsupported isolation: %v", .Isolation) } if .ReadOnly { .AccessMode = pgx.ReadOnly } , := .conn.BeginTx(, ) if != nil { return nil, } return wrapTx{ctx: , tx: }, nil } func ( *Conn) ( context.Context, string, []driver.NamedValue) (driver.Result, error) { if .conn.IsClosed() { return nil, driver.ErrBadConn } := namedValueToInterface() , := .conn.Exec(, , ...) // if we got a network error before we had a chance to send the query, retry if != nil { if pgconn.SafeToRetry() { return nil, driver.ErrBadConn } } return driver.RowsAffected(.RowsAffected()), } func ( *Conn) ( context.Context, string, []driver.NamedValue) (driver.Rows, error) { if .conn.IsClosed() { return nil, driver.ErrBadConn } := []any{databaseSQLResultFormats} = append(, namedValueToInterface()...) , := .conn.Query(, , ...) if != nil { if pgconn.SafeToRetry() { return nil, driver.ErrBadConn } return nil, } // Preload first row because otherwise we won't know what columns are available when database/sql asks. := .Next() if = .Err(); != nil { .Close() return nil, } return &Rows{conn: , rows: , skipNext: true, skipNextMore: }, nil } func ( *Conn) ( context.Context) error { if .conn.IsClosed() { return driver.ErrBadConn } := .conn.Ping() if != nil { // A Ping failure implies some sort of fatal state. The connection is almost certainly already closed by the // failure, but manually close it just to be sure. .Close() return driver.ErrBadConn } return nil } func ( *Conn) (*driver.NamedValue) error { // Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. So everything can be passed through directly. return nil } func ( *Conn) ( context.Context) error { if .conn.IsClosed() { return driver.ErrBadConn } := time.Now() if .Sub(.lastResetSessionTime) > time.Second { if := .conn.PgConn().CheckConn(); != nil { return driver.ErrBadConn } } .lastResetSessionTime = return .resetSessionFunc(, .conn) } type Stmt struct { sd *pgconn.StatementDescription conn *Conn } func ( *Stmt) () error { , := context.WithTimeout(context.Background(), time.Second*5) defer () return .conn.conn.Deallocate(, .sd.Name) } func ( *Stmt) () int { return len(.sd.ParamOIDs) } func ( *Stmt) ( []driver.Value) (driver.Result, error) { return nil, errors.New("Stmt.Exec deprecated and not implemented") } func ( *Stmt) ( context.Context, []driver.NamedValue) (driver.Result, error) { return .conn.ExecContext(, .sd.Name, ) } func ( *Stmt) ( []driver.Value) (driver.Rows, error) { return nil, errors.New("Stmt.Query deprecated and not implemented") } func ( *Stmt) ( context.Context, []driver.NamedValue) (driver.Rows, error) { return .conn.QueryContext(, .sd.Name, ) } type rowValueFunc func(src []byte) (driver.Value, error) type Rows struct { conn *Conn rows pgx.Rows valueFuncs []rowValueFunc skipNext bool skipNextMore bool columnNames []string } func ( *Rows) () []string { if .columnNames == nil { := .rows.FieldDescriptions() .columnNames = make([]string, len()) for , := range { .columnNames[] = string(.Name) } } return .columnNames } // ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned. func ( *Rows) ( int) string { if , := .conn.conn.TypeMap().TypeForOID(.rows.FieldDescriptions()[].DataTypeOID); { return strings.ToUpper(.Name) } return strconv.FormatInt(int64(.rows.FieldDescriptions()[].DataTypeOID), 10) } const varHeaderSize = 4 // ColumnTypeLength returns the length of the column type if the column is a // variable length type. If the column is not a variable length type ok // should return false. func ( *Rows) ( int) (int64, bool) { := .rows.FieldDescriptions()[] switch .DataTypeOID { case pgtype.TextOID, pgtype.ByteaOID: return math.MaxInt64, true case pgtype.VarcharOID, pgtype.BPCharArrayOID: return int64(.TypeModifier - varHeaderSize), true default: return 0, false } } // ColumnTypePrecisionScale should return the precision and scale for decimal // types. If not applicable, ok should be false. func ( *Rows) ( int) (, int64, bool) { := .rows.FieldDescriptions()[] switch .DataTypeOID { case pgtype.NumericOID: := .TypeModifier - varHeaderSize = int64(( >> 16) & 0xffff) = int64( & 0xffff) return , , true default: return 0, 0, false } } // ColumnTypeScanType returns the value type that can be used to scan types into. func ( *Rows) ( int) reflect.Type { := .rows.FieldDescriptions()[] switch .DataTypeOID { case pgtype.Float8OID: return reflect.TypeOf(float64(0)) case pgtype.Float4OID: return reflect.TypeOf(float32(0)) case pgtype.Int8OID: return reflect.TypeOf(int64(0)) case pgtype.Int4OID: return reflect.TypeOf(int32(0)) case pgtype.Int2OID: return reflect.TypeOf(int16(0)) case pgtype.BoolOID: return reflect.TypeOf(false) case pgtype.NumericOID: return reflect.TypeOf(float64(0)) case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID: return reflect.TypeOf(time.Time{}) case pgtype.ByteaOID: return reflect.TypeOf([]byte(nil)) default: return reflect.TypeOf("") } } func ( *Rows) () error { .rows.Close() return .rows.Err() } func ( *Rows) ( []driver.Value) error { := .conn.conn.TypeMap() := .rows.FieldDescriptions() if .valueFuncs == nil { .valueFuncs = make([]rowValueFunc, len()) for , := range { := .DataTypeOID := .Format switch .DataTypeOID { case pgtype.BoolOID: var bool := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) return , } case pgtype.ByteaOID: var []byte := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) return , } case pgtype.CIDOID, pgtype.OIDOID, pgtype.XIDOID: var pgtype.Uint32 := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) if != nil { return nil, } return .Value() } case pgtype.DateOID: var pgtype.Date := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) if != nil { return nil, } return .Value() } case pgtype.Float4OID: var float32 := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) return float64(), } case pgtype.Float8OID: var float64 := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) return , } case pgtype.Int2OID: var int16 := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) return int64(), } case pgtype.Int4OID: var int32 := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) return int64(), } case pgtype.Int8OID: var int64 := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) return , } case pgtype.JSONOID, pgtype.JSONBOID: var []byte := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) if != nil { return nil, } return , nil } case pgtype.TimestampOID: var pgtype.Timestamp := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) if != nil { return nil, } return .Value() } case pgtype.TimestamptzOID: var pgtype.Timestamptz := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) if != nil { return nil, } return .Value() } default: var string := .PlanScan(, , &) .valueFuncs[] = func( []byte) (driver.Value, error) { := .Scan(, &) return , } } } } var bool if .skipNext { = .skipNextMore .skipNext = false } else { = .rows.Next() } if ! { if .rows.Err() == nil { return io.EOF } else { return .rows.Err() } } for , := range .rows.RawValues() { if != nil { var error [], = .valueFuncs[]() if != nil { return fmt.Errorf("convert field %d failed: %v", , ) } } else { [] = nil } } return nil } func valueToInterface( []driver.Value) []any { := make([]any, 0, len()) for , := range { if != nil { = append(, .(any)) } else { = append(, nil) } } return } func namedValueToInterface( []driver.NamedValue) []any { := make([]any, 0, len()) for , := range { if .Value != nil { = append(, .Value.(any)) } else { = append(, nil) } } return } type wrapTx struct { ctx context.Context tx pgx.Tx } func ( wrapTx) () error { return .tx.Commit(.ctx) } func ( wrapTx) () error { return .tx.Rollback(.ctx) }