package callbacks
import (
"reflect"
"sort"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func SetupUpdateReflectValue (db *gorm .DB ) {
if db .Error == nil && db .Statement .Schema != nil {
if !db .Statement .ReflectValue .CanAddr () || db .Statement .Model != db .Statement .Dest {
db .Statement .ReflectValue = reflect .ValueOf (db .Statement .Model )
for db .Statement .ReflectValue .Kind () == reflect .Ptr {
db .Statement .ReflectValue = db .Statement .ReflectValue .Elem ()
}
if dest , ok := db .Statement .Dest .(map [string ]interface {}); ok {
for _ , rel := range db .Statement .Schema .Relationships .BelongsTo {
if _ , ok := dest [rel .Name ]; ok {
db .AddError (rel .Field .Set (db .Statement .Context , db .Statement .ReflectValue , dest [rel .Name ]))
}
}
}
}
}
}
func BeforeUpdate (db *gorm .DB ) {
if db .Error == nil && db .Statement .Schema != nil && !db .Statement .SkipHooks && (db .Statement .Schema .BeforeSave || db .Statement .Schema .BeforeUpdate ) {
callMethod (db , func (value interface {}, tx *gorm .DB ) (called bool ) {
if db .Statement .Schema .BeforeSave {
if i , ok := value .(BeforeSaveInterface ); ok {
called = true
db .AddError (i .BeforeSave (tx ))
}
}
if db .Statement .Schema .BeforeUpdate {
if i , ok := value .(BeforeUpdateInterface ); ok {
called = true
db .AddError (i .BeforeUpdate (tx ))
}
}
return called
})
}
}
func Update (config *Config ) func (db *gorm .DB ) {
supportReturning := utils .Contains (config .UpdateClauses , "RETURNING" )
return func (db *gorm .DB ) {
if db .Error != nil {
return
}
if db .Statement .Schema != nil {
for _ , c := range db .Statement .Schema .UpdateClauses {
db .Statement .AddClause (c )
}
}
if db .Statement .SQL .Len () == 0 {
db .Statement .SQL .Grow (180 )
db .Statement .AddClauseIfNotExists (clause .Update {})
if _ , ok := db .Statement .Clauses ["SET" ]; !ok {
if set := ConvertToAssignments (db .Statement ); len (set ) != 0 {
defer delete (db .Statement .Clauses , "SET" )
db .Statement .AddClause (set )
} else {
return
}
}
db .Statement .Build (db .Statement .BuildClauses ...)
}
checkMissingWhereConditions (db )
if !db .DryRun && db .Error == nil {
if ok , mode := hasReturning (db , supportReturning ); ok {
if rows , err := db .Statement .ConnPool .QueryContext (db .Statement .Context , db .Statement .SQL .String (), db .Statement .Vars ...); db .AddError (err ) == nil {
dest := db .Statement .Dest
db .Statement .Dest = db .Statement .ReflectValue .Addr ().Interface ()
gorm .Scan (rows , db , mode )
db .Statement .Dest = dest
db .AddError (rows .Close ())
}
} else {
result , err := db .Statement .ConnPool .ExecContext (db .Statement .Context , db .Statement .SQL .String (), db .Statement .Vars ...)
if db .AddError (err ) == nil {
db .RowsAffected , _ = result .RowsAffected ()
}
}
}
}
}
func AfterUpdate (db *gorm .DB ) {
if db .Error == nil && db .Statement .Schema != nil && !db .Statement .SkipHooks && (db .Statement .Schema .AfterSave || db .Statement .Schema .AfterUpdate ) {
callMethod (db , func (value interface {}, tx *gorm .DB ) (called bool ) {
if db .Statement .Schema .AfterUpdate {
if i , ok := value .(AfterUpdateInterface ); ok {
called = true
db .AddError (i .AfterUpdate (tx ))
}
}
if db .Statement .Schema .AfterSave {
if i , ok := value .(AfterSaveInterface ); ok {
called = true
db .AddError (i .AfterSave (tx ))
}
}
return called
})
}
}
func ConvertToAssignments (stmt *gorm .Statement ) (set clause .Set ) {
var (
selectColumns , restricted = stmt .SelectAndOmitColumns (false , true )
assignValue func (field *schema .Field , value interface {})
)
switch stmt .ReflectValue .Kind () {
case reflect .Slice , reflect .Array :
assignValue = func (field *schema .Field , value interface {}) {
for i := 0 ; i < stmt .ReflectValue .Len (); i ++ {
if stmt .ReflectValue .CanAddr () {
field .Set (stmt .Context , stmt .ReflectValue .Index (i ), value )
}
}
}
case reflect .Struct :
assignValue = func (field *schema .Field , value interface {}) {
if stmt .ReflectValue .CanAddr () {
field .Set (stmt .Context , stmt .ReflectValue , value )
}
}
default :
assignValue = func (field *schema .Field , value interface {}) {
}
}
updatingValue := reflect .ValueOf (stmt .Dest )
for updatingValue .Kind () == reflect .Ptr {
updatingValue = updatingValue .Elem ()
}
if !updatingValue .CanAddr () || stmt .Dest != stmt .Model {
switch stmt .ReflectValue .Kind () {
case reflect .Slice , reflect .Array :
if size := stmt .ReflectValue .Len (); size > 0 {
var isZero bool
for i := 0 ; i < size ; i ++ {
for _ , field := range stmt .Schema .PrimaryFields {
_, isZero = field .ValueOf (stmt .Context , stmt .ReflectValue .Index (i ))
if !isZero {
break
}
}
}
if !isZero {
_ , primaryValues := schema .GetIdentityFieldValuesMap (stmt .Context , stmt .ReflectValue , stmt .Schema .PrimaryFields )
column , values := schema .ToQueryValues ("" , stmt .Schema .PrimaryFieldDBNames , primaryValues )
stmt .AddClause (clause .Where {Exprs : []clause .Expression {clause .IN {Column : column , Values : values }}})
}
}
case reflect .Struct :
for _ , field := range stmt .Schema .PrimaryFields {
if value , isZero := field .ValueOf (stmt .Context , stmt .ReflectValue ); !isZero {
stmt .AddClause (clause .Where {Exprs : []clause .Expression {clause .Eq {Column : field .DBName , Value : value }}})
}
}
}
}
switch value := updatingValue .Interface ().(type ) {
case map [string ]interface {}:
set = make ([]clause .Assignment , 0 , len (value ))
keys := make ([]string , 0 , len (value ))
for k := range value {
keys = append (keys , k )
}
sort .Strings (keys )
for _ , k := range keys {
kv := value [k ]
if _ , ok := kv .(*gorm .DB ); ok {
kv = []interface {}{kv }
}
if stmt .Schema != nil {
if field := stmt .Schema .LookUpField (k ); field != nil {
if field .DBName != "" {
if v , ok := selectColumns [field .DBName ]; (ok && v ) || (!ok && !restricted ) {
set = append (set , clause .Assignment {Column : clause .Column {Name : field .DBName }, Value : kv })
assignValue (field , value [k ])
}
} else if v , ok := selectColumns [field .Name ]; (ok && v ) || (!ok && !restricted ) {
assignValue (field , value [k ])
}
continue
}
}
if v , ok := selectColumns [k ]; (ok && v ) || (!ok && !restricted ) {
set = append (set , clause .Assignment {Column : clause .Column {Name : k }, Value : kv })
}
}
if !stmt .SkipHooks && stmt .Schema != nil {
for _ , dbName := range stmt .Schema .DBNames {
field := stmt .Schema .LookUpField (dbName )
if field .AutoUpdateTime > 0 && value [field .Name ] == nil && value [field .DBName ] == nil {
if v , ok := selectColumns [field .DBName ]; (ok && v ) || !ok {
now := stmt .DB .NowFunc ()
assignValue (field , now )
if field .AutoUpdateTime == schema .UnixNanosecond {
set = append (set , clause .Assignment {Column : clause .Column {Name : field .DBName }, Value : now .UnixNano ()})
} else if field .AutoUpdateTime == schema .UnixMillisecond {
set = append (set , clause .Assignment {Column : clause .Column {Name : field .DBName }, Value : now .UnixNano () / 1e6 })
} else if field .AutoUpdateTime == schema .UnixSecond {
set = append (set , clause .Assignment {Column : clause .Column {Name : field .DBName }, Value : now .Unix ()})
} else {
set = append (set , clause .Assignment {Column : clause .Column {Name : field .DBName }, Value : now })
}
}
}
}
}
default :
updatingSchema := stmt .Schema
var isDiffSchema bool
if !updatingValue .CanAddr () || stmt .Dest != stmt .Model {
updatingStmt := &gorm .Statement {DB : stmt .DB }
if err := updatingStmt .Parse (stmt .Dest ); err == nil {
updatingSchema = updatingStmt .Schema
isDiffSchema = true
}
}
switch updatingValue .Kind () {
case reflect .Struct :
set = make ([]clause .Assignment , 0 , len (stmt .Schema .FieldsByDBName ))
for _ , dbName := range stmt .Schema .DBNames {
if field := updatingSchema .LookUpField (dbName ); field != nil {
if !field .PrimaryKey || !updatingValue .CanAddr () || stmt .Dest != stmt .Model {
if v , ok := selectColumns [field .DBName ]; (ok && v ) || (!ok && (!restricted || (!stmt .SkipHooks && field .AutoUpdateTime > 0 ))) {
value , isZero := field .ValueOf (stmt .Context , updatingValue )
if !stmt .SkipHooks && field .AutoUpdateTime > 0 {
if field .AutoUpdateTime == schema .UnixNanosecond {
value = stmt .DB .NowFunc ().UnixNano ()
} else if field .AutoUpdateTime == schema .UnixMillisecond {
value = stmt .DB .NowFunc ().UnixNano () / 1e6
} else if field .AutoUpdateTime == schema .UnixSecond {
value = stmt .DB .NowFunc ().Unix ()
} else {
value = stmt .DB .NowFunc ()
}
isZero = false
}
if (ok || !isZero ) && field .Updatable {
set = append (set , clause .Assignment {Column : clause .Column {Name : field .DBName }, Value : value })
assignField := field
if isDiffSchema {
if originField := stmt .Schema .LookUpField (dbName ); originField != nil {
assignField = originField
}
}
assignValue (assignField , value )
}
}
} else {
if value , isZero := field .ValueOf (stmt .Context , updatingValue ); !isZero {
stmt .AddClause (clause .Where {Exprs : []clause .Expression {clause .Eq {Column : field .DBName , Value : value }}})
}
}
}
}
default :
stmt .AddError (gorm .ErrInvalidData )
}
}
return
}
The pages are generated with Golds v0.6.7 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .