package pgx
import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"time"
"github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
type Rows interface {
Close ()
Err () error
CommandTag () pgconn .CommandTag
FieldDescriptions () []pgconn .FieldDescription
Next () bool
Scan (dest ...any ) error
Values () ([]any , error )
RawValues () [][]byte
Conn () *Conn
}
type Row interface {
Scan (dest ...any ) error
}
type RowScanner interface {
ScanRow (rows Rows ) error
}
type connRow baseRows
func (r *connRow ) Scan (dest ...any ) (err error ) {
rows := (*baseRows )(r )
if rows .Err () != nil {
return rows .Err ()
}
for _ , d := range dest {
if _ , ok := d .(*pgtype .DriverBytes ); ok {
rows .Close ()
return fmt .Errorf ("cannot scan into *pgtype.DriverBytes from QueryRow" )
}
}
if !rows .Next () {
if rows .Err () == nil {
return ErrNoRows
}
return rows .Err ()
}
rows .Scan (dest ...)
rows .Close ()
return rows .Err ()
}
type baseRows struct {
typeMap *pgtype .Map
resultReader *pgconn .ResultReader
values [][]byte
commandTag pgconn .CommandTag
err error
closed bool
scanPlans []pgtype .ScanPlan
scanTypes []reflect .Type
conn *Conn
multiResultReader *pgconn .MultiResultReader
queryTracer QueryTracer
batchTracer BatchTracer
ctx context .Context
startTime time .Time
sql string
args []any
rowCount int
}
func (rows *baseRows ) FieldDescriptions () []pgconn .FieldDescription {
return rows .resultReader .FieldDescriptions ()
}
func (rows *baseRows ) Close () {
if rows .closed {
return
}
rows .closed = true
if rows .resultReader != nil {
var closeErr error
rows .commandTag , closeErr = rows .resultReader .Close ()
if rows .err == nil {
rows .err = closeErr
}
}
if rows .multiResultReader != nil {
closeErr := rows .multiResultReader .Close ()
if rows .err == nil {
rows .err = closeErr
}
}
if rows .err != nil && rows .conn != nil && rows .sql != "" {
if stmtcache .IsStatementInvalid (rows .err ) {
if sc := rows .conn .statementCache ; sc != nil {
sc .Invalidate (rows .sql )
}
if sc := rows .conn .descriptionCache ; sc != nil {
sc .Invalidate (rows .sql )
}
}
}
if rows .batchTracer != nil {
rows .batchTracer .TraceBatchQuery (rows .ctx , rows .conn , TraceBatchQueryData {SQL : rows .sql , Args : rows .args , CommandTag : rows .commandTag , Err : rows .err })
} else if rows .queryTracer != nil {
rows .queryTracer .TraceQueryEnd (rows .ctx , rows .conn , TraceQueryEndData {rows .commandTag , rows .err })
}
}
func (rows *baseRows ) CommandTag () pgconn .CommandTag {
return rows .commandTag
}
func (rows *baseRows ) Err () error {
return rows .err
}
func (rows *baseRows ) fatal (err error ) {
if rows .err != nil {
return
}
rows .err = err
rows .Close ()
}
func (rows *baseRows ) Next () bool {
if rows .closed {
return false
}
if rows .resultReader .NextRow () {
rows .rowCount ++
rows .values = rows .resultReader .Values ()
return true
} else {
rows .Close ()
return false
}
}
func (rows *baseRows ) Scan (dest ...any ) error {
m := rows .typeMap
fieldDescriptions := rows .FieldDescriptions ()
values := rows .values
if len (fieldDescriptions ) != len (values ) {
err := fmt .Errorf ("number of field descriptions must equal number of values, got %d and %d" , len (fieldDescriptions ), len (values ))
rows .fatal (err )
return err
}
if len (dest ) == 1 {
if rc , ok := dest [0 ].(RowScanner ); ok {
err := rc .ScanRow (rows )
if err != nil {
rows .fatal (err )
}
return err
}
}
if len (fieldDescriptions ) != len (dest ) {
err := fmt .Errorf ("number of field descriptions must equal number of destinations, got %d and %d" , len (fieldDescriptions ), len (dest ))
rows .fatal (err )
return err
}
if rows .scanPlans == nil {
rows .scanPlans = make ([]pgtype .ScanPlan , len (values ))
rows .scanTypes = make ([]reflect .Type , len (values ))
for i := range dest {
rows .scanPlans [i ] = m .PlanScan (fieldDescriptions [i ].DataTypeOID , fieldDescriptions [i ].Format , dest [i ])
rows .scanTypes [i ] = reflect .TypeOf (dest [i ])
}
}
for i , dst := range dest {
if dst == nil {
continue
}
if rows .scanTypes [i ] != reflect .TypeOf (dst ) {
rows .scanPlans [i ] = m .PlanScan (fieldDescriptions [i ].DataTypeOID , fieldDescriptions [i ].Format , dest [i ])
rows .scanTypes [i ] = reflect .TypeOf (dest [i ])
}
err := rows .scanPlans [i ].Scan (values [i ], dst )
if err != nil {
err = ScanArgError {ColumnIndex : i , Err : err }
rows .fatal (err )
return err
}
}
return nil
}
func (rows *baseRows ) Values () ([]any , error ) {
if rows .closed {
return nil , errors .New ("rows is closed" )
}
values := make ([]any , 0 , len (rows .FieldDescriptions ()))
for i := range rows .FieldDescriptions () {
buf := rows .values [i ]
fd := &rows .FieldDescriptions ()[i ]
if buf == nil {
values = append (values , nil )
continue
}
if dt , ok := rows .typeMap .TypeForOID (fd .DataTypeOID ); ok {
value , err := dt .Codec .DecodeValue (rows .typeMap , fd .DataTypeOID , fd .Format , buf )
if err != nil {
rows .fatal (err )
}
values = append (values , value )
} else {
switch fd .Format {
case TextFormatCode :
values = append (values , string (buf ))
case BinaryFormatCode :
newBuf := make ([]byte , len (buf ))
copy (newBuf , buf )
values = append (values , newBuf )
default :
rows .fatal (errors .New ("unknown format code" ))
}
}
if rows .Err () != nil {
return nil , rows .Err ()
}
}
return values , rows .Err ()
}
func (rows *baseRows ) RawValues () [][]byte {
return rows .values
}
func (rows *baseRows ) Conn () *Conn {
return rows .conn
}
type ScanArgError struct {
ColumnIndex int
Err error
}
func (e ScanArgError ) Error () string {
return fmt .Sprintf ("can't scan into dest[%d]: %v" , e .ColumnIndex , e .Err )
}
func (e ScanArgError ) Unwrap () error {
return e .Err
}
func ScanRow (typeMap *pgtype .Map , fieldDescriptions []pgconn .FieldDescription , values [][]byte , dest ...any ) error {
if len (fieldDescriptions ) != len (values ) {
return fmt .Errorf ("number of field descriptions must equal number of values, got %d and %d" , len (fieldDescriptions ), len (values ))
}
if len (fieldDescriptions ) != len (dest ) {
return fmt .Errorf ("number of field descriptions must equal number of destinations, got %d and %d" , len (fieldDescriptions ), len (dest ))
}
for i , d := range dest {
if d == nil {
continue
}
err := typeMap .Scan (fieldDescriptions [i ].DataTypeOID , fieldDescriptions [i ].Format , values [i ], d )
if err != nil {
return ScanArgError {ColumnIndex : i , Err : err }
}
}
return nil
}
func RowsFromResultReader (typeMap *pgtype .Map , resultReader *pgconn .ResultReader ) Rows {
return &baseRows {
typeMap : typeMap ,
resultReader : resultReader ,
}
}
func ForEachRow (rows Rows , scans []any , fn func () error ) (pgconn .CommandTag , error ) {
defer rows .Close ()
for rows .Next () {
err := rows .Scan (scans ...)
if err != nil {
return pgconn .CommandTag {}, err
}
err = fn ()
if err != nil {
return pgconn .CommandTag {}, err
}
}
if err := rows .Err (); err != nil {
return pgconn .CommandTag {}, err
}
return rows .CommandTag (), nil
}
type CollectableRow interface {
FieldDescriptions () []pgconn .FieldDescription
Scan (dest ...any ) error
Values () ([]any , error )
RawValues () [][]byte
}
type RowToFunc [T any ] func (row CollectableRow ) (T , error )
func CollectRows [T any ](rows Rows , fn RowToFunc [T ]) ([]T , error ) {
defer rows .Close ()
slice := []T {}
for rows .Next () {
value , err := fn (rows )
if err != nil {
return nil , err
}
slice = append (slice , value )
}
if err := rows .Err (); err != nil {
return nil , err
}
return slice , nil
}
func CollectOneRow [T any ](rows Rows , fn RowToFunc [T ]) (T , error ) {
defer rows .Close ()
var value T
var err error
if !rows .Next () {
if err = rows .Err (); err != nil {
return value , err
}
return value , ErrNoRows
}
value , err = fn (rows )
if err != nil {
return value , err
}
rows .Close ()
return value , rows .Err ()
}
func RowTo [T any ](row CollectableRow ) (T , error ) {
var value T
err := row .Scan (&value )
return value , err
}
func RowToAddrOf [T any ](row CollectableRow ) (*T , error ) {
var value T
err := row .Scan (&value )
return &value , err
}
func RowToMap (row CollectableRow ) (map [string ]any , error ) {
var value map [string ]any
err := row .Scan ((*mapRowScanner )(&value ))
return value , err
}
type mapRowScanner map [string ]any
func (rs *mapRowScanner ) ScanRow (rows Rows ) error {
values , err := rows .Values ()
if err != nil {
return err
}
*rs = make (mapRowScanner , len (values ))
for i := range values {
(*rs )[string (rows .FieldDescriptions ()[i ].Name )] = values [i ]
}
return nil
}
func RowToStructByPos [T any ](row CollectableRow ) (T , error ) {
var value T
err := row .Scan (&positionalStructRowScanner {ptrToStruct : &value })
return value , err
}
func RowToAddrOfStructByPos [T any ](row CollectableRow ) (*T , error ) {
var value T
err := row .Scan (&positionalStructRowScanner {ptrToStruct : &value })
return &value , err
}
type positionalStructRowScanner struct {
ptrToStruct any
}
func (rs *positionalStructRowScanner ) ScanRow (rows Rows ) error {
dst := rs .ptrToStruct
dstValue := reflect .ValueOf (dst )
if dstValue .Kind () != reflect .Ptr {
return fmt .Errorf ("dst not a pointer" )
}
dstElemValue := dstValue .Elem ()
scanTargets := rs .appendScanTargets (dstElemValue , nil )
if len (rows .RawValues ()) > len (scanTargets ) {
return fmt .Errorf ("got %d values, but dst struct has only %d fields" , len (rows .RawValues ()), len (scanTargets ))
}
return rows .Scan (scanTargets ...)
}
func (rs *positionalStructRowScanner ) appendScanTargets (dstElemValue reflect .Value , scanTargets []any ) []any {
dstElemType := dstElemValue .Type ()
if scanTargets == nil {
scanTargets = make ([]any , 0 , dstElemType .NumField ())
}
for i := 0 ; i < dstElemType .NumField (); i ++ {
sf := dstElemType .Field (i )
if sf .Anonymous && sf .Type .Kind () == reflect .Struct {
scanTargets = rs .appendScanTargets (dstElemValue .Field (i ), scanTargets )
} else if sf .PkgPath == "" {
dbTag , _ := sf .Tag .Lookup (structTagKey )
if dbTag == "-" {
continue
}
scanTargets = append (scanTargets , dstElemValue .Field (i ).Addr ().Interface ())
}
}
return scanTargets
}
func RowToStructByName [T any ](row CollectableRow ) (T , error ) {
var value T
err := row .Scan (&namedStructRowScanner {ptrToStruct : &value })
return value , err
}
func RowToAddrOfStructByName [T any ](row CollectableRow ) (*T , error ) {
var value T
err := row .Scan (&namedStructRowScanner {ptrToStruct : &value })
return &value , err
}
func RowToStructByNameLax [T any ](row CollectableRow ) (T , error ) {
var value T
err := row .Scan (&namedStructRowScanner {ptrToStruct : &value , lax : true })
return value , err
}
func RowToAddrOfStructByNameLax [T any ](row CollectableRow ) (*T , error ) {
var value T
err := row .Scan (&namedStructRowScanner {ptrToStruct : &value , lax : true })
return &value , err
}
type namedStructRowScanner struct {
ptrToStruct any
lax bool
}
func (rs *namedStructRowScanner ) ScanRow (rows Rows ) error {
dst := rs .ptrToStruct
dstValue := reflect .ValueOf (dst )
if dstValue .Kind () != reflect .Ptr {
return fmt .Errorf ("dst not a pointer" )
}
dstElemValue := dstValue .Elem ()
scanTargets , err := rs .appendScanTargets (dstElemValue , nil , rows .FieldDescriptions ())
if err != nil {
return err
}
for i , t := range scanTargets {
if t == nil {
return fmt .Errorf ("struct doesn't have corresponding row field %s" , rows .FieldDescriptions ()[i ].Name )
}
}
return rows .Scan (scanTargets ...)
}
const structTagKey = "db"
func fieldPosByName(fldDescs []pgconn .FieldDescription , field string ) (i int ) {
i = -1
for i , desc := range fldDescs {
if strings .EqualFold (desc .Name , field ) {
return i
}
}
return
}
func (rs *namedStructRowScanner ) appendScanTargets (dstElemValue reflect .Value , scanTargets []any , fldDescs []pgconn .FieldDescription ) ([]any , error ) {
var err error
dstElemType := dstElemValue .Type ()
if scanTargets == nil {
scanTargets = make ([]any , len (fldDescs ))
}
for i := 0 ; i < dstElemType .NumField (); i ++ {
sf := dstElemType .Field (i )
if sf .PkgPath != "" && !sf .Anonymous {
continue
}
if sf .Anonymous && sf .Type .Kind () == reflect .Struct {
scanTargets , err = rs .appendScanTargets (dstElemValue .Field (i ), scanTargets , fldDescs )
if err != nil {
return nil , err
}
} else {
dbTag , dbTagPresent := sf .Tag .Lookup (structTagKey )
if dbTagPresent {
dbTag = strings .Split (dbTag , "," )[0 ]
}
if dbTag == "-" {
continue
}
colName := dbTag
if !dbTagPresent {
colName = sf .Name
}
fpos := fieldPosByName (fldDescs , colName )
if fpos == -1 {
if rs .lax {
continue
}
return nil , fmt .Errorf ("cannot find field %s in returned row" , colName )
}
if fpos >= len (scanTargets ) && !rs .lax {
return nil , fmt .Errorf ("cannot find field %s in returned row" , colName )
}
scanTargets [fpos ] = dstElemValue .Field (i ).Addr ().Interface ()
}
}
return scanTargets , err
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .