package gorm
import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"sync"
"time"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
const preparedStmtDBKey = "preparedStmt"
type Config struct {
SkipDefaultTransaction bool
NamingStrategy schema .Namer
FullSaveAssociations bool
Logger logger .Interface
NowFunc func () time .Time
DryRun bool
PrepareStmt bool
DisableAutomaticPing bool
DisableForeignKeyConstraintWhenMigrating bool
IgnoreRelationshipsWhenMigrating bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
QueryFields bool
CreateBatchSize int
TranslateError bool
ClauseBuilders map [string ]clause .ClauseBuilder
ConnPool ConnPool
Dialector
Plugins map [string ]Plugin
callbacks *callbacks
cacheStore *sync .Map
}
func (c *Config ) Apply (config *Config ) error {
if config != c {
*config = *c
}
return nil
}
func (c *Config ) AfterInitialize (db *DB ) error {
if db != nil {
for _ , plugin := range c .Plugins {
if err := plugin .Initialize (db ); err != nil {
return err
}
}
}
return nil
}
type Option interface {
Apply (*Config ) error
AfterInitialize (*DB ) error
}
type DB struct {
*Config
Error error
RowsAffected int64
Statement *Statement
clone int
}
type Session struct {
DryRun bool
PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
QueryFields bool
Context context .Context
Logger logger .Interface
NowFunc func () time .Time
CreateBatchSize int
}
func Open (dialector Dialector , opts ...Option ) (db *DB , err error ) {
config := &Config {}
sort .Slice (opts , func (i , j int ) bool {
_ , isConfig := opts [i ].(*Config )
_ , isConfig2 := opts [j ].(*Config )
return isConfig && !isConfig2
})
for _ , opt := range opts {
if opt != nil {
if applyErr := opt .Apply (config ); applyErr != nil {
return nil , applyErr
}
defer func (opt Option ) {
if errr := opt .AfterInitialize (db ); errr != nil {
err = errr
}
}(opt )
}
}
if d , ok := dialector .(interface { Apply (*Config ) error }); ok {
if err = d .Apply (config ); err != nil {
return
}
}
if config .NamingStrategy == nil {
config .NamingStrategy = schema .NamingStrategy {IdentifierMaxLength : 64 }
}
if config .Logger == nil {
config .Logger = logger .Default
}
if config .NowFunc == nil {
config .NowFunc = func () time .Time { return time .Now ().Local () }
}
if dialector != nil {
config .Dialector = dialector
}
if config .Plugins == nil {
config .Plugins = map [string ]Plugin {}
}
if config .cacheStore == nil {
config .cacheStore = &sync .Map {}
}
db = &DB {Config : config , clone : 1 }
db .callbacks = initializeCallbacks (db )
if config .ClauseBuilders == nil {
config .ClauseBuilders = map [string ]clause .ClauseBuilder {}
}
if config .Dialector != nil {
err = config .Dialector .Initialize (db )
if err != nil {
if db , _ := db .DB (); db != nil {
_ = db .Close ()
}
}
}
if config .PrepareStmt {
preparedStmt := NewPreparedStmtDB (db .ConnPool )
db .cacheStore .Store (preparedStmtDBKey , preparedStmt )
db .ConnPool = preparedStmt
}
db .Statement = &Statement {
DB : db ,
ConnPool : db .ConnPool ,
Context : context .Background (),
Clauses : map [string ]clause .Clause {},
}
if err == nil && !config .DisableAutomaticPing {
if pinger , ok := db .ConnPool .(interface { Ping () error }); ok {
err = pinger .Ping ()
}
}
if err != nil {
config .Logger .Error (context .Background (), "failed to initialize database, got error %v" , err )
}
return
}
func (db *DB ) Session (config *Session ) *DB {
var (
txConfig = *db .Config
tx = &DB {
Config : &txConfig ,
Statement : db .Statement ,
Error : db .Error ,
clone : 1 ,
}
)
if config .CreateBatchSize > 0 {
tx .Config .CreateBatchSize = config .CreateBatchSize
}
if config .SkipDefaultTransaction {
tx .Config .SkipDefaultTransaction = true
}
if config .AllowGlobalUpdate {
txConfig .AllowGlobalUpdate = true
}
if config .FullSaveAssociations {
txConfig .FullSaveAssociations = true
}
if config .Context != nil || config .PrepareStmt || config .SkipHooks {
tx .Statement = tx .Statement .clone ()
tx .Statement .DB = tx
}
if config .Context != nil {
tx .Statement .Context = config .Context
}
if config .PrepareStmt {
var preparedStmt *PreparedStmtDB
if v , ok := db .cacheStore .Load (preparedStmtDBKey ); ok {
preparedStmt = v .(*PreparedStmtDB )
} else {
preparedStmt = NewPreparedStmtDB (db .ConnPool )
db .cacheStore .Store (preparedStmtDBKey , preparedStmt )
}
switch t := tx .Statement .ConnPool .(type ) {
case Tx :
tx .Statement .ConnPool = &PreparedStmtTX {
Tx : t ,
PreparedStmtDB : preparedStmt ,
}
default :
tx .Statement .ConnPool = &PreparedStmtDB {
ConnPool : db .Config .ConnPool ,
Mux : preparedStmt .Mux ,
Stmts : preparedStmt .Stmts ,
}
}
txConfig .ConnPool = tx .Statement .ConnPool
txConfig .PrepareStmt = true
}
if config .SkipHooks {
tx .Statement .SkipHooks = true
}
if config .DisableNestedTransaction {
txConfig .DisableNestedTransaction = true
}
if !config .NewDB {
tx .clone = 2
}
if config .DryRun {
tx .Config .DryRun = true
}
if config .QueryFields {
tx .Config .QueryFields = true
}
if config .Logger != nil {
tx .Config .Logger = config .Logger
}
if config .NowFunc != nil {
tx .Config .NowFunc = config .NowFunc
}
if config .Initialized {
tx = tx .getInstance ()
}
return tx
}
func (db *DB ) WithContext (ctx context .Context ) *DB {
return db .Session (&Session {Context : ctx })
}
func (db *DB ) Debug () (tx *DB ) {
tx = db .getInstance ()
return tx .Session (&Session {
Logger : db .Logger .LogMode (logger .Info ),
})
}
func (db *DB ) Set (key string , value interface {}) *DB {
tx := db .getInstance ()
tx .Statement .Settings .Store (key , value )
return tx
}
func (db *DB ) Get (key string ) (interface {}, bool ) {
return db .Statement .Settings .Load (key )
}
func (db *DB ) InstanceSet (key string , value interface {}) *DB {
tx := db .getInstance ()
tx .Statement .Settings .Store (fmt .Sprintf ("%p" , tx .Statement )+key , value )
return tx
}
func (db *DB ) InstanceGet (key string ) (interface {}, bool ) {
return db .Statement .Settings .Load (fmt .Sprintf ("%p" , db .Statement ) + key )
}
func (db *DB ) Callback () *callbacks {
return db .callbacks
}
func (db *DB ) AddError (err error ) error {
if err != nil {
if db .Config .TranslateError {
if errTranslator , ok := db .Dialector .(ErrorTranslator ); ok {
err = errTranslator .Translate (err )
}
}
if db .Error == nil {
db .Error = err
} else {
db .Error = fmt .Errorf ("%v; %w" , db .Error , err )
}
}
return db .Error
}
func (db *DB ) DB () (*sql .DB , error ) {
connPool := db .ConnPool
if db .Statement != nil && db .Statement .ConnPool != nil {
connPool = db .Statement .ConnPool
}
if tx , ok := connPool .(*sql .Tx ); ok && tx != nil {
return (*sql .DB )(reflect .ValueOf (tx ).Elem ().FieldByName ("db" ).UnsafePointer ()), nil
}
if dbConnector , ok := connPool .(GetDBConnector ); ok && dbConnector != nil {
if sqldb , err := dbConnector .GetDBConn (); sqldb != nil || err != nil {
return sqldb , err
}
}
if sqldb , ok := connPool .(*sql .DB ); ok && sqldb != nil {
return sqldb , nil
}
return nil , ErrInvalidDB
}
func (db *DB ) getInstance () *DB {
if db .clone > 0 {
tx := &DB {Config : db .Config , Error : db .Error }
if db .clone == 1 {
tx .Statement = &Statement {
DB : tx ,
ConnPool : db .Statement .ConnPool ,
Context : db .Statement .Context ,
Clauses : map [string ]clause .Clause {},
Vars : make ([]interface {}, 0 , 8 ),
SkipHooks : db .Statement .SkipHooks ,
}
} else {
tx .Statement = db .Statement .clone ()
tx .Statement .DB = tx
}
return tx
}
return db
}
func Expr (expr string , args ...interface {}) clause .Expr {
return clause .Expr {SQL : expr , Vars : args }
}
func (db *DB ) SetupJoinTable (model interface {}, field string , joinTable interface {}) error {
var (
tx = db .getInstance ()
stmt = tx .Statement
modelSchema , joinSchema *schema .Schema
)
err := stmt .Parse (model )
if err != nil {
return err
}
modelSchema = stmt .Schema
err = stmt .Parse (joinTable )
if err != nil {
return err
}
joinSchema = stmt .Schema
relation , ok := modelSchema .Relationships .Relations [field ]
isRelation := ok && relation .JoinTable != nil
if !isRelation {
return fmt .Errorf ("failed to find relation: %s" , field )
}
for _ , ref := range relation .References {
f := joinSchema .LookUpField (ref .ForeignKey .DBName )
if f == nil {
return fmt .Errorf ("missing field %s for join table" , ref .ForeignKey .DBName )
}
f .DataType = ref .ForeignKey .DataType
f .GORMDataType = ref .ForeignKey .GORMDataType
if f .Size == 0 {
f .Size = ref .ForeignKey .Size
}
ref .ForeignKey = f
}
for name , rel := range relation .JoinTable .Relationships .Relations {
if _ , ok := joinSchema .Relationships .Relations [name ]; !ok {
rel .Schema = joinSchema
joinSchema .Relationships .Relations [name ] = rel
}
}
relation .JoinTable = joinSchema
return nil
}
func (db *DB ) Use (plugin Plugin ) error {
name := plugin .Name ()
if _ , ok := db .Plugins [name ]; ok {
return ErrRegistered
}
if err := plugin .Initialize (db ); err != nil {
return err
}
db .Plugins [name ] = plugin
return nil
}
func (db *DB ) ToSQL (queryFn func (tx *DB ) *DB ) string {
tx := queryFn (db .Session (&Session {DryRun : true , SkipDefaultTransaction : true }))
stmt := tx .Statement
return db .Dialector .Explain (stmt .SQL .String (), stmt .Vars ...)
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .