package pgx
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/sanitize"
"github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
type ConnConfig struct {
pgconn .Config
Tracer QueryTracer
connString string
StatementCacheCapacity int
DescriptionCacheCapacity int
DefaultQueryExecMode QueryExecMode
createdByParseConfig bool
}
type ParseConfigOptions struct {
pgconn .ParseConfigOptions
}
func (cc *ConnConfig ) Copy () *ConnConfig {
newConfig := new (ConnConfig )
*newConfig = *cc
newConfig .Config = *newConfig .Config .Copy ()
return newConfig
}
func (cc *ConnConfig ) ConnString () string { return cc .connString }
type Conn struct {
pgConn *pgconn .PgConn
config *ConnConfig
preparedStatements map [string ]*pgconn .StatementDescription
statementCache stmtcache .Cache
descriptionCache stmtcache .Cache
queryTracer QueryTracer
batchTracer BatchTracer
copyFromTracer CopyFromTracer
prepareTracer PrepareTracer
notifications []*pgconn .Notification
doneChan chan struct {}
closedChan chan error
typeMap *pgtype .Map
wbuf []byte
eqb ExtendedQueryBuilder
}
type Identifier []string
func (ident Identifier ) Sanitize () string {
parts := make ([]string , len (ident ))
for i := range ident {
s := strings .ReplaceAll (ident [i ], string ([]byte {0 }), "" )
parts [i ] = `"` + strings .ReplaceAll (s , `"` , `""` ) + `"`
}
return strings .Join (parts , "." )
}
var ErrNoRows = errors .New ("no rows in result set" )
var errDisabledStatementCache = fmt .Errorf ("cannot use QueryExecModeCacheStatement with disabled statement cache" )
var errDisabledDescriptionCache = fmt .Errorf ("cannot use QueryExecModeCacheDescribe with disabled description cache" )
func Connect (ctx context .Context , connString string ) (*Conn , error ) {
connConfig , err := ParseConfig (connString )
if err != nil {
return nil , err
}
return connect (ctx , connConfig )
}
func ConnectWithOptions (ctx context .Context , connString string , options ParseConfigOptions ) (*Conn , error ) {
connConfig , err := ParseConfigWithOptions (connString , options )
if err != nil {
return nil , err
}
return connect (ctx , connConfig )
}
func ConnectConfig (ctx context .Context , connConfig *ConnConfig ) (*Conn , error ) {
connConfig = connConfig .Copy ()
return connect (ctx , connConfig )
}
func ParseConfigWithOptions (connString string , options ParseConfigOptions ) (*ConnConfig , error ) {
config , err := pgconn .ParseConfigWithOptions (connString , options .ParseConfigOptions )
if err != nil {
return nil , err
}
statementCacheCapacity := 512
if s , ok := config .RuntimeParams ["statement_cache_capacity" ]; ok {
delete (config .RuntimeParams , "statement_cache_capacity" )
n , err := strconv .ParseInt (s , 10 , 32 )
if err != nil {
return nil , fmt .Errorf ("cannot parse statement_cache_capacity: %w" , err )
}
statementCacheCapacity = int (n )
}
descriptionCacheCapacity := 512
if s , ok := config .RuntimeParams ["description_cache_capacity" ]; ok {
delete (config .RuntimeParams , "description_cache_capacity" )
n , err := strconv .ParseInt (s , 10 , 32 )
if err != nil {
return nil , fmt .Errorf ("cannot parse description_cache_capacity: %w" , err )
}
descriptionCacheCapacity = int (n )
}
defaultQueryExecMode := QueryExecModeCacheStatement
if s , ok := config .RuntimeParams ["default_query_exec_mode" ]; ok {
delete (config .RuntimeParams , "default_query_exec_mode" )
switch s {
case "cache_statement" :
defaultQueryExecMode = QueryExecModeCacheStatement
case "cache_describe" :
defaultQueryExecMode = QueryExecModeCacheDescribe
case "describe_exec" :
defaultQueryExecMode = QueryExecModeDescribeExec
case "exec" :
defaultQueryExecMode = QueryExecModeExec
case "simple_protocol" :
defaultQueryExecMode = QueryExecModeSimpleProtocol
default :
return nil , fmt .Errorf ("invalid default_query_exec_mode: %s" , s )
}
}
connConfig := &ConnConfig {
Config : *config ,
createdByParseConfig : true ,
StatementCacheCapacity : statementCacheCapacity ,
DescriptionCacheCapacity : descriptionCacheCapacity ,
DefaultQueryExecMode : defaultQueryExecMode ,
connString : connString ,
}
return connConfig , nil
}
func ParseConfig (connString string ) (*ConnConfig , error ) {
return ParseConfigWithOptions (connString , ParseConfigOptions {})
}
func connect(ctx context .Context , config *ConnConfig ) (c *Conn , err error ) {
if connectTracer , ok := config .Tracer .(ConnectTracer ); ok {
ctx = connectTracer .TraceConnectStart (ctx , TraceConnectStartData {ConnConfig : config })
defer func () {
connectTracer .TraceConnectEnd (ctx , TraceConnectEndData {Conn : c , Err : err })
}()
}
if !config .createdByParseConfig {
panic ("config must be created by ParseConfig" )
}
c = &Conn {
config : config ,
typeMap : pgtype .NewMap (),
queryTracer : config .Tracer ,
}
if t , ok := c .queryTracer .(BatchTracer ); ok {
c .batchTracer = t
}
if t , ok := c .queryTracer .(CopyFromTracer ); ok {
c .copyFromTracer = t
}
if t , ok := c .queryTracer .(PrepareTracer ); ok {
c .prepareTracer = t
}
if config .Config .OnNotification == nil {
config .Config .OnNotification = c .bufferNotifications
}
c .pgConn , err = pgconn .ConnectConfig (ctx , &config .Config )
if err != nil {
return nil , err
}
c .preparedStatements = make (map [string ]*pgconn .StatementDescription )
c .doneChan = make (chan struct {})
c .closedChan = make (chan error )
c .wbuf = make ([]byte , 0 , 1024 )
if c .config .StatementCacheCapacity > 0 {
c .statementCache = stmtcache .NewLRUCache (c .config .StatementCacheCapacity )
}
if c .config .DescriptionCacheCapacity > 0 {
c .descriptionCache = stmtcache .NewLRUCache (c .config .DescriptionCacheCapacity )
}
return c , nil
}
func (c *Conn ) Close (ctx context .Context ) error {
if c .IsClosed () {
return nil
}
err := c .pgConn .Close (ctx )
return err
}
func (c *Conn ) Prepare (ctx context .Context , name , sql string ) (sd *pgconn .StatementDescription , err error ) {
if c .prepareTracer != nil {
ctx = c .prepareTracer .TracePrepareStart (ctx , c , TracePrepareStartData {Name : name , SQL : sql })
}
if name != "" {
var ok bool
if sd , ok = c .preparedStatements [name ]; ok && sd .SQL == sql {
if c .prepareTracer != nil {
c .prepareTracer .TracePrepareEnd (ctx , c , TracePrepareEndData {AlreadyPrepared : true })
}
return sd , nil
}
}
if c .prepareTracer != nil {
defer func () {
c .prepareTracer .TracePrepareEnd (ctx , c , TracePrepareEndData {Err : err })
}()
}
sd , err = c .pgConn .Prepare (ctx , name , sql , nil )
if err != nil {
return nil , err
}
if name != "" {
c .preparedStatements [name ] = sd
}
return sd , nil
}
func (c *Conn ) Deallocate (ctx context .Context , name string ) error {
delete (c .preparedStatements , name )
_ , err := c .pgConn .Exec (ctx , "deallocate " +quoteIdentifier (name )).ReadAll ()
return err
}
func (c *Conn ) DeallocateAll (ctx context .Context ) error {
c .preparedStatements = map [string ]*pgconn .StatementDescription {}
if c .config .StatementCacheCapacity > 0 {
c .statementCache = stmtcache .NewLRUCache (c .config .StatementCacheCapacity )
}
if c .config .DescriptionCacheCapacity > 0 {
c .descriptionCache = stmtcache .NewLRUCache (c .config .DescriptionCacheCapacity )
}
_ , err := c .pgConn .Exec (ctx , "deallocate all" ).ReadAll ()
return err
}
func (c *Conn ) bufferNotifications (_ *pgconn .PgConn , n *pgconn .Notification ) {
c .notifications = append (c .notifications , n )
}
func (c *Conn ) WaitForNotification (ctx context .Context ) (*pgconn .Notification , error ) {
var n *pgconn .Notification
if len (c .notifications ) > 0 {
n = c .notifications [0 ]
c .notifications = c .notifications [1 :]
return n , nil
}
err := c .pgConn .WaitForNotification (ctx )
if len (c .notifications ) > 0 {
n = c .notifications [0 ]
c .notifications = c .notifications [1 :]
}
return n , err
}
func (c *Conn ) IsClosed () bool {
return c .pgConn .IsClosed ()
}
func (c *Conn ) die (err error ) {
if c .IsClosed () {
return
}
ctx , cancel := context .WithCancel (context .Background ())
cancel ()
c .pgConn .Close (ctx )
}
func quoteIdentifier(s string ) string {
return `"` + strings .ReplaceAll (s , `"` , `""` ) + `"`
}
func (c *Conn ) Ping (ctx context .Context ) error {
return c .pgConn .Ping (ctx )
}
func (c *Conn ) PgConn () *pgconn .PgConn { return c .pgConn }
func (c *Conn ) TypeMap () *pgtype .Map { return c .typeMap }
func (c *Conn ) Config () *ConnConfig { return c .config .Copy () }
func (c *Conn ) Exec (ctx context .Context , sql string , arguments ...any ) (pgconn .CommandTag , error ) {
if c .queryTracer != nil {
ctx = c .queryTracer .TraceQueryStart (ctx , c , TraceQueryStartData {SQL : sql , Args : arguments })
}
if err := c .deallocateInvalidatedCachedStatements (ctx ); err != nil {
return pgconn .CommandTag {}, err
}
commandTag , err := c .exec (ctx , sql , arguments ...)
if c .queryTracer != nil {
c .queryTracer .TraceQueryEnd (ctx , c , TraceQueryEndData {CommandTag : commandTag , Err : err })
}
return commandTag , err
}
func (c *Conn ) exec (ctx context .Context , sql string , arguments ...any ) (commandTag pgconn .CommandTag , err error ) {
mode := c .config .DefaultQueryExecMode
var queryRewriter QueryRewriter
optionLoop :
for len (arguments ) > 0 {
switch arg := arguments [0 ].(type ) {
case QueryExecMode :
mode = arg
arguments = arguments [1 :]
case QueryRewriter :
queryRewriter = arg
arguments = arguments [1 :]
default :
break optionLoop
}
}
if queryRewriter != nil {
sql , arguments , err = queryRewriter .RewriteQuery (ctx , c , sql , arguments )
if err != nil {
return pgconn .CommandTag {}, fmt .Errorf ("rewrite query failed: %v" , err )
}
}
if len (arguments ) == 0 {
mode = QueryExecModeSimpleProtocol
}
if sd , ok := c .preparedStatements [sql ]; ok {
return c .execPrepared (ctx , sd , arguments )
}
switch mode {
case QueryExecModeCacheStatement :
if c .statementCache == nil {
return pgconn .CommandTag {}, errDisabledStatementCache
}
sd := c .statementCache .Get (sql )
if sd == nil {
sd , err = c .Prepare (ctx , stmtcache .NextStatementName (), sql )
if err != nil {
return pgconn .CommandTag {}, err
}
c .statementCache .Put (sd )
}
return c .execPrepared (ctx , sd , arguments )
case QueryExecModeCacheDescribe :
if c .descriptionCache == nil {
return pgconn .CommandTag {}, errDisabledDescriptionCache
}
sd := c .descriptionCache .Get (sql )
if sd == nil {
sd , err = c .Prepare (ctx , "" , sql )
if err != nil {
return pgconn .CommandTag {}, err
}
}
return c .execParams (ctx , sd , arguments )
case QueryExecModeDescribeExec :
sd , err := c .Prepare (ctx , "" , sql )
if err != nil {
return pgconn .CommandTag {}, err
}
return c .execPrepared (ctx , sd , arguments )
case QueryExecModeExec :
return c .execSQLParams (ctx , sql , arguments )
case QueryExecModeSimpleProtocol :
return c .execSimpleProtocol (ctx , sql , arguments )
default :
return pgconn .CommandTag {}, fmt .Errorf ("unknown QueryExecMode: %v" , mode )
}
}
func (c *Conn ) execSimpleProtocol (ctx context .Context , sql string , arguments []any ) (commandTag pgconn .CommandTag , err error ) {
if len (arguments ) > 0 {
sql , err = c .sanitizeForSimpleQuery (sql , arguments ...)
if err != nil {
return pgconn .CommandTag {}, err
}
}
mrr := c .pgConn .Exec (ctx , sql )
for mrr .NextResult () {
commandTag , _ = mrr .ResultReader ().Close ()
}
err = mrr .Close ()
return commandTag , err
}
func (c *Conn ) execParams (ctx context .Context , sd *pgconn .StatementDescription , arguments []any ) (pgconn .CommandTag , error ) {
err := c .eqb .Build (c .typeMap , sd , arguments )
if err != nil {
return pgconn .CommandTag {}, err
}
result := c .pgConn .ExecParams (ctx , sd .SQL , c .eqb .ParamValues , sd .ParamOIDs , c .eqb .ParamFormats , c .eqb .ResultFormats ).Read ()
c .eqb .reset ()
return result .CommandTag , result .Err
}
func (c *Conn ) execPrepared (ctx context .Context , sd *pgconn .StatementDescription , arguments []any ) (pgconn .CommandTag , error ) {
err := c .eqb .Build (c .typeMap , sd , arguments )
if err != nil {
return pgconn .CommandTag {}, err
}
result := c .pgConn .ExecPrepared (ctx , sd .Name , c .eqb .ParamValues , c .eqb .ParamFormats , c .eqb .ResultFormats ).Read ()
c .eqb .reset ()
return result .CommandTag , result .Err
}
type unknownArgumentTypeQueryExecModeExecError struct {
arg any
}
func (e *unknownArgumentTypeQueryExecModeExecError ) Error () string {
return fmt .Sprintf ("cannot use unregistered type %T as query argument in QueryExecModeExec" , e .arg )
}
func (c *Conn ) execSQLParams (ctx context .Context , sql string , args []any ) (pgconn .CommandTag , error ) {
err := c .eqb .Build (c .typeMap , nil , args )
if err != nil {
return pgconn .CommandTag {}, err
}
result := c .pgConn .ExecParams (ctx , sql , c .eqb .ParamValues , nil , c .eqb .ParamFormats , c .eqb .ResultFormats ).Read ()
c .eqb .reset ()
return result .CommandTag , result .Err
}
func (c *Conn ) getRows (ctx context .Context , sql string , args []any ) *baseRows {
r := &baseRows {}
r .ctx = ctx
r .queryTracer = c .queryTracer
r .typeMap = c .typeMap
r .startTime = time .Now ()
r .sql = sql
r .args = args
r .conn = c
return r
}
type QueryExecMode int32
const (
_ QueryExecMode = iota
QueryExecModeCacheStatement
QueryExecModeCacheDescribe
QueryExecModeDescribeExec
QueryExecModeExec
QueryExecModeSimpleProtocol
)
func (m QueryExecMode ) String () string {
switch m {
case QueryExecModeCacheStatement :
return "cache statement"
case QueryExecModeCacheDescribe :
return "cache describe"
case QueryExecModeDescribeExec :
return "describe exec"
case QueryExecModeExec :
return "exec"
case QueryExecModeSimpleProtocol :
return "simple protocol"
default :
return "invalid"
}
}
type QueryResultFormats []int16
type QueryResultFormatsByOID map [uint32 ]int16
type QueryRewriter interface {
RewriteQuery (ctx context .Context , conn *Conn , sql string , args []any ) (newSQL string , newArgs []any , err error )
}
func (c *Conn ) Query (ctx context .Context , sql string , args ...any ) (Rows , error ) {
if c .queryTracer != nil {
ctx = c .queryTracer .TraceQueryStart (ctx , c , TraceQueryStartData {SQL : sql , Args : args })
}
if err := c .deallocateInvalidatedCachedStatements (ctx ); err != nil {
if c .queryTracer != nil {
c .queryTracer .TraceQueryEnd (ctx , c , TraceQueryEndData {Err : err })
}
return &baseRows {err : err , closed : true }, err
}
var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID
mode := c .config .DefaultQueryExecMode
var queryRewriter QueryRewriter
optionLoop :
for len (args ) > 0 {
switch arg := args [0 ].(type ) {
case QueryResultFormats :
resultFormats = arg
args = args [1 :]
case QueryResultFormatsByOID :
resultFormatsByOID = arg
args = args [1 :]
case QueryExecMode :
mode = arg
args = args [1 :]
case QueryRewriter :
queryRewriter = arg
args = args [1 :]
default :
break optionLoop
}
}
if queryRewriter != nil {
var err error
originalSQL := sql
originalArgs := args
sql , args , err = queryRewriter .RewriteQuery (ctx , c , sql , args )
if err != nil {
rows := c .getRows (ctx , originalSQL , originalArgs )
err = fmt .Errorf ("rewrite query failed: %v" , err )
rows .fatal (err )
return rows , err
}
}
if sql == "" {
mode = QueryExecModeSimpleProtocol
}
c .eqb .reset ()
anynil .NormalizeSlice (args )
rows := c .getRows (ctx , sql , args )
var err error
sd , explicitPreparedStatement := c .preparedStatements [sql ]
if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
if sd == nil {
sd , err = c .getStatementDescription (ctx , mode , sql )
if err != nil {
rows .fatal (err )
return rows , err
}
}
if len (sd .ParamOIDs ) != len (args ) {
rows .fatal (fmt .Errorf ("expected %d arguments, got %d" , len (sd .ParamOIDs ), len (args )))
return rows , rows .err
}
rows .sql = sd .SQL
err = c .eqb .Build (c .typeMap , sd , args )
if err != nil {
rows .fatal (err )
return rows , rows .err
}
if resultFormatsByOID != nil {
resultFormats = make ([]int16 , len (sd .Fields ))
for i := range resultFormats {
resultFormats [i ] = resultFormatsByOID [uint32 (sd .Fields [i ].DataTypeOID )]
}
}
if resultFormats == nil {
resultFormats = c .eqb .ResultFormats
}
if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe {
rows .resultReader = c .pgConn .ExecParams (ctx , sql , c .eqb .ParamValues , sd .ParamOIDs , c .eqb .ParamFormats , resultFormats )
} else {
rows .resultReader = c .pgConn .ExecPrepared (ctx , sd .Name , c .eqb .ParamValues , c .eqb .ParamFormats , resultFormats )
}
} else if mode == QueryExecModeExec {
err := c .eqb .Build (c .typeMap , nil , args )
if err != nil {
rows .fatal (err )
return rows , rows .err
}
rows .resultReader = c .pgConn .ExecParams (ctx , sql , c .eqb .ParamValues , nil , c .eqb .ParamFormats , c .eqb .ResultFormats )
} else if mode == QueryExecModeSimpleProtocol {
sql , err = c .sanitizeForSimpleQuery (sql , args ...)
if err != nil {
rows .fatal (err )
return rows , err
}
mrr := c .pgConn .Exec (ctx , sql )
if mrr .NextResult () {
rows .resultReader = mrr .ResultReader ()
rows .multiResultReader = mrr
} else {
err = mrr .Close ()
rows .fatal (err )
return rows , err
}
return rows , nil
} else {
err = fmt .Errorf ("unknown QueryExecMode: %v" , mode )
rows .fatal (err )
return rows , rows .err
}
c .eqb .reset ()
return rows , rows .err
}
func (c *Conn ) getStatementDescription (
ctx context .Context ,
mode QueryExecMode ,
sql string ,
) (sd *pgconn .StatementDescription , err error ) {
switch mode {
case QueryExecModeCacheStatement :
if c .statementCache == nil {
return nil , errDisabledStatementCache
}
sd = c .statementCache .Get (sql )
if sd == nil {
sd , err = c .Prepare (ctx , stmtcache .NextStatementName (), sql )
if err != nil {
return nil , err
}
c .statementCache .Put (sd )
}
case QueryExecModeCacheDescribe :
if c .descriptionCache == nil {
return nil , errDisabledDescriptionCache
}
sd = c .descriptionCache .Get (sql )
if sd == nil {
sd , err = c .Prepare (ctx , "" , sql )
if err != nil {
return nil , err
}
c .descriptionCache .Put (sd )
}
case QueryExecModeDescribeExec :
return c .Prepare (ctx , "" , sql )
}
return sd , err
}
func (c *Conn ) QueryRow (ctx context .Context , sql string , args ...any ) Row {
rows , _ := c .Query (ctx , sql , args ...)
return (*connRow )(rows .(*baseRows ))
}
func (c *Conn ) SendBatch (ctx context .Context , b *Batch ) (br BatchResults ) {
if c .batchTracer != nil {
ctx = c .batchTracer .TraceBatchStart (ctx , c , TraceBatchStartData {Batch : b })
defer func () {
err := br .(interface { earlyError () error }).earlyError ()
if err != nil {
c .batchTracer .TraceBatchEnd (ctx , c , TraceBatchEndData {Err : err })
}
}()
}
if err := c .deallocateInvalidatedCachedStatements (ctx ); err != nil {
return &batchResults {ctx : ctx , conn : c , err : err }
}
mode := c .config .DefaultQueryExecMode
for _ , bi := range b .queuedQueries {
var queryRewriter QueryRewriter
sql := bi .query
arguments := bi .arguments
optionLoop :
for len (arguments ) > 0 {
switch arg := arguments [0 ].(type ) {
case QueryRewriter :
queryRewriter = arg
arguments = arguments [1 :]
default :
break optionLoop
}
}
if queryRewriter != nil {
var err error
sql , arguments , err = queryRewriter .RewriteQuery (ctx , c , sql , arguments )
if err != nil {
return &batchResults {ctx : ctx , conn : c , err : fmt .Errorf ("rewrite query failed: %v" , err )}
}
}
bi .query = sql
bi .arguments = arguments
}
if mode == QueryExecModeSimpleProtocol {
return c .sendBatchQueryExecModeSimpleProtocol (ctx , b )
}
for _ , bi := range b .queuedQueries {
if sd , ok := c .preparedStatements [bi .query ]; ok {
bi .sd = sd
}
}
switch mode {
case QueryExecModeExec :
return c .sendBatchQueryExecModeExec (ctx , b )
case QueryExecModeCacheStatement :
return c .sendBatchQueryExecModeCacheStatement (ctx , b )
case QueryExecModeCacheDescribe :
return c .sendBatchQueryExecModeCacheDescribe (ctx , b )
case QueryExecModeDescribeExec :
return c .sendBatchQueryExecModeDescribeExec (ctx , b )
default :
panic ("unknown QueryExecMode" )
}
}
func (c *Conn ) sendBatchQueryExecModeSimpleProtocol (ctx context .Context , b *Batch ) *batchResults {
var sb strings .Builder
for i , bi := range b .queuedQueries {
if i > 0 {
sb .WriteByte (';' )
}
sql , err := c .sanitizeForSimpleQuery (bi .query , bi .arguments ...)
if err != nil {
return &batchResults {ctx : ctx , conn : c , err : err }
}
sb .WriteString (sql )
}
mrr := c .pgConn .Exec (ctx , sb .String ())
return &batchResults {
ctx : ctx ,
conn : c ,
mrr : mrr ,
b : b ,
qqIdx : 0 ,
}
}
func (c *Conn ) sendBatchQueryExecModeExec (ctx context .Context , b *Batch ) *batchResults {
batch := &pgconn .Batch {}
for _ , bi := range b .queuedQueries {
sd := bi .sd
if sd != nil {
err := c .eqb .Build (c .typeMap , sd , bi .arguments )
if err != nil {
return &batchResults {ctx : ctx , conn : c , err : err }
}
batch .ExecPrepared (sd .Name , c .eqb .ParamValues , c .eqb .ParamFormats , c .eqb .ResultFormats )
} else {
err := c .eqb .Build (c .typeMap , nil , bi .arguments )
if err != nil {
return &batchResults {ctx : ctx , conn : c , err : err }
}
batch .ExecParams (bi .query , c .eqb .ParamValues , nil , c .eqb .ParamFormats , c .eqb .ResultFormats )
}
}
c .eqb .reset ()
mrr := c .pgConn .ExecBatch (ctx , batch )
return &batchResults {
ctx : ctx ,
conn : c ,
mrr : mrr ,
b : b ,
qqIdx : 0 ,
}
}
func (c *Conn ) sendBatchQueryExecModeCacheStatement (ctx context .Context , b *Batch ) (pbr *pipelineBatchResults ) {
if c .statementCache == nil {
return &pipelineBatchResults {ctx : ctx , conn : c , err : errDisabledStatementCache , closed : true }
}
distinctNewQueries := []*pgconn .StatementDescription {}
distinctNewQueriesIdxMap := make (map [string ]int )
for _ , bi := range b .queuedQueries {
if bi .sd == nil {
sd := c .statementCache .Get (bi .query )
if sd != nil {
bi .sd = sd
} else {
if idx , present := distinctNewQueriesIdxMap [bi .query ]; present {
bi .sd = distinctNewQueries [idx ]
} else {
sd = &pgconn .StatementDescription {
Name : stmtcache .NextStatementName (),
SQL : bi .query ,
}
distinctNewQueriesIdxMap [sd .SQL ] = len (distinctNewQueries )
distinctNewQueries = append (distinctNewQueries , sd )
bi .sd = sd
}
}
}
}
return c .sendBatchExtendedWithDescription (ctx , b , distinctNewQueries , c .statementCache )
}
func (c *Conn ) sendBatchQueryExecModeCacheDescribe (ctx context .Context , b *Batch ) (pbr *pipelineBatchResults ) {
if c .descriptionCache == nil {
return &pipelineBatchResults {ctx : ctx , conn : c , err : errDisabledDescriptionCache , closed : true }
}
distinctNewQueries := []*pgconn .StatementDescription {}
distinctNewQueriesIdxMap := make (map [string ]int )
for _ , bi := range b .queuedQueries {
if bi .sd == nil {
sd := c .descriptionCache .Get (bi .query )
if sd != nil {
bi .sd = sd
} else {
if idx , present := distinctNewQueriesIdxMap [bi .query ]; present {
bi .sd = distinctNewQueries [idx ]
} else {
sd = &pgconn .StatementDescription {
SQL : bi .query ,
}
distinctNewQueriesIdxMap [sd .SQL ] = len (distinctNewQueries )
distinctNewQueries = append (distinctNewQueries , sd )
bi .sd = sd
}
}
}
}
return c .sendBatchExtendedWithDescription (ctx , b , distinctNewQueries , c .descriptionCache )
}
func (c *Conn ) sendBatchQueryExecModeDescribeExec (ctx context .Context , b *Batch ) (pbr *pipelineBatchResults ) {
distinctNewQueries := []*pgconn .StatementDescription {}
distinctNewQueriesIdxMap := make (map [string ]int )
for _ , bi := range b .queuedQueries {
if bi .sd == nil {
if idx , present := distinctNewQueriesIdxMap [bi .query ]; present {
bi .sd = distinctNewQueries [idx ]
} else {
sd := &pgconn .StatementDescription {
SQL : bi .query ,
}
distinctNewQueriesIdxMap [sd .SQL ] = len (distinctNewQueries )
distinctNewQueries = append (distinctNewQueries , sd )
bi .sd = sd
}
}
}
return c .sendBatchExtendedWithDescription (ctx , b , distinctNewQueries , nil )
}
func (c *Conn ) sendBatchExtendedWithDescription (ctx context .Context , b *Batch , distinctNewQueries []*pgconn .StatementDescription , sdCache stmtcache .Cache ) (pbr *pipelineBatchResults ) {
pipeline := c .pgConn .StartPipeline (context .Background ())
defer func () {
if pbr != nil && pbr .err != nil {
pipeline .Close ()
}
}()
if len (distinctNewQueries ) > 0 {
for _ , sd := range distinctNewQueries {
pipeline .SendPrepare (sd .Name , sd .SQL , nil )
}
err := pipeline .Sync ()
if err != nil {
return &pipelineBatchResults {ctx : ctx , conn : c , err : err , closed : true }
}
for _ , sd := range distinctNewQueries {
results , err := pipeline .GetResults ()
if err != nil {
return &pipelineBatchResults {ctx : ctx , conn : c , err : err , closed : true }
}
resultSD , ok := results .(*pgconn .StatementDescription )
if !ok {
return &pipelineBatchResults {ctx : ctx , conn : c , err : fmt .Errorf ("expected statement description, got %T" , results ), closed : true }
}
sd .ParamOIDs = resultSD .ParamOIDs
sd .Fields = resultSD .Fields
}
results , err := pipeline .GetResults ()
if err != nil {
return &pipelineBatchResults {ctx : ctx , conn : c , err : err , closed : true }
}
_ , ok := results .(*pgconn .PipelineSync )
if !ok {
return &pipelineBatchResults {ctx : ctx , conn : c , err : fmt .Errorf ("expected sync, got %T" , results ), closed : true }
}
}
if sdCache != nil {
for _ , sd := range distinctNewQueries {
sdCache .Put (sd )
}
}
for _ , bi := range b .queuedQueries {
err := c .eqb .Build (c .typeMap , bi .sd , bi .arguments )
if err != nil {
err = fmt .Errorf ("error building query %s: %w" , bi .query , err )
return &pipelineBatchResults {ctx : ctx , conn : c , err : err , closed : true }
}
if bi .sd .Name == "" {
pipeline .SendQueryParams (bi .sd .SQL , c .eqb .ParamValues , bi .sd .ParamOIDs , c .eqb .ParamFormats , c .eqb .ResultFormats )
} else {
pipeline .SendQueryPrepared (bi .sd .Name , c .eqb .ParamValues , c .eqb .ParamFormats , c .eqb .ResultFormats )
}
}
err := pipeline .Sync ()
if err != nil {
return &pipelineBatchResults {ctx : ctx , conn : c , err : err , closed : true }
}
return &pipelineBatchResults {
ctx : ctx ,
conn : c ,
pipeline : pipeline ,
b : b ,
}
}
func (c *Conn ) sanitizeForSimpleQuery (sql string , args ...any ) (string , error ) {
if c .pgConn .ParameterStatus ("standard_conforming_strings" ) != "on" {
return "" , errors .New ("simple protocol queries must be run with standard_conforming_strings=on" )
}
if c .pgConn .ParameterStatus ("client_encoding" ) != "UTF8" {
return "" , errors .New ("simple protocol queries must be run with client_encoding=UTF8" )
}
var err error
valueArgs := make ([]any , len (args ))
for i , a := range args {
valueArgs [i ], err = convertSimpleArgument (c .typeMap , a )
if err != nil {
return "" , err
}
}
return sanitize .SanitizeSQL (sql , valueArgs ...)
}
func (c *Conn ) LoadType (ctx context .Context , typeName string ) (*pgtype .Type , error ) {
var oid uint32
err := c .QueryRow (ctx , "select $1::text::regtype::oid;" , typeName ).Scan (&oid )
if err != nil {
return nil , err
}
var typtype string
var typbasetype uint32
err = c .QueryRow (ctx , "select typtype::text, typbasetype from pg_type where oid=$1" , oid ).Scan (&typtype , &typbasetype )
if err != nil {
return nil , err
}
switch typtype {
case "b" :
elementOID , err := c .getArrayElementOID (ctx , oid )
if err != nil {
return nil , err
}
dt , ok := c .TypeMap ().TypeForOID (elementOID )
if !ok {
return nil , errors .New ("array element OID not registered" )
}
return &pgtype .Type {Name : typeName , OID : oid , Codec : &pgtype .ArrayCodec {ElementType : dt }}, nil
case "c" :
fields , err := c .getCompositeFields (ctx , oid )
if err != nil {
return nil , err
}
return &pgtype .Type {Name : typeName , OID : oid , Codec : &pgtype .CompositeCodec {Fields : fields }}, nil
case "d" :
dt , ok := c .TypeMap ().TypeForOID (typbasetype )
if !ok {
return nil , errors .New ("domain base type OID not registered" )
}
return &pgtype .Type {Name : typeName , OID : oid , Codec : dt .Codec }, nil
case "e" :
return &pgtype .Type {Name : typeName , OID : oid , Codec : &pgtype .EnumCodec {}}, nil
case "r" :
elementOID , err := c .getRangeElementOID (ctx , oid )
if err != nil {
return nil , err
}
dt , ok := c .TypeMap ().TypeForOID (elementOID )
if !ok {
return nil , errors .New ("range element OID not registered" )
}
return &pgtype .Type {Name : typeName , OID : oid , Codec : &pgtype .RangeCodec {ElementType : dt }}, nil
case "m" :
elementOID , err := c .getMultiRangeElementOID (ctx , oid )
if err != nil {
return nil , err
}
dt , ok := c .TypeMap ().TypeForOID (elementOID )
if !ok {
return nil , errors .New ("multirange element OID not registered" )
}
return &pgtype .Type {Name : typeName , OID : oid , Codec : &pgtype .MultirangeCodec {ElementType : dt }}, nil
default :
return &pgtype .Type {}, errors .New ("unknown typtype" )
}
}
func (c *Conn ) getArrayElementOID (ctx context .Context , oid uint32 ) (uint32 , error ) {
var typelem uint32
err := c .QueryRow (ctx , "select typelem from pg_type where oid=$1" , oid ).Scan (&typelem )
if err != nil {
return 0 , err
}
return typelem , nil
}
func (c *Conn ) getRangeElementOID (ctx context .Context , oid uint32 ) (uint32 , error ) {
var typelem uint32
err := c .QueryRow (ctx , "select rngsubtype from pg_range where rngtypid=$1" , oid ).Scan (&typelem )
if err != nil {
return 0 , err
}
return typelem , nil
}
func (c *Conn ) getMultiRangeElementOID (ctx context .Context , oid uint32 ) (uint32 , error ) {
var typelem uint32
err := c .QueryRow (ctx , "select rngtypid from pg_range where rngmultitypid=$1" , oid ).Scan (&typelem )
if err != nil {
return 0 , err
}
return typelem , nil
}
func (c *Conn ) getCompositeFields (ctx context .Context , oid uint32 ) ([]pgtype .CompositeCodecField , error ) {
var typrelid uint32
err := c .QueryRow (ctx , "select typrelid from pg_type where oid=$1" , oid ).Scan (&typrelid )
if err != nil {
return nil , err
}
var fields []pgtype .CompositeCodecField
var fieldName string
var fieldOID uint32
rows , _ := c .Query (ctx , `select attname, atttypid
from pg_attribute
where attrelid=$1
and not attisdropped
and attnum > 0
order by attnum` ,
typrelid ,
)
_, err = ForEachRow (rows , []any {&fieldName , &fieldOID }, func () error {
dt , ok := c .TypeMap ().TypeForOID (fieldOID )
if !ok {
return fmt .Errorf ("unknown composite type field OID: %v" , fieldOID )
}
fields = append (fields , pgtype .CompositeCodecField {Name : fieldName , Type : dt })
return nil
})
if err != nil {
return nil , err
}
return fields , nil
}
func (c *Conn ) deallocateInvalidatedCachedStatements (ctx context .Context ) error {
if c .pgConn .TxStatus () != 'I' {
return nil
}
if c .descriptionCache != nil {
c .descriptionCache .HandleInvalidated ()
}
var invalidatedStatements []*pgconn .StatementDescription
if c .statementCache != nil {
invalidatedStatements = c .statementCache .HandleInvalidated ()
}
if len (invalidatedStatements ) == 0 {
return nil
}
pipeline := c .pgConn .StartPipeline (ctx )
defer pipeline .Close ()
for _ , sd := range invalidatedStatements {
pipeline .SendDeallocate (sd .Name )
delete (c .preparedStatements , sd .Name )
}
err := pipeline .Sync ()
if err != nil {
return fmt .Errorf ("failed to deallocate cached statement(s): %w" , err )
}
err = pipeline .Close ()
if err != nil {
return fmt .Errorf ("failed to deallocate cached statement(s): %w" , err )
}
return nil
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .