package pgconn
import (
"context"
"crypto/md5"
"crypto/tls"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"math"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)
const (
connStatusUninitialized = iota
connStatusConnecting
connStatusClosed
connStatusIdle
connStatusBusy
)
type Notice PgError
type Notification struct {
PID uint32
Channel string
Payload string
}
type DialFunc func (ctx context .Context , network, addr string ) (net .Conn , error )
type LookupFunc func (ctx context .Context , host string ) (addrs []string , err error )
type BuildFrontendFunc func (r io .Reader , w io .Writer ) *pgproto3 .Frontend
type NoticeHandler func (*PgConn , *Notice )
type NotificationHandler func (*PgConn , *Notification )
type PgConn struct {
conn net .Conn
pid uint32
secretKey uint32
parameterStatuses map [string ]string
txStatus byte
frontend *pgproto3 .Frontend
bgReader *bgreader .BGReader
slowWriteTimer *time .Timer
config *Config
status byte
bufferingReceive bool
bufferingReceiveMux sync .Mutex
bufferingReceiveMsg pgproto3 .BackendMessage
bufferingReceiveErr error
peekedMsg pgproto3 .BackendMessage
resultReader ResultReader
multiResultReader MultiResultReader
pipeline Pipeline
contextWatcher *ctxwatch .ContextWatcher
fieldDescriptions [16 ]FieldDescription
cleanupDone chan struct {}
}
func Connect (ctx context .Context , connString string ) (*PgConn , error ) {
config , err := ParseConfig (connString )
if err != nil {
return nil , err
}
return ConnectConfig (ctx , config )
}
func ConnectWithOptions (ctx context .Context , connString string , parseConfigOptions ParseConfigOptions ) (*PgConn , error ) {
config , err := ParseConfigWithOptions (connString , parseConfigOptions )
if err != nil {
return nil , err
}
return ConnectConfig (ctx , config )
}
func ConnectConfig (octx context .Context , config *Config ) (pgConn *PgConn , err error ) {
if !config .createdByParseConfig {
panic ("config must be created by ParseConfig" )
}
fallbackConfigs := []*FallbackConfig {
{
Host : config .Host ,
Port : config .Port ,
TLSConfig : config .TLSConfig ,
},
}
fallbackConfigs = append (fallbackConfigs , config .Fallbacks ...)
ctx := octx
fallbackConfigs , err = expandWithIPs (ctx , config .LookupFunc , fallbackConfigs )
if err != nil {
return nil , &connectError {config : config , msg : "hostname resolving error" , err : err }
}
if len (fallbackConfigs ) == 0 {
return nil , &connectError {config : config , msg : "hostname resolving error" , err : errors .New ("ip addr wasn't found" )}
}
foundBestServer := false
var fallbackConfig *FallbackConfig
for i , fc := range fallbackConfigs {
if config .ConnectTimeout != 0 {
if i == 0 || (fallbackConfigs [i ].Host != fallbackConfigs [i -1 ].Host ) {
var cancel context .CancelFunc
ctx , cancel = context .WithTimeout (octx , config .ConnectTimeout )
defer cancel ()
}
} else {
ctx = octx
}
pgConn , err = connect (ctx , config , fc , false )
if err == nil {
foundBestServer = true
break
} else if pgerr , ok := err .(*PgError ); ok {
err = &connectError {config : config , msg : "server error" , err : pgerr }
const ERRCODE_INVALID_PASSWORD = "28P01"
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000"
const ERRCODE_INVALID_CATALOG_NAME = "3D000"
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501"
if pgerr .Code == ERRCODE_INVALID_PASSWORD ||
pgerr .Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc .TLSConfig != nil ||
pgerr .Code == ERRCODE_INVALID_CATALOG_NAME ||
pgerr .Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break
}
} else if cerr , ok := err .(*connectError ); ok {
if _ , ok := cerr .err .(*NotPreferredError ); ok {
fallbackConfig = fc
}
}
}
if !foundBestServer && fallbackConfig != nil {
pgConn , err = connect (ctx , config , fallbackConfig , true )
if pgerr , ok := err .(*PgError ); ok {
err = &connectError {config : config , msg : "server error" , err : pgerr }
}
}
if err != nil {
return nil , err
}
if config .AfterConnect != nil {
err := config .AfterConnect (ctx , pgConn )
if err != nil {
pgConn .conn .Close ()
return nil , &connectError {config : config , msg : "AfterConnect error" , err : err }
}
}
return pgConn , nil
}
func expandWithIPs(ctx context .Context , lookupFn LookupFunc , fallbacks []*FallbackConfig ) ([]*FallbackConfig , error ) {
var configs []*FallbackConfig
var lookupErrors []error
for _ , fb := range fallbacks {
if isAbsolutePath (fb .Host ) {
configs = append (configs , &FallbackConfig {
Host : fb .Host ,
Port : fb .Port ,
TLSConfig : fb .TLSConfig ,
})
continue
}
ips , err := lookupFn (ctx , fb .Host )
if err != nil {
lookupErrors = append (lookupErrors , err )
continue
}
for _ , ip := range ips {
splitIP , splitPort , err := net .SplitHostPort (ip )
if err == nil {
port , err := strconv .ParseUint (splitPort , 10 , 16 )
if err != nil {
return nil , fmt .Errorf ("error parsing port (%s) from lookup: %w" , splitPort , err )
}
configs = append (configs , &FallbackConfig {
Host : splitIP ,
Port : uint16 (port ),
TLSConfig : fb .TLSConfig ,
})
} else {
configs = append (configs , &FallbackConfig {
Host : ip ,
Port : fb .Port ,
TLSConfig : fb .TLSConfig ,
})
}
}
}
if len (configs ) == 0 && len (lookupErrors ) > 0 {
return nil , lookupErrors [0 ]
}
return configs , nil
}
func connect(ctx context .Context , config *Config , fallbackConfig *FallbackConfig ,
ignoreNotPreferredErr bool ,
) (*PgConn , error ) {
pgConn := new (PgConn )
pgConn .config = config
pgConn .cleanupDone = make (chan struct {})
var err error
network , address := NetworkAddress (fallbackConfig .Host , fallbackConfig .Port )
netConn , err := config .DialFunc (ctx , network , address )
if err != nil {
return nil , &connectError {config : config , msg : "dial error" , err : normalizeTimeoutError (ctx , err )}
}
pgConn .conn = netConn
pgConn .contextWatcher = newContextWatcher (netConn )
pgConn .contextWatcher .Watch (ctx )
if fallbackConfig .TLSConfig != nil {
nbTLSConn , err := startTLS (netConn , fallbackConfig .TLSConfig )
pgConn .contextWatcher .Unwatch ()
if err != nil {
netConn .Close ()
return nil , &connectError {config : config , msg : "tls error" , err : err }
}
pgConn .conn = nbTLSConn
pgConn .contextWatcher = newContextWatcher (nbTLSConn )
pgConn .contextWatcher .Watch (ctx )
}
defer pgConn .contextWatcher .Unwatch ()
pgConn .parameterStatuses = make (map [string ]string )
pgConn .status = connStatusConnecting
pgConn .bgReader = bgreader .New (pgConn .conn )
pgConn .slowWriteTimer = time .AfterFunc (time .Duration (math .MaxInt64 ), pgConn .bgReader .Start )
pgConn .slowWriteTimer .Stop ()
pgConn .frontend = config .BuildFrontend (pgConn .bgReader , pgConn .conn )
startupMsg := pgproto3 .StartupMessage {
ProtocolVersion : pgproto3 .ProtocolVersionNumber ,
Parameters : make (map [string ]string ),
}
for k , v := range config .RuntimeParams {
startupMsg .Parameters [k ] = v
}
startupMsg .Parameters ["user" ] = config .User
if config .Database != "" {
startupMsg .Parameters ["database" ] = config .Database
}
pgConn .frontend .Send (&startupMsg )
if err := pgConn .flushWithPotentialWriteReadDeadlock (); err != nil {
pgConn .conn .Close ()
return nil , &connectError {config : config , msg : "failed to write startup message" , err : normalizeTimeoutError (ctx , err )}
}
for {
msg , err := pgConn .receiveMessage ()
if err != nil {
pgConn .conn .Close ()
if err , ok := err .(*PgError ); ok {
return nil , err
}
return nil , &connectError {config : config , msg : "failed to receive message" , err : normalizeTimeoutError (ctx , err )}
}
switch msg := msg .(type ) {
case *pgproto3 .BackendKeyData :
pgConn .pid = msg .ProcessID
pgConn .secretKey = msg .SecretKey
case *pgproto3 .AuthenticationOk :
case *pgproto3 .AuthenticationCleartextPassword :
err = pgConn .txPasswordMessage (pgConn .config .Password )
if err != nil {
pgConn .conn .Close ()
return nil , &connectError {config : config , msg : "failed to write password message" , err : err }
}
case *pgproto3 .AuthenticationMD5Password :
digestedPassword := "md5" + hexMD5 (hexMD5 (pgConn .config .Password +pgConn .config .User )+string (msg .Salt [:]))
err = pgConn .txPasswordMessage (digestedPassword )
if err != nil {
pgConn .conn .Close ()
return nil , &connectError {config : config , msg : "failed to write password message" , err : err }
}
case *pgproto3 .AuthenticationSASL :
err = pgConn .scramAuth (msg .AuthMechanisms )
if err != nil {
pgConn .conn .Close ()
return nil , &connectError {config : config , msg : "failed SASL auth" , err : err }
}
case *pgproto3 .AuthenticationGSS :
err = pgConn .gssAuth ()
if err != nil {
pgConn .conn .Close ()
return nil , &connectError {config : config , msg : "failed GSS auth" , err : err }
}
case *pgproto3 .ReadyForQuery :
pgConn .status = connStatusIdle
if config .ValidateConnect != nil {
pgConn .contextWatcher .Unwatch ()
err := config .ValidateConnect (ctx , pgConn )
if err != nil {
if _ , ok := err .(*NotPreferredError ); ignoreNotPreferredErr && ok {
return pgConn , nil
}
pgConn .conn .Close ()
return nil , &connectError {config : config , msg : "ValidateConnect failed" , err : err }
}
}
return pgConn , nil
case *pgproto3 .ParameterStatus , *pgproto3 .NoticeResponse :
case *pgproto3 .ErrorResponse :
pgConn .conn .Close ()
return nil , ErrorResponseToPgError (msg )
default :
pgConn .conn .Close ()
return nil , &connectError {config : config , msg : "received unexpected message" , err : err }
}
}
}
func newContextWatcher(conn net .Conn ) *ctxwatch .ContextWatcher {
return ctxwatch .NewContextWatcher (
func () { conn .SetDeadline (time .Date (1 , 1 , 1 , 1 , 1 , 1 , 1 , time .UTC )) },
func () { conn .SetDeadline (time .Time {}) },
)
}
func startTLS(conn net .Conn , tlsConfig *tls .Config ) (net .Conn , error ) {
err := binary .Write (conn , binary .BigEndian , []int32 {8 , 80877103 })
if err != nil {
return nil , err
}
response := make ([]byte , 1 )
if _, err = io .ReadFull (conn , response ); err != nil {
return nil , err
}
if response [0 ] != 'S' {
return nil , errors .New ("server refused TLS connection" )
}
return tls .Client (conn , tlsConfig ), nil
}
func (pgConn *PgConn ) txPasswordMessage (password string ) (err error ) {
pgConn .frontend .Send (&pgproto3 .PasswordMessage {Password : password })
return pgConn .flushWithPotentialWriteReadDeadlock ()
}
func hexMD5(s string ) string {
hash := md5 .New ()
io .WriteString (hash , s )
return hex .EncodeToString (hash .Sum (nil ))
}
func (pgConn *PgConn ) signalMessage () chan struct {} {
if pgConn .bufferingReceive {
panic ("BUG: signalMessage when already in progress" )
}
pgConn .bufferingReceive = true
pgConn .bufferingReceiveMux .Lock ()
ch := make (chan struct {})
go func () {
pgConn .bufferingReceiveMsg , pgConn .bufferingReceiveErr = pgConn .frontend .Receive ()
pgConn .bufferingReceiveMux .Unlock ()
close (ch )
}()
return ch
}
func (pgConn *PgConn ) ReceiveMessage (ctx context .Context ) (pgproto3 .BackendMessage , error ) {
if err := pgConn .lock (); err != nil {
return nil , err
}
defer pgConn .unlock ()
if ctx != context .Background () {
select {
case <- ctx .Done ():
return nil , newContextAlreadyDoneError (ctx )
default :
}
pgConn .contextWatcher .Watch (ctx )
defer pgConn .contextWatcher .Unwatch ()
}
msg , err := pgConn .receiveMessage ()
if err != nil {
err = &pgconnError {
msg : "receive message failed" ,
err : normalizeTimeoutError (ctx , err ),
safeToRetry : true ,
}
}
return msg , err
}
func (pgConn *PgConn ) peekMessage () (pgproto3 .BackendMessage , error ) {
if pgConn .peekedMsg != nil {
return pgConn .peekedMsg , nil
}
var msg pgproto3 .BackendMessage
var err error
if pgConn .bufferingReceive {
pgConn .bufferingReceiveMux .Lock ()
msg = pgConn .bufferingReceiveMsg
err = pgConn .bufferingReceiveErr
pgConn .bufferingReceiveMux .Unlock ()
pgConn .bufferingReceive = false
var netErr net .Error
if errors .As (err , &netErr ) && netErr .Timeout () {
msg , err = pgConn .frontend .Receive ()
}
} else {
msg , err = pgConn .frontend .Receive ()
}
if err != nil {
var netErr net .Error
isNetErr := errors .As (err , &netErr )
if !(isNetErr && netErr .Timeout ()) {
pgConn .asyncClose ()
}
return nil , err
}
pgConn .peekedMsg = msg
return msg , nil
}
func (pgConn *PgConn ) receiveMessage () (pgproto3 .BackendMessage , error ) {
msg , err := pgConn .peekMessage ()
if err != nil {
return nil , err
}
pgConn .peekedMsg = nil
switch msg := msg .(type ) {
case *pgproto3 .ReadyForQuery :
pgConn .txStatus = msg .TxStatus
case *pgproto3 .ParameterStatus :
pgConn .parameterStatuses [msg .Name ] = msg .Value
case *pgproto3 .ErrorResponse :
if msg .Severity == "FATAL" {
pgConn .status = connStatusClosed
pgConn .conn .Close ()
close (pgConn .cleanupDone )
return nil , ErrorResponseToPgError (msg )
}
case *pgproto3 .NoticeResponse :
if pgConn .config .OnNotice != nil {
pgConn .config .OnNotice (pgConn , noticeResponseToNotice (msg ))
}
case *pgproto3 .NotificationResponse :
if pgConn .config .OnNotification != nil {
pgConn .config .OnNotification (pgConn , &Notification {PID : msg .PID , Channel : msg .Channel , Payload : msg .Payload })
}
}
return msg , nil
}
func (pgConn *PgConn ) Conn () net .Conn {
return pgConn .conn
}
func (pgConn *PgConn ) PID () uint32 {
return pgConn .pid
}
func (pgConn *PgConn ) TxStatus () byte {
return pgConn .txStatus
}
func (pgConn *PgConn ) SecretKey () uint32 {
return pgConn .secretKey
}
func (pgConn *PgConn ) Frontend () *pgproto3 .Frontend {
return pgConn .frontend
}
func (pgConn *PgConn ) Close (ctx context .Context ) error {
if pgConn .status == connStatusClosed {
return nil
}
pgConn .status = connStatusClosed
defer close (pgConn .cleanupDone )
defer pgConn .conn .Close ()
if ctx != context .Background () {
pgConn .contextWatcher .Unwatch ()
pgConn .contextWatcher .Watch (ctx )
defer pgConn .contextWatcher .Unwatch ()
}
pgConn .frontend .Send (&pgproto3 .Terminate {})
pgConn .flushWithPotentialWriteReadDeadlock ()
return pgConn .conn .Close ()
}
func (pgConn *PgConn ) asyncClose () {
if pgConn .status == connStatusClosed {
return
}
pgConn .status = connStatusClosed
go func () {
defer close (pgConn .cleanupDone )
defer pgConn .conn .Close ()
deadline := time .Now ().Add (time .Second * 15 )
ctx , cancel := context .WithDeadline (context .Background (), deadline )
defer cancel ()
pgConn .CancelRequest (ctx )
pgConn .conn .SetDeadline (deadline )
pgConn .frontend .Send (&pgproto3 .Terminate {})
pgConn .flushWithPotentialWriteReadDeadlock ()
}()
}
func (pgConn *PgConn ) CleanupDone () chan (struct {}) {
return pgConn .cleanupDone
}
func (pgConn *PgConn ) IsClosed () bool {
return pgConn .status < connStatusIdle
}
func (pgConn *PgConn ) IsBusy () bool {
return pgConn .status == connStatusBusy
}
func (pgConn *PgConn ) lock () error {
switch pgConn .status {
case connStatusBusy :
return &connLockError {status : "conn busy" }
case connStatusClosed :
return &connLockError {status : "conn closed" }
case connStatusUninitialized :
return &connLockError {status : "conn uninitialized" }
}
pgConn .status = connStatusBusy
return nil
}
func (pgConn *PgConn ) unlock () {
switch pgConn .status {
case connStatusBusy :
pgConn .status = connStatusIdle
case connStatusClosed :
default :
panic ("BUG: cannot unlock unlocked connection" )
}
}
func (pgConn *PgConn ) ParameterStatus (key string ) string {
return pgConn .parameterStatuses [key ]
}
type CommandTag struct {
s string
}
func NewCommandTag (s string ) CommandTag {
return CommandTag {s : s }
}
func (ct CommandTag ) RowsAffected () int64 {
idx := -1
for i := len (ct .s ) - 1 ; i >= 0 ; i -- {
if ct .s [i ] >= '0' && ct .s [i ] <= '9' {
idx = i
} else {
break
}
}
if idx == -1 {
return 0
}
var n int64
for _ , b := range ct .s [idx :] {
n = n *10 + int64 (b -'0' )
}
return n
}
func (ct CommandTag ) String () string {
return ct .s
}
func (ct CommandTag ) Insert () bool {
return strings .HasPrefix (ct .s , "INSERT" )
}
func (ct CommandTag ) Update () bool {
return strings .HasPrefix (ct .s , "UPDATE" )
}
func (ct CommandTag ) Delete () bool {
return strings .HasPrefix (ct .s , "DELETE" )
}
func (ct CommandTag ) Select () bool {
return strings .HasPrefix (ct .s , "SELECT" )
}
type FieldDescription struct {
Name string
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier int32
Format int16
}
func (pgConn *PgConn ) convertRowDescription (dst []FieldDescription , rd *pgproto3 .RowDescription ) []FieldDescription {
if cap (dst ) >= len (rd .Fields ) {
dst = dst [:len (rd .Fields ):len (rd .Fields )]
} else {
dst = make ([]FieldDescription , len (rd .Fields ))
}
for i := range rd .Fields {
dst [i ].Name = string (rd .Fields [i ].Name )
dst [i ].TableOID = rd .Fields [i ].TableOID
dst [i ].TableAttributeNumber = rd .Fields [i ].TableAttributeNumber
dst [i ].DataTypeOID = rd .Fields [i ].DataTypeOID
dst [i ].DataTypeSize = rd .Fields [i ].DataTypeSize
dst [i ].TypeModifier = rd .Fields [i ].TypeModifier
dst [i ].Format = rd .Fields [i ].Format
}
return dst
}
type StatementDescription struct {
Name string
SQL string
ParamOIDs []uint32
Fields []FieldDescription
}
func (pgConn *PgConn ) Prepare (ctx context .Context , name , sql string , paramOIDs []uint32 ) (*StatementDescription , error ) {
if err := pgConn .lock (); err != nil {
return nil , err
}
defer pgConn .unlock ()
if ctx != context .Background () {
select {
case <- ctx .Done ():
return nil , newContextAlreadyDoneError (ctx )
default :
}
pgConn .contextWatcher .Watch (ctx )
defer pgConn .contextWatcher .Unwatch ()
}
pgConn .frontend .SendParse (&pgproto3 .Parse {Name : name , Query : sql , ParameterOIDs : paramOIDs })
pgConn .frontend .SendDescribe (&pgproto3 .Describe {ObjectType : 'S' , Name : name })
pgConn .frontend .SendSync (&pgproto3 .Sync {})
err := pgConn .flushWithPotentialWriteReadDeadlock ()
if err != nil {
pgConn .asyncClose ()
return nil , err
}
psd := &StatementDescription {Name : name , SQL : sql }
var parseErr error
readloop :
for {
msg , err := pgConn .receiveMessage ()
if err != nil {
pgConn .asyncClose ()
return nil , normalizeTimeoutError (ctx , err )
}
switch msg := msg .(type ) {
case *pgproto3 .ParameterDescription :
psd .ParamOIDs = make ([]uint32 , len (msg .ParameterOIDs ))
copy (psd .ParamOIDs , msg .ParameterOIDs )
case *pgproto3 .RowDescription :
psd .Fields = pgConn .convertRowDescription (nil , msg )
case *pgproto3 .ErrorResponse :
parseErr = ErrorResponseToPgError (msg )
case *pgproto3 .ReadyForQuery :
break readloop
}
}
if parseErr != nil {
return nil , parseErr
}
return psd , nil
}
func ErrorResponseToPgError (msg *pgproto3 .ErrorResponse ) *PgError {
return &PgError {
Severity : msg .Severity ,
Code : string (msg .Code ),
Message : string (msg .Message ),
Detail : string (msg .Detail ),
Hint : msg .Hint ,
Position : msg .Position ,
InternalPosition : msg .InternalPosition ,
InternalQuery : string (msg .InternalQuery ),
Where : string (msg .Where ),
SchemaName : string (msg .SchemaName ),
TableName : string (msg .TableName ),
ColumnName : string (msg .ColumnName ),
DataTypeName : string (msg .DataTypeName ),
ConstraintName : msg .ConstraintName ,
File : string (msg .File ),
Line : msg .Line ,
Routine : string (msg .Routine ),
}
}
func noticeResponseToNotice(msg *pgproto3 .NoticeResponse ) *Notice {
pgerr := ErrorResponseToPgError ((*pgproto3 .ErrorResponse )(msg ))
return (*Notice )(pgerr )
}
func (pgConn *PgConn ) CancelRequest (ctx context .Context ) error {
serverAddr := pgConn .conn .RemoteAddr ()
var serverNetwork string
var serverAddress string
if serverAddr .Network () == "unix" {
serverNetwork , serverAddress = NetworkAddress (pgConn .config .Host , pgConn .config .Port )
} else {
serverNetwork , serverAddress = serverAddr .Network (), serverAddr .String ()
}
cancelConn , err := pgConn .config .DialFunc (ctx , serverNetwork , serverAddress )
if err != nil {
if serverAddr .Network () != "unix" {
return err
}
serverNetwork , serverAddr := NetworkAddress (pgConn .config .Host , pgConn .config .Port )
cancelConn , err = pgConn .config .DialFunc (ctx , serverNetwork , serverAddr )
if err != nil {
return err
}
}
defer cancelConn .Close ()
if ctx != context .Background () {
contextWatcher := ctxwatch .NewContextWatcher (
func () { cancelConn .SetDeadline (time .Date (1 , 1 , 1 , 1 , 1 , 1 , 1 , time .UTC )) },
func () { cancelConn .SetDeadline (time .Time {}) },
)
contextWatcher .Watch (ctx )
defer contextWatcher .Unwatch ()
}
buf := make ([]byte , 16 )
binary .BigEndian .PutUint32 (buf [0 :4 ], 16 )
binary .BigEndian .PutUint32 (buf [4 :8 ], 80877102 )
binary .BigEndian .PutUint32 (buf [8 :12 ], uint32 (pgConn .pid ))
binary .BigEndian .PutUint32 (buf [12 :16 ], uint32 (pgConn .secretKey ))
_, err = cancelConn .Write (buf )
return err
}
func (pgConn *PgConn ) WaitForNotification (ctx context .Context ) error {
if err := pgConn .lock (); err != nil {
return err
}
defer pgConn .unlock ()
if ctx != context .Background () {
select {
case <- ctx .Done ():
return newContextAlreadyDoneError (ctx )
default :
}
pgConn .contextWatcher .Watch (ctx )
defer pgConn .contextWatcher .Unwatch ()
}
for {
msg , err := pgConn .receiveMessage ()
if err != nil {
return normalizeTimeoutError (ctx , err )
}
switch msg .(type ) {
case *pgproto3 .NotificationResponse :
return nil
}
}
}
func (pgConn *PgConn ) Exec (ctx context .Context , sql string ) *MultiResultReader {
if err := pgConn .lock (); err != nil {
return &MultiResultReader {
closed : true ,
err : err ,
}
}
pgConn .multiResultReader = MultiResultReader {
pgConn : pgConn ,
ctx : ctx ,
}
multiResult := &pgConn .multiResultReader
if ctx != context .Background () {
select {
case <- ctx .Done ():
multiResult .closed = true
multiResult .err = newContextAlreadyDoneError (ctx )
pgConn .unlock ()
return multiResult
default :
}
pgConn .contextWatcher .Watch (ctx )
}
pgConn .frontend .SendQuery (&pgproto3 .Query {String : sql })
err := pgConn .flushWithPotentialWriteReadDeadlock ()
if err != nil {
pgConn .asyncClose ()
pgConn .contextWatcher .Unwatch ()
multiResult .closed = true
multiResult .err = err
pgConn .unlock ()
return multiResult
}
return multiResult
}
func (pgConn *PgConn ) ExecParams (ctx context .Context , sql string , paramValues [][]byte , paramOIDs []uint32 , paramFormats []int16 , resultFormats []int16 ) *ResultReader {
result := pgConn .execExtendedPrefix (ctx , paramValues )
if result .closed {
return result
}
pgConn .frontend .SendParse (&pgproto3 .Parse {Query : sql , ParameterOIDs : paramOIDs })
pgConn .frontend .SendBind (&pgproto3 .Bind {ParameterFormatCodes : paramFormats , Parameters : paramValues , ResultFormatCodes : resultFormats })
pgConn .execExtendedSuffix (result )
return result
}
func (pgConn *PgConn ) ExecPrepared (ctx context .Context , stmtName string , paramValues [][]byte , paramFormats []int16 , resultFormats []int16 ) *ResultReader {
result := pgConn .execExtendedPrefix (ctx , paramValues )
if result .closed {
return result
}
pgConn .frontend .SendBind (&pgproto3 .Bind {PreparedStatement : stmtName , ParameterFormatCodes : paramFormats , Parameters : paramValues , ResultFormatCodes : resultFormats })
pgConn .execExtendedSuffix (result )
return result
}
func (pgConn *PgConn ) execExtendedPrefix (ctx context .Context , paramValues [][]byte ) *ResultReader {
pgConn .resultReader = ResultReader {
pgConn : pgConn ,
ctx : ctx ,
}
result := &pgConn .resultReader
if err := pgConn .lock (); err != nil {
result .concludeCommand (CommandTag {}, err )
result .closed = true
return result
}
if len (paramValues ) > math .MaxUint16 {
result .concludeCommand (CommandTag {}, fmt .Errorf ("extended protocol limited to %v parameters" , math .MaxUint16 ))
result .closed = true
pgConn .unlock ()
return result
}
if ctx != context .Background () {
select {
case <- ctx .Done ():
result .concludeCommand (CommandTag {}, newContextAlreadyDoneError (ctx ))
result .closed = true
pgConn .unlock ()
return result
default :
}
pgConn .contextWatcher .Watch (ctx )
}
return result
}
func (pgConn *PgConn ) execExtendedSuffix (result *ResultReader ) {
pgConn .frontend .SendDescribe (&pgproto3 .Describe {ObjectType : 'P' })
pgConn .frontend .SendExecute (&pgproto3 .Execute {})
pgConn .frontend .SendSync (&pgproto3 .Sync {})
err := pgConn .flushWithPotentialWriteReadDeadlock ()
if err != nil {
pgConn .asyncClose ()
result .concludeCommand (CommandTag {}, err )
pgConn .contextWatcher .Unwatch ()
result .closed = true
pgConn .unlock ()
return
}
result .readUntilRowDescription ()
}
func (pgConn *PgConn ) CopyTo (ctx context .Context , w io .Writer , sql string ) (CommandTag , error ) {
if err := pgConn .lock (); err != nil {
return CommandTag {}, err
}
if ctx != context .Background () {
select {
case <- ctx .Done ():
pgConn .unlock ()
return CommandTag {}, newContextAlreadyDoneError (ctx )
default :
}
pgConn .contextWatcher .Watch (ctx )
defer pgConn .contextWatcher .Unwatch ()
}
pgConn .frontend .SendQuery (&pgproto3 .Query {String : sql })
err := pgConn .flushWithPotentialWriteReadDeadlock ()
if err != nil {
pgConn .asyncClose ()
pgConn .unlock ()
return CommandTag {}, err
}
var commandTag CommandTag
var pgErr error
for {
msg , err := pgConn .receiveMessage ()
if err != nil {
pgConn .asyncClose ()
return CommandTag {}, normalizeTimeoutError (ctx , err )
}
switch msg := msg .(type ) {
case *pgproto3 .CopyDone :
case *pgproto3 .CopyData :
_ , err := w .Write (msg .Data )
if err != nil {
pgConn .asyncClose ()
return CommandTag {}, err
}
case *pgproto3 .ReadyForQuery :
pgConn .unlock ()
return commandTag , pgErr
case *pgproto3 .CommandComplete :
commandTag = pgConn .makeCommandTag (msg .CommandTag )
case *pgproto3 .ErrorResponse :
pgErr = ErrorResponseToPgError (msg )
}
}
}
func (pgConn *PgConn ) CopyFrom (ctx context .Context , r io .Reader , sql string ) (CommandTag , error ) {
if err := pgConn .lock (); err != nil {
return CommandTag {}, err
}
defer pgConn .unlock ()
if ctx != context .Background () {
select {
case <- ctx .Done ():
return CommandTag {}, newContextAlreadyDoneError (ctx )
default :
}
pgConn .contextWatcher .Watch (ctx )
defer pgConn .contextWatcher .Unwatch ()
}
pgConn .frontend .SendQuery (&pgproto3 .Query {String : sql })
err := pgConn .flushWithPotentialWriteReadDeadlock ()
if err != nil {
pgConn .asyncClose ()
return CommandTag {}, err
}
abortCopyChan := make (chan struct {})
copyErrChan := make (chan error , 1 )
signalMessageChan := pgConn .signalMessage ()
var wg sync .WaitGroup
wg .Add (1 )
go func () {
defer wg .Done ()
buf := iobufpool .Get (65536 )
defer iobufpool .Put (buf )
(*buf )[0 ] = 'd'
for {
n , readErr := r .Read ((*buf )[5 :cap (*buf )])
if n > 0 {
*buf = (*buf )[0 : n +5 ]
pgio .SetInt32 ((*buf )[1 :], int32 (n +4 ))
writeErr := pgConn .frontend .SendUnbufferedEncodedCopyData (*buf )
if writeErr != nil {
pgConn .conn .Close ()
copyErrChan <- writeErr
return
}
}
if readErr != nil {
copyErrChan <- readErr
return
}
select {
case <- abortCopyChan :
return
default :
}
}
}()
var pgErr error
var copyErr error
for copyErr == nil && pgErr == nil {
select {
case copyErr = <- copyErrChan :
case <- signalMessageChan :
if err := pgConn .bufferingReceiveErr ; err != nil {
pgConn .status = connStatusClosed
pgConn .conn .Close ()
close (pgConn .cleanupDone )
return CommandTag {}, normalizeTimeoutError (ctx , err )
}
msg , _ := pgConn .receiveMessage ()
switch msg := msg .(type ) {
case *pgproto3 .ErrorResponse :
pgErr = ErrorResponseToPgError (msg )
default :
signalMessageChan = pgConn .signalMessage ()
}
}
}
close (abortCopyChan )
wg .Wait ()
if copyErr == io .EOF || pgErr != nil {
pgConn .frontend .Send (&pgproto3 .CopyDone {})
} else {
pgConn .frontend .Send (&pgproto3 .CopyFail {Message : copyErr .Error()})
}
err = pgConn .flushWithPotentialWriteReadDeadlock ()
if err != nil {
pgConn .asyncClose ()
return CommandTag {}, err
}
var commandTag CommandTag
for {
msg , err := pgConn .receiveMessage ()
if err != nil {
pgConn .asyncClose ()
return CommandTag {}, normalizeTimeoutError (ctx , err )
}
switch msg := msg .(type ) {
case *pgproto3 .ReadyForQuery :
return commandTag , pgErr
case *pgproto3 .CommandComplete :
commandTag = pgConn .makeCommandTag (msg .CommandTag )
case *pgproto3 .ErrorResponse :
pgErr = ErrorResponseToPgError (msg )
}
}
}
type MultiResultReader struct {
pgConn *PgConn
ctx context .Context
pipeline *Pipeline
rr *ResultReader
closed bool
err error
}
func (mrr *MultiResultReader ) ReadAll () ([]*Result , error ) {
var results []*Result
for mrr .NextResult () {
results = append (results , mrr .ResultReader ().Read ())
}
err := mrr .Close ()
return results , err
}
func (mrr *MultiResultReader ) receiveMessage () (pgproto3 .BackendMessage , error ) {
msg , err := mrr .pgConn .receiveMessage ()
if err != nil {
mrr .pgConn .contextWatcher .Unwatch ()
mrr .err = normalizeTimeoutError (mrr .ctx , err )
mrr .closed = true
mrr .pgConn .asyncClose ()
return nil , mrr .err
}
switch msg := msg .(type ) {
case *pgproto3 .ReadyForQuery :
mrr .closed = true
if mrr .pipeline != nil {
mrr .pipeline .expectedReadyForQueryCount --
} else {
mrr .pgConn .contextWatcher .Unwatch ()
mrr .pgConn .unlock ()
}
case *pgproto3 .ErrorResponse :
mrr .err = ErrorResponseToPgError (msg )
}
return msg , nil
}
func (mrr *MultiResultReader ) NextResult () bool {
for !mrr .closed && mrr .err == nil {
msg , err := mrr .receiveMessage ()
if err != nil {
return false
}
switch msg := msg .(type ) {
case *pgproto3 .RowDescription :
mrr .pgConn .resultReader = ResultReader {
pgConn : mrr .pgConn ,
multiResultReader : mrr ,
ctx : mrr .ctx ,
fieldDescriptions : mrr .pgConn .convertRowDescription (mrr .pgConn .fieldDescriptions [:], msg ),
}
mrr .rr = &mrr .pgConn .resultReader
return true
case *pgproto3 .CommandComplete :
mrr .pgConn .resultReader = ResultReader {
commandTag : mrr .pgConn .makeCommandTag (msg .CommandTag ),
commandConcluded : true ,
closed : true ,
}
mrr .rr = &mrr .pgConn .resultReader
return true
case *pgproto3 .EmptyQueryResponse :
return false
}
}
return false
}
func (mrr *MultiResultReader ) ResultReader () *ResultReader {
return mrr .rr
}
func (mrr *MultiResultReader ) Close () error {
for !mrr .closed {
_ , err := mrr .receiveMessage ()
if err != nil {
return mrr .err
}
}
return mrr .err
}
type ResultReader struct {
pgConn *PgConn
multiResultReader *MultiResultReader
pipeline *Pipeline
ctx context .Context
fieldDescriptions []FieldDescription
rowValues [][]byte
commandTag CommandTag
commandConcluded bool
closed bool
err error
}
type Result struct {
FieldDescriptions []FieldDescription
Rows [][][]byte
CommandTag CommandTag
Err error
}
func (rr *ResultReader ) Read () *Result {
br := &Result {}
for rr .NextRow () {
if br .FieldDescriptions == nil {
br .FieldDescriptions = make ([]FieldDescription , len (rr .FieldDescriptions ()))
copy (br .FieldDescriptions , rr .FieldDescriptions ())
}
values := rr .Values ()
row := make ([][]byte , len (values ))
for i := range row {
row [i ] = make ([]byte , len (values [i ]))
copy (row [i ], values [i ])
}
br .Rows = append (br .Rows , row )
}
br .CommandTag , br .Err = rr .Close ()
return br
}
func (rr *ResultReader ) NextRow () bool {
for !rr .commandConcluded {
msg , err := rr .receiveMessage ()
if err != nil {
return false
}
switch msg := msg .(type ) {
case *pgproto3 .DataRow :
rr .rowValues = msg .Values
return true
}
}
return false
}
func (rr *ResultReader ) FieldDescriptions () []FieldDescription {
return rr .fieldDescriptions
}
func (rr *ResultReader ) Values () [][]byte {
return rr .rowValues
}
func (rr *ResultReader ) Close () (CommandTag , error ) {
if rr .closed {
return rr .commandTag , rr .err
}
rr .closed = true
for !rr .commandConcluded {
_ , err := rr .receiveMessage ()
if err != nil {
return CommandTag {}, rr .err
}
}
if rr .multiResultReader == nil && rr .pipeline == nil {
for {
msg , err := rr .receiveMessage ()
if err != nil {
return CommandTag {}, rr .err
}
switch msg := msg .(type ) {
case *pgproto3 .ErrorResponse :
rr .err = ErrorResponseToPgError (msg )
case *pgproto3 .ReadyForQuery :
rr .pgConn .contextWatcher .Unwatch ()
rr .pgConn .unlock ()
return rr .commandTag , rr .err
}
}
}
return rr .commandTag , rr .err
}
func (rr *ResultReader ) readUntilRowDescription () {
for !rr .commandConcluded {
msg , _ := rr .pgConn .peekMessage ()
if _ , ok := msg .(*pgproto3 .DataRow ); ok {
return
}
msg , _ = rr .receiveMessage ()
if _ , ok := msg .(*pgproto3 .RowDescription ); ok {
return
}
}
}
func (rr *ResultReader ) receiveMessage () (msg pgproto3 .BackendMessage , err error ) {
if rr .multiResultReader == nil {
msg , err = rr .pgConn .receiveMessage ()
} else {
msg , err = rr .multiResultReader .receiveMessage ()
}
if err != nil {
err = normalizeTimeoutError (rr .ctx , err )
rr .concludeCommand (CommandTag {}, err )
rr .pgConn .contextWatcher .Unwatch ()
rr .closed = true
if rr .multiResultReader == nil {
rr .pgConn .asyncClose ()
}
return nil , rr .err
}
switch msg := msg .(type ) {
case *pgproto3 .RowDescription :
rr .fieldDescriptions = rr .pgConn .convertRowDescription (rr .pgConn .fieldDescriptions [:], msg )
case *pgproto3 .CommandComplete :
rr .concludeCommand (rr .pgConn .makeCommandTag (msg .CommandTag ), nil )
case *pgproto3 .EmptyQueryResponse :
rr .concludeCommand (CommandTag {}, nil )
case *pgproto3 .ErrorResponse :
rr .concludeCommand (CommandTag {}, ErrorResponseToPgError (msg ))
}
return msg , nil
}
func (rr *ResultReader ) concludeCommand (commandTag CommandTag , err error ) {
if err != nil && rr .err == nil {
rr .err = err
}
if rr .commandConcluded {
return
}
rr .commandTag = commandTag
rr .rowValues = nil
rr .commandConcluded = true
}
type Batch struct {
buf []byte
}
func (batch *Batch ) ExecParams (sql string , paramValues [][]byte , paramOIDs []uint32 , paramFormats []int16 , resultFormats []int16 ) {
batch .buf = (&pgproto3 .Parse {Query : sql , ParameterOIDs : paramOIDs }).Encode (batch .buf )
batch .ExecPrepared ("" , paramValues , paramFormats , resultFormats )
}
func (batch *Batch ) ExecPrepared (stmtName string , paramValues [][]byte , paramFormats []int16 , resultFormats []int16 ) {
batch .buf = (&pgproto3 .Bind {PreparedStatement : stmtName , ParameterFormatCodes : paramFormats , Parameters : paramValues , ResultFormatCodes : resultFormats }).Encode (batch .buf )
batch .buf = (&pgproto3 .Describe {ObjectType : 'P' }).Encode (batch .buf )
batch .buf = (&pgproto3 .Execute {}).Encode (batch .buf )
}
func (pgConn *PgConn ) ExecBatch (ctx context .Context , batch *Batch ) *MultiResultReader {
if err := pgConn .lock (); err != nil {
return &MultiResultReader {
closed : true ,
err : err ,
}
}
pgConn .multiResultReader = MultiResultReader {
pgConn : pgConn ,
ctx : ctx ,
}
multiResult := &pgConn .multiResultReader
if ctx != context .Background () {
select {
case <- ctx .Done ():
multiResult .closed = true
multiResult .err = newContextAlreadyDoneError (ctx )
pgConn .unlock ()
return multiResult
default :
}
pgConn .contextWatcher .Watch (ctx )
}
batch .buf = (&pgproto3 .Sync {}).Encode (batch .buf )
pgConn .enterPotentialWriteReadDeadlock ()
defer pgConn .exitPotentialWriteReadDeadlock ()
_ , err := pgConn .conn .Write (batch .buf )
if err != nil {
multiResult .closed = true
multiResult .err = err
pgConn .unlock ()
return multiResult
}
return multiResult
}
func (pgConn *PgConn ) EscapeString (s string ) (string , error ) {
if pgConn .ParameterStatus ("standard_conforming_strings" ) != "on" {
return "" , errors .New ("EscapeString must be run with standard_conforming_strings=on" )
}
if pgConn .ParameterStatus ("client_encoding" ) != "UTF8" {
return "" , errors .New ("EscapeString must be run with client_encoding=UTF8" )
}
return strings .Replace (s , "'" , "''" , -1 ), nil
}
func (pgConn *PgConn ) CheckConn () error {
ctx , cancel := context .WithTimeout (context .Background (), 1 *time .Millisecond )
defer cancel ()
_ , err := pgConn .ReceiveMessage (ctx )
if err != nil {
if !Timeout (err ) {
return err
}
}
return nil
}
func (pgConn *PgConn ) Ping (ctx context .Context ) error {
return pgConn .Exec (ctx , "-- ping" ).Close ()
}
func (pgConn *PgConn ) makeCommandTag (buf []byte ) CommandTag {
return CommandTag {s : string (buf )}
}
func (pgConn *PgConn ) enterPotentialWriteReadDeadlock () {
if pgConn .slowWriteTimer .Reset (15 * time .Millisecond ) {
panic ("BUG: slow write timer already active" )
}
}
func (pgConn *PgConn ) exitPotentialWriteReadDeadlock () {
_ = pgConn .slowWriteTimer .Stop ()
pgConn .bgReader .Stop ()
}
func (pgConn *PgConn ) flushWithPotentialWriteReadDeadlock () error {
pgConn .enterPotentialWriteReadDeadlock ()
defer pgConn .exitPotentialWriteReadDeadlock ()
err := pgConn .frontend .Flush ()
return err
}
func (pgConn *PgConn ) SyncConn (ctx context .Context ) error {
for i := 0 ; i < 10 ; i ++ {
if pgConn .bgReader .Status () == bgreader .StatusStopped && pgConn .frontend .ReadBufferLen () == 0 {
return nil
}
err := pgConn .Ping (ctx )
if err != nil {
return fmt .Errorf ("SyncConn: Ping failed while syncing conn: %w" , err )
}
}
return errors .New ("SyncConn: conn never synchronized" )
}
type HijackedConn struct {
Conn net .Conn
PID uint32
SecretKey uint32
ParameterStatuses map [string ]string
TxStatus byte
Frontend *pgproto3 .Frontend
Config *Config
}
func (pgConn *PgConn ) Hijack () (*HijackedConn , error ) {
if err := pgConn .lock (); err != nil {
return nil , err
}
pgConn .status = connStatusClosed
return &HijackedConn {
Conn : pgConn .conn ,
PID : pgConn .pid ,
SecretKey : pgConn .secretKey ,
ParameterStatuses : pgConn .parameterStatuses ,
TxStatus : pgConn .txStatus ,
Frontend : pgConn .frontend ,
Config : pgConn .config ,
}, nil
}
func Construct (hc *HijackedConn ) (*PgConn , error ) {
pgConn := &PgConn {
conn : hc .Conn ,
pid : hc .PID ,
secretKey : hc .SecretKey ,
parameterStatuses : hc .ParameterStatuses ,
txStatus : hc .TxStatus ,
frontend : hc .Frontend ,
config : hc .Config ,
status : connStatusIdle ,
cleanupDone : make (chan struct {}),
}
pgConn .contextWatcher = newContextWatcher (pgConn .conn )
pgConn .bgReader = bgreader .New (pgConn .conn )
pgConn .slowWriteTimer = time .AfterFunc (time .Duration (math .MaxInt64 ), pgConn .bgReader .Start )
pgConn .slowWriteTimer .Stop ()
pgConn .frontend = hc .Config .BuildFrontend (pgConn .bgReader , pgConn .conn )
return pgConn , nil
}
type Pipeline struct {
conn *PgConn
ctx context .Context
expectedReadyForQueryCount int
pendingSync bool
err error
closed bool
}
type PipelineSync struct {}
type CloseComplete struct {}
func (pgConn *PgConn ) StartPipeline (ctx context .Context ) *Pipeline {
if err := pgConn .lock (); err != nil {
return &Pipeline {
closed : true ,
err : err ,
}
}
pgConn .pipeline = Pipeline {
conn : pgConn ,
ctx : ctx ,
}
pipeline := &pgConn .pipeline
if ctx != context .Background () {
select {
case <- ctx .Done ():
pipeline .closed = true
pipeline .err = newContextAlreadyDoneError (ctx )
pgConn .unlock ()
return pipeline
default :
}
pgConn .contextWatcher .Watch (ctx )
}
return pipeline
}
func (p *Pipeline ) SendPrepare (name , sql string , paramOIDs []uint32 ) {
if p .closed {
return
}
p .pendingSync = true
p .conn .frontend .SendParse (&pgproto3 .Parse {Name : name , Query : sql , ParameterOIDs : paramOIDs })
p .conn .frontend .SendDescribe (&pgproto3 .Describe {ObjectType : 'S' , Name : name })
}
func (p *Pipeline ) SendDeallocate (name string ) {
if p .closed {
return
}
p .pendingSync = true
p .conn .frontend .SendClose (&pgproto3 .Close {ObjectType : 'S' , Name : name })
}
func (p *Pipeline ) SendQueryParams (sql string , paramValues [][]byte , paramOIDs []uint32 , paramFormats []int16 , resultFormats []int16 ) {
if p .closed {
return
}
p .pendingSync = true
p .conn .frontend .SendParse (&pgproto3 .Parse {Query : sql , ParameterOIDs : paramOIDs })
p .conn .frontend .SendBind (&pgproto3 .Bind {ParameterFormatCodes : paramFormats , Parameters : paramValues , ResultFormatCodes : resultFormats })
p .conn .frontend .SendDescribe (&pgproto3 .Describe {ObjectType : 'P' })
p .conn .frontend .SendExecute (&pgproto3 .Execute {})
}
func (p *Pipeline ) SendQueryPrepared (stmtName string , paramValues [][]byte , paramFormats []int16 , resultFormats []int16 ) {
if p .closed {
return
}
p .pendingSync = true
p .conn .frontend .SendBind (&pgproto3 .Bind {PreparedStatement : stmtName , ParameterFormatCodes : paramFormats , Parameters : paramValues , ResultFormatCodes : resultFormats })
p .conn .frontend .SendDescribe (&pgproto3 .Describe {ObjectType : 'P' })
p .conn .frontend .SendExecute (&pgproto3 .Execute {})
}
func (p *Pipeline ) Flush () error {
if p .closed {
if p .err != nil {
return p .err
}
return errors .New ("pipeline closed" )
}
err := p .conn .flushWithPotentialWriteReadDeadlock ()
if err != nil {
err = normalizeTimeoutError (p .ctx , err )
p .conn .asyncClose ()
p .conn .contextWatcher .Unwatch ()
p .conn .unlock ()
p .closed = true
p .err = err
return err
}
return nil
}
func (p *Pipeline ) Sync () error {
p .conn .frontend .SendSync (&pgproto3 .Sync {})
err := p .Flush ()
if err != nil {
return err
}
p .pendingSync = false
p .expectedReadyForQueryCount ++
return nil
}
func (p *Pipeline ) GetResults () (results any , err error ) {
if p .expectedReadyForQueryCount == 0 {
return nil , nil
}
for {
msg , err := p .conn .receiveMessage ()
if err != nil {
return nil , err
}
switch msg := msg .(type ) {
case *pgproto3 .RowDescription :
p .conn .resultReader = ResultReader {
pgConn : p .conn ,
pipeline : p ,
ctx : p .ctx ,
fieldDescriptions : p .conn .convertRowDescription (p .conn .fieldDescriptions [:], msg ),
}
return &p .conn .resultReader , nil
case *pgproto3 .CommandComplete :
p .conn .resultReader = ResultReader {
commandTag : p .conn .makeCommandTag (msg .CommandTag ),
commandConcluded : true ,
closed : true ,
}
return &p .conn .resultReader , nil
case *pgproto3 .ParseComplete :
peekedMsg , err := p .conn .peekMessage ()
if err != nil {
return nil , err
}
if _ , ok := peekedMsg .(*pgproto3 .ParameterDescription ); ok {
return p .getResultsPrepare ()
}
case *pgproto3 .CloseComplete :
return &CloseComplete {}, nil
case *pgproto3 .ReadyForQuery :
p .expectedReadyForQueryCount --
return &PipelineSync {}, nil
case *pgproto3 .ErrorResponse :
pgErr := ErrorResponseToPgError (msg )
return nil , pgErr
}
}
}
func (p *Pipeline ) getResultsPrepare () (*StatementDescription , error ) {
psd := &StatementDescription {}
for {
msg , err := p .conn .receiveMessage ()
if err != nil {
p .conn .asyncClose ()
return nil , normalizeTimeoutError (p .ctx , err )
}
switch msg := msg .(type ) {
case *pgproto3 .ParameterDescription :
psd .ParamOIDs = make ([]uint32 , len (msg .ParameterOIDs ))
copy (psd .ParamOIDs , msg .ParameterOIDs )
case *pgproto3 .RowDescription :
psd .Fields = p .conn .convertRowDescription (nil , msg )
return psd , nil
case *pgproto3 .NoData :
return psd , nil
case *pgproto3 .ErrorResponse :
pgErr := ErrorResponseToPgError (msg )
return nil , pgErr
case *pgproto3 .CommandComplete :
p .conn .asyncClose ()
return nil , errors .New ("BUG: received CommandComplete while handling Describe" )
case *pgproto3 .ReadyForQuery :
p .conn .asyncClose ()
return nil , errors .New ("BUG: received ReadyForQuery while handling Describe" )
}
}
}
func (p *Pipeline ) Close () error {
if p .closed {
return p .err
}
p .closed = true
if p .pendingSync {
p .conn .asyncClose ()
p .err = errors .New ("pipeline has unsynced requests" )
p .conn .contextWatcher .Unwatch ()
p .conn .unlock ()
return p .err
}
for p .expectedReadyForQueryCount > 0 {
_ , err := p .GetResults ()
if err != nil {
p .err = err
var pgErr *PgError
if !errors .As (err , &pgErr ) {
p .conn .asyncClose ()
break
}
}
}
p .conn .contextWatcher .Unwatch ()
p .conn .unlock ()
return p .err
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .