package pgx
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"github.com/jackc/pgx/v5/pgconn"
)
type TxIsoLevel string
const (
Serializable TxIsoLevel = "serializable"
RepeatableRead TxIsoLevel = "repeatable read"
ReadCommitted TxIsoLevel = "read committed"
ReadUncommitted TxIsoLevel = "read uncommitted"
)
type TxAccessMode string
const (
ReadWrite TxAccessMode = "read write"
ReadOnly TxAccessMode = "read only"
)
type TxDeferrableMode string
const (
Deferrable TxDeferrableMode = "deferrable"
NotDeferrable TxDeferrableMode = "not deferrable"
)
type TxOptions struct {
IsoLevel TxIsoLevel
AccessMode TxAccessMode
DeferrableMode TxDeferrableMode
BeginQuery string
}
var emptyTxOptions TxOptions
func (txOptions TxOptions ) beginSQL () string {
if txOptions == emptyTxOptions {
return "begin"
}
if txOptions .BeginQuery != "" {
return txOptions .BeginQuery
}
var buf strings .Builder
buf .Grow (64 )
buf .WriteString ("begin" )
if txOptions .IsoLevel != "" {
buf .WriteString (" isolation level " )
buf .WriteString (string (txOptions .IsoLevel ))
}
if txOptions .AccessMode != "" {
buf .WriteByte (' ' )
buf .WriteString (string (txOptions .AccessMode ))
}
if txOptions .DeferrableMode != "" {
buf .WriteByte (' ' )
buf .WriteString (string (txOptions .DeferrableMode ))
}
return buf .String ()
}
var ErrTxClosed = errors .New ("tx is closed" )
var ErrTxCommitRollback = errors .New ("commit unexpectedly resulted in rollback" )
func (c *Conn ) Begin (ctx context .Context ) (Tx , error ) {
return c .BeginTx (ctx , TxOptions {})
}
func (c *Conn ) BeginTx (ctx context .Context , txOptions TxOptions ) (Tx , error ) {
_ , err := c .Exec (ctx , txOptions .beginSQL ())
if err != nil {
c .die (errors .New ("failed to begin transaction" ))
return nil , err
}
return &dbTx {conn : c }, nil
}
type Tx interface {
Begin (ctx context .Context ) (Tx , error )
Commit (ctx context .Context ) error
Rollback (ctx context .Context ) error
CopyFrom (ctx context .Context , tableName Identifier , columnNames []string , rowSrc CopyFromSource ) (int64 , error )
SendBatch (ctx context .Context , b *Batch ) BatchResults
LargeObjects () LargeObjects
Prepare (ctx context .Context , name, sql string ) (*pgconn .StatementDescription , error )
Exec (ctx context .Context , sql string , arguments ...any ) (commandTag pgconn .CommandTag , err error )
Query (ctx context .Context , sql string , args ...any ) (Rows , error )
QueryRow (ctx context .Context , sql string , args ...any ) Row
Conn () *Conn
}
type dbTx struct {
conn *Conn
savepointNum int64
closed bool
}
func (tx *dbTx ) Begin (ctx context .Context ) (Tx , error ) {
if tx .closed {
return nil , ErrTxClosed
}
tx .savepointNum ++
_ , err := tx .conn .Exec (ctx , "savepoint sp_" +strconv .FormatInt (tx .savepointNum , 10 ))
if err != nil {
return nil , err
}
return &dbSimulatedNestedTx {tx : tx , savepointNum : tx .savepointNum }, nil
}
func (tx *dbTx ) Commit (ctx context .Context ) error {
if tx .closed {
return ErrTxClosed
}
commandTag , err := tx .conn .Exec (ctx , "commit" )
tx .closed = true
if err != nil {
if tx .conn .PgConn ().TxStatus () != 'I' {
_ = tx .conn .Close (ctx )
}
return err
}
if commandTag .String () == "ROLLBACK" {
return ErrTxCommitRollback
}
return nil
}
func (tx *dbTx ) Rollback (ctx context .Context ) error {
if tx .closed {
return ErrTxClosed
}
_ , err := tx .conn .Exec (ctx , "rollback" )
tx .closed = true
if err != nil {
tx .conn .die (fmt .Errorf ("rollback failed: %w" , err ))
return err
}
return nil
}
func (tx *dbTx ) Exec (ctx context .Context , sql string , arguments ...any ) (commandTag pgconn .CommandTag , err error ) {
if tx .closed {
return pgconn .CommandTag {}, ErrTxClosed
}
return tx .conn .Exec (ctx , sql , arguments ...)
}
func (tx *dbTx ) Prepare (ctx context .Context , name , sql string ) (*pgconn .StatementDescription , error ) {
if tx .closed {
return nil , ErrTxClosed
}
return tx .conn .Prepare (ctx , name , sql )
}
func (tx *dbTx ) Query (ctx context .Context , sql string , args ...any ) (Rows , error ) {
if tx .closed {
err := ErrTxClosed
return &baseRows {closed : true , err : err }, err
}
return tx .conn .Query (ctx , sql , args ...)
}
func (tx *dbTx ) QueryRow (ctx context .Context , sql string , args ...any ) Row {
rows , _ := tx .Query (ctx , sql , args ...)
return (*connRow )(rows .(*baseRows ))
}
func (tx *dbTx ) CopyFrom (ctx context .Context , tableName Identifier , columnNames []string , rowSrc CopyFromSource ) (int64 , error ) {
if tx .closed {
return 0 , ErrTxClosed
}
return tx .conn .CopyFrom (ctx , tableName , columnNames , rowSrc )
}
func (tx *dbTx ) SendBatch (ctx context .Context , b *Batch ) BatchResults {
if tx .closed {
return &batchResults {err : ErrTxClosed }
}
return tx .conn .SendBatch (ctx , b )
}
func (tx *dbTx ) LargeObjects () LargeObjects {
return LargeObjects {tx : tx }
}
func (tx *dbTx ) Conn () *Conn {
return tx .conn
}
type dbSimulatedNestedTx struct {
tx Tx
savepointNum int64
closed bool
}
func (sp *dbSimulatedNestedTx ) Begin (ctx context .Context ) (Tx , error ) {
if sp .closed {
return nil , ErrTxClosed
}
return sp .tx .Begin (ctx )
}
func (sp *dbSimulatedNestedTx ) Commit (ctx context .Context ) error {
if sp .closed {
return ErrTxClosed
}
_ , err := sp .Exec (ctx , "release savepoint sp_" +strconv .FormatInt (sp .savepointNum , 10 ))
sp .closed = true
return err
}
func (sp *dbSimulatedNestedTx ) Rollback (ctx context .Context ) error {
if sp .closed {
return ErrTxClosed
}
_ , err := sp .Exec (ctx , "rollback to savepoint sp_" +strconv .FormatInt (sp .savepointNum , 10 ))
sp .closed = true
return err
}
func (sp *dbSimulatedNestedTx ) Exec (ctx context .Context , sql string , arguments ...any ) (commandTag pgconn .CommandTag , err error ) {
if sp .closed {
return pgconn .CommandTag {}, ErrTxClosed
}
return sp .tx .Exec (ctx , sql , arguments ...)
}
func (sp *dbSimulatedNestedTx ) Prepare (ctx context .Context , name , sql string ) (*pgconn .StatementDescription , error ) {
if sp .closed {
return nil , ErrTxClosed
}
return sp .tx .Prepare (ctx , name , sql )
}
func (sp *dbSimulatedNestedTx ) Query (ctx context .Context , sql string , args ...any ) (Rows , error ) {
if sp .closed {
err := ErrTxClosed
return &baseRows {closed : true , err : err }, err
}
return sp .tx .Query (ctx , sql , args ...)
}
func (sp *dbSimulatedNestedTx ) QueryRow (ctx context .Context , sql string , args ...any ) Row {
rows , _ := sp .Query (ctx , sql , args ...)
return (*connRow )(rows .(*baseRows ))
}
func (sp *dbSimulatedNestedTx ) CopyFrom (ctx context .Context , tableName Identifier , columnNames []string , rowSrc CopyFromSource ) (int64 , error ) {
if sp .closed {
return 0 , ErrTxClosed
}
return sp .tx .CopyFrom (ctx , tableName , columnNames , rowSrc )
}
func (sp *dbSimulatedNestedTx ) SendBatch (ctx context .Context , b *Batch ) BatchResults {
if sp .closed {
return &batchResults {err : ErrTxClosed }
}
return sp .tx .SendBatch (ctx , b )
}
func (sp *dbSimulatedNestedTx ) LargeObjects () LargeObjects {
return LargeObjects {tx : sp }
}
func (sp *dbSimulatedNestedTx ) Conn () *Conn {
return sp .tx .Conn ()
}
func BeginFunc (
ctx context .Context ,
db interface {
Begin (ctx context .Context ) (Tx , error )
},
fn func (Tx ) error ,
) (err error ) {
var tx Tx
tx , err = db .Begin (ctx )
if err != nil {
return err
}
return beginFuncExec (ctx , tx , fn )
}
func BeginTxFunc (
ctx context .Context ,
db interface {
BeginTx (ctx context .Context , txOptions TxOptions ) (Tx , error )
},
txOptions TxOptions ,
fn func (Tx ) error ,
) (err error ) {
var tx Tx
tx , err = db .BeginTx (ctx , txOptions )
if err != nil {
return err
}
return beginFuncExec (ctx , tx , fn )
}
func beginFuncExec(ctx context .Context , tx Tx , fn func (Tx ) error ) (err error ) {
defer func () {
rollbackErr := tx .Rollback (ctx )
if rollbackErr != nil && !errors .Is (rollbackErr , ErrTxClosed ) {
err = rollbackErr
}
}()
fErr := fn (tx )
if fErr != nil {
_ = tx .Rollback (ctx )
return fErr
}
return tx .Commit (ctx )
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .