From 02c67a185e5b926b1917cde51c2b0288a8dbf752 Mon Sep 17 00:00:00 2001 From: weloe <1345895607@qq.com> Date: Wed, 12 Apr 2023 18:28:09 +0800 Subject: [PATCH 1/3] fix: use transaction to SavePolicy --- adapter.go | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/adapter.go b/adapter.go index 1fab5f8..bfa2ae0 100755 --- a/adapter.go +++ b/adapter.go @@ -579,8 +579,17 @@ func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule { // SavePolicy saves policy to database. func (a *Adapter) SavePolicy(model model.Model) error { - err := a.truncateTable() + var err error + tx := a.db.Begin() + + if a.db.Config.Name() == sqlite.DriverName { + err = tx.Exec(fmt.Sprintf("delete from %s", a.getFullTableName())).Error + } else { + err = tx.Exec(fmt.Sprintf("truncate table %s", a.getFullTableName())).Error + } + if err != nil { + tx.Rollback() return err } @@ -590,7 +599,8 @@ func (a *Adapter) SavePolicy(model model.Model) error { for _, rule := range ast.Policy { lines = append(lines, a.savePolicyLine(ptype, rule)) if len(lines) > flushEvery { - if err := a.db.Create(&lines).Error; err != nil { + if err := tx.Create(&lines).Error; err != nil { + tx.Rollback() return err } lines = nil @@ -602,7 +612,8 @@ func (a *Adapter) SavePolicy(model model.Model) error { for _, rule := range ast.Policy { lines = append(lines, a.savePolicyLine(ptype, rule)) if len(lines) > flushEvery { - if err := a.db.Create(&lines).Error; err != nil { + if err := tx.Create(&lines).Error; err != nil { + tx.Rollback() return err } lines = nil @@ -610,11 +621,13 @@ func (a *Adapter) SavePolicy(model model.Model) error { } } if len(lines) > 0 { - if err := a.db.Create(&lines).Error; err != nil { + if err := tx.Create(&lines).Error; err != nil { + tx.Rollback() return err } } + tx.Commit() return nil } From bda8c8916735aac71dc4270e50b70b83a074035d Mon Sep 17 00:00:00 2001 From: weloe <1345895607@qq.com> Date: Thu, 13 Apr 2023 03:01:55 +0800 Subject: [PATCH 2/3] fix: fix use error db to execute transaction before --- adapter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/adapter.go b/adapter.go index bfa2ae0..edd55dc 100755 --- a/adapter.go +++ b/adapter.go @@ -580,7 +580,7 @@ func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule { // SavePolicy saves policy to database. func (a *Adapter) SavePolicy(model model.Model) error { var err error - tx := a.db.Begin() + tx := a.db.Clauses(dbresolver.Write).Begin() if a.db.Config.Name() == sqlite.DriverName { err = tx.Exec(fmt.Sprintf("delete from %s", a.getFullTableName())).Error From 91b097e4c4102b92b583087246b5763c4a45d5c7 Mon Sep 17 00:00:00 2001 From: weloe <1345895607@qq.com> Date: Thu, 13 Apr 2023 03:32:20 +0800 Subject: [PATCH 3/3] update: fix return error --- adapter.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/adapter.go b/adapter.go index edd55dc..48cc07e 100755 --- a/adapter.go +++ b/adapter.go @@ -627,8 +627,8 @@ func (a *Adapter) SavePolicy(model model.Model) error { } } - tx.Commit() - return nil + err = tx.Commit().Error + return err } // AddPolicy adds a policy rule to the storage.