package gorm
import (
"context"
"errors"
"fmt"
"reflect"
"sort"
"time"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func initializeCallbacks(db *DB ) *callbacks {
return &callbacks {
processors : map [string ]*processor {
"create" : {db : db },
"query" : {db : db },
"update" : {db : db },
"delete" : {db : db },
"row" : {db : db },
"raw" : {db : db },
},
}
}
type callbacks struct {
processors map [string ]*processor
}
type processor struct {
db *DB
Clauses []string
fns []func (*DB )
callbacks []*callback
}
type callback struct {
name string
before string
after string
remove bool
replace bool
match func (*DB ) bool
handler func (*DB )
processor *processor
}
func (cs *callbacks ) Create () *processor {
return cs .processors ["create" ]
}
func (cs *callbacks ) Query () *processor {
return cs .processors ["query" ]
}
func (cs *callbacks ) Update () *processor {
return cs .processors ["update" ]
}
func (cs *callbacks ) Delete () *processor {
return cs .processors ["delete" ]
}
func (cs *callbacks ) Row () *processor {
return cs .processors ["row" ]
}
func (cs *callbacks ) Raw () *processor {
return cs .processors ["raw" ]
}
func (p *processor ) Execute (db *DB ) *DB {
for len (db .Statement .scopes ) > 0 {
db = db .executeScopes ()
}
var (
curTime = time .Now ()
stmt = db .Statement
resetBuildClauses bool
)
if len (stmt .BuildClauses ) == 0 {
stmt .BuildClauses = p .Clauses
resetBuildClauses = true
}
if optimizer , ok := db .Statement .Dest .(StatementModifier ); ok {
optimizer .ModifyStatement (stmt )
}
if stmt .Model == nil {
stmt .Model = stmt .Dest
} else if stmt .Dest == nil {
stmt .Dest = stmt .Model
}
if stmt .Model != nil {
if err := stmt .Parse (stmt .Model ); err != nil && (!errors .Is (err , schema .ErrUnsupportedDataType ) || (stmt .Table == "" && stmt .TableExpr == nil && stmt .SQL .Len () == 0 )) {
if errors .Is (err , schema .ErrUnsupportedDataType ) && stmt .Table == "" && stmt .TableExpr == nil {
db .AddError (fmt .Errorf ("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")" , err ))
} else {
db .AddError (err )
}
}
}
if stmt .Dest != nil {
stmt .ReflectValue = reflect .ValueOf (stmt .Dest )
for stmt .ReflectValue .Kind () == reflect .Ptr {
if stmt .ReflectValue .IsNil () && stmt .ReflectValue .CanAddr () {
stmt .ReflectValue .Set (reflect .New (stmt .ReflectValue .Type ().Elem ()))
}
stmt .ReflectValue = stmt .ReflectValue .Elem ()
}
if !stmt .ReflectValue .IsValid () {
db .AddError (ErrInvalidValue )
}
}
for _ , f := range p .fns {
f (db )
}
if stmt .SQL .Len () > 0 {
db .Logger .Trace (stmt .Context , curTime , func () (string , int64 ) {
sql , vars := stmt .SQL .String (), stmt .Vars
if filter , ok := db .Logger .(ParamsFilter ); ok {
sql , vars = filter .ParamsFilter (stmt .Context , stmt .SQL .String (), stmt .Vars ...)
}
return db .Dialector .Explain (sql , vars ...), db .RowsAffected
}, db .Error )
}
if !stmt .DB .DryRun {
stmt .SQL .Reset ()
stmt .Vars = nil
}
if resetBuildClauses {
stmt .BuildClauses = nil
}
return db
}
func (p *processor ) Get (name string ) func (*DB ) {
for i := len (p .callbacks ) - 1 ; i >= 0 ; i -- {
if v := p .callbacks [i ]; v .name == name && !v .remove {
return v .handler
}
}
return nil
}
func (p *processor ) Before (name string ) *callback {
return &callback {before : name , processor : p }
}
func (p *processor ) After (name string ) *callback {
return &callback {after : name , processor : p }
}
func (p *processor ) Match (fc func (*DB ) bool ) *callback {
return &callback {match : fc , processor : p }
}
func (p *processor ) Register (name string , fn func (*DB )) error {
return (&callback {processor : p }).Register (name , fn )
}
func (p *processor ) Remove (name string ) error {
return (&callback {processor : p }).Remove (name )
}
func (p *processor ) Replace (name string , fn func (*DB )) error {
return (&callback {processor : p }).Replace (name , fn )
}
func (p *processor ) compile () (err error ) {
var callbacks []*callback
for _ , callback := range p .callbacks {
if callback .match == nil || callback .match (p .db ) {
callbacks = append (callbacks , callback )
}
}
p .callbacks = callbacks
if p .fns , err = sortCallbacks (p .callbacks ); err != nil {
p .db .Logger .Error (context .Background (), "Got error when compile callbacks, got %v" , err )
}
return
}
func (c *callback ) Before (name string ) *callback {
c .before = name
return c
}
func (c *callback ) After (name string ) *callback {
c .after = name
return c
}
func (c *callback ) Register (name string , fn func (*DB )) error {
c .name = name
c .handler = fn
c .processor .callbacks = append (c .processor .callbacks , c )
return c .processor .compile ()
}
func (c *callback ) Remove (name string ) error {
c .processor .db .Logger .Warn (context .Background (), "removing callback `%s` from %s\n" , name , utils .FileWithLineNum ())
c .name = name
c .remove = true
c .processor .callbacks = append (c .processor .callbacks , c )
return c .processor .compile ()
}
func (c *callback ) Replace (name string , fn func (*DB )) error {
c .processor .db .Logger .Info (context .Background (), "replacing callback `%s` from %s\n" , name , utils .FileWithLineNum ())
c .name = name
c .handler = fn
c .replace = true
c .processor .callbacks = append (c .processor .callbacks , c )
return c .processor .compile ()
}
func getRIndex(strs []string , str string ) int {
for i := len (strs ) - 1 ; i >= 0 ; i -- {
if strs [i ] == str {
return i
}
}
return -1
}
func sortCallbacks(cs []*callback ) (fns []func (*DB ), err error ) {
var (
names , sorted []string
sortCallback func (*callback ) error
)
sort .SliceStable (cs , func (i , j int ) bool {
if cs [j ].before == "*" && cs [i ].before != "*" {
return true
}
if cs [j ].after == "*" && cs [i ].after != "*" {
return true
}
return false
})
for _ , c := range cs {
if idx := getRIndex (names , c .name ); idx > -1 && !c .replace && !c .remove && !cs [idx ].remove {
c .processor .db .Logger .Warn (context .Background (), "duplicated callback `%s` from %s\n" , c .name , utils .FileWithLineNum ())
}
names = append (names , c .name )
}
sortCallback = func (c *callback ) error {
if c .before != "" {
if c .before == "*" && len (sorted ) > 0 {
if curIdx := getRIndex (sorted , c .name ); curIdx == -1 {
sorted = append ([]string {c .name }, sorted ...)
}
} else if sortedIdx := getRIndex (sorted , c .before ); sortedIdx != -1 {
if curIdx := getRIndex (sorted , c .name ); curIdx == -1 {
sorted = append (sorted [:sortedIdx ], append ([]string {c .name }, sorted [sortedIdx :]...)...)
} else if curIdx > sortedIdx {
return fmt .Errorf ("conflicting callback %s with before %s" , c .name , c .before )
}
} else if idx := getRIndex (names , c .before ); idx != -1 {
cs [idx ].after = c .name
}
}
if c .after != "" {
if c .after == "*" && len (sorted ) > 0 {
if curIdx := getRIndex (sorted , c .name ); curIdx == -1 {
sorted = append (sorted , c .name )
}
} else if sortedIdx := getRIndex (sorted , c .after ); sortedIdx != -1 {
if curIdx := getRIndex (sorted , c .name ); curIdx == -1 {
sorted = append (sorted , c .name )
} else if curIdx < sortedIdx {
return fmt .Errorf ("conflicting callback %s with before %s" , c .name , c .after )
}
} else if idx := getRIndex (names , c .after ); idx != -1 {
after := cs [idx ]
if after .before == "" {
after .before = c .name
}
if err := sortCallback (after ); err != nil {
return err
}
if err := sortCallback (c ); err != nil {
return err
}
}
}
if getRIndex (sorted , c .name ) == -1 {
sorted = append (sorted , c .name )
}
return nil
}
for _ , c := range cs {
if err = sortCallback (c ); err != nil {
return
}
}
for _ , name := range sorted {
if idx := getRIndex (names , name ); !cs [idx ].remove {
fns = append (fns , cs [idx ].handler )
}
}
return
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .