package pgproto3
import (
"bytes"
"encoding/binary"
"fmt"
"io"
)
type Backend struct {
cr *chunkReader
w io .Writer
tracer *tracer
wbuf []byte
bind Bind
cancelRequest CancelRequest
_close Close
copyFail CopyFail
copyData CopyData
copyDone CopyDone
describe Describe
execute Execute
flush Flush
functionCall FunctionCall
gssEncRequest GSSEncRequest
parse Parse
query Query
sslRequest SSLRequest
startupMessage StartupMessage
sync Sync
terminate Terminate
bodyLen int
msgType byte
partialMsg bool
authType uint32
}
const (
minStartupPacketLen = 4
maxStartupPacketLen = 10000
)
func NewBackend (r io .Reader , w io .Writer ) *Backend {
cr := newChunkReader (r , 0 )
return &Backend {cr : cr , w : w }
}
func (b *Backend ) Send (msg BackendMessage ) {
prevLen := len (b .wbuf )
b .wbuf = msg .Encode (b .wbuf )
if b .tracer != nil {
b .tracer .traceMessage ('B' , int32 (len (b .wbuf )-prevLen ), msg )
}
}
func (b *Backend ) Flush () error {
n , err := b .w .Write (b .wbuf )
const maxLen = 1024
if len (b .wbuf ) > maxLen {
b .wbuf = make ([]byte , 0 , maxLen )
} else {
b .wbuf = b .wbuf [:0 ]
}
if err != nil {
return &writeError {err : err , safeToRetry : n == 0 }
}
return nil
}
func (b *Backend ) Trace (w io .Writer , options TracerOptions ) {
b .tracer = &tracer {
w : w ,
buf : &bytes .Buffer {},
TracerOptions : options ,
}
}
func (b *Backend ) Untrace () {
b .tracer = nil
}
func (b *Backend ) ReceiveStartupMessage () (FrontendMessage , error ) {
buf , err := b .cr .Next (4 )
if err != nil {
return nil , err
}
msgSize := int (binary .BigEndian .Uint32 (buf ) - 4 )
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
return nil , fmt .Errorf ("invalid length of startup packet: %d" , msgSize )
}
buf , err = b .cr .Next (msgSize )
if err != nil {
return nil , translateEOFtoErrUnexpectedEOF (err )
}
code := binary .BigEndian .Uint32 (buf )
switch code {
case ProtocolVersionNumber :
err = b .startupMessage .Decode (buf )
if err != nil {
return nil , err
}
return &b .startupMessage , nil
case sslRequestNumber :
err = b .sslRequest .Decode (buf )
if err != nil {
return nil , err
}
return &b .sslRequest , nil
case cancelRequestCode :
err = b .cancelRequest .Decode (buf )
if err != nil {
return nil , err
}
return &b .cancelRequest , nil
case gssEncReqNumber :
err = b .gssEncRequest .Decode (buf )
if err != nil {
return nil , err
}
return &b .gssEncRequest , nil
default :
return nil , fmt .Errorf ("unknown startup message code: %d" , code )
}
}
func (b *Backend ) Receive () (FrontendMessage , error ) {
if !b .partialMsg {
header , err := b .cr .Next (5 )
if err != nil {
return nil , translateEOFtoErrUnexpectedEOF (err )
}
b .msgType = header [0 ]
b .bodyLen = int (binary .BigEndian .Uint32 (header [1 :])) - 4
b .partialMsg = true
}
var msg FrontendMessage
switch b .msgType {
case 'B' :
msg = &b .bind
case 'C' :
msg = &b ._close
case 'D' :
msg = &b .describe
case 'E' :
msg = &b .execute
case 'F' :
msg = &b .functionCall
case 'f' :
msg = &b .copyFail
case 'd' :
msg = &b .copyData
case 'c' :
msg = &b .copyDone
case 'H' :
msg = &b .flush
case 'P' :
msg = &b .parse
case 'p' :
switch b .authType {
case AuthTypeSASL :
msg = &SASLInitialResponse {}
case AuthTypeSASLContinue :
msg = &SASLResponse {}
case AuthTypeSASLFinal :
msg = &SASLResponse {}
case AuthTypeGSS , AuthTypeGSSCont :
msg = &GSSResponse {}
case AuthTypeCleartextPassword , AuthTypeMD5Password :
fallthrough
default :
msg = &PasswordMessage {}
}
case 'Q' :
msg = &b .query
case 'S' :
msg = &b .sync
case 'X' :
msg = &b .terminate
default :
return nil , fmt .Errorf ("unknown message type: %c" , b .msgType )
}
msgBody , err := b .cr .Next (b .bodyLen )
if err != nil {
return nil , translateEOFtoErrUnexpectedEOF (err )
}
b .partialMsg = false
err = msg .Decode (msgBody )
if err != nil {
return nil , err
}
if b .tracer != nil {
b .tracer .traceMessage ('F' , int32 (5 +len (msgBody )), msg )
}
return msg , nil
}
func (b *Backend ) SetAuthType (authType uint32 ) error {
switch authType {
case AuthTypeOk ,
AuthTypeCleartextPassword ,
AuthTypeMD5Password ,
AuthTypeSCMCreds ,
AuthTypeGSS ,
AuthTypeGSSCont ,
AuthTypeSSPI ,
AuthTypeSASL ,
AuthTypeSASLContinue ,
AuthTypeSASLFinal :
b .authType = authType
default :
return fmt .Errorf ("authType not recognized: %d" , authType )
}
return nil
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .