transaction.go
3.0 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package transaction
import (
"context"
"fmt"
"gorm.io/gorm"
"sync"
)
type Context struct {
//启用事务标识
beginTransFlag bool
db *gorm.DB
session *gorm.DB
lock sync.Mutex
}
func (transactionContext *Context) Begin() error {
transactionContext.lock.Lock()
defer transactionContext.lock.Unlock()
transactionContext.beginTransFlag = true
tx := transactionContext.db.Begin()
transactionContext.session = tx
return nil
}
func (transactionContext *Context) Commit() error {
transactionContext.lock.Lock()
defer transactionContext.lock.Unlock()
if !transactionContext.beginTransFlag {
return nil
}
tx := transactionContext.session.Commit()
return tx.Error
}
func (transactionContext *Context) Rollback() error {
transactionContext.lock.Lock()
defer transactionContext.lock.Unlock()
if !transactionContext.beginTransFlag {
return nil
}
tx := transactionContext.session.Rollback()
return tx.Error
}
func (transactionContext *Context) DB() *gorm.DB {
if transactionContext.beginTransFlag && transactionContext.session != nil {
return transactionContext.session
}
return transactionContext.db
}
func NewTransactionContext(db *gorm.DB) *Context {
return &Context{
db: db,
}
}
type Conn interface {
Begin() error
Commit() error
Rollback() error
DB() *gorm.DB
}
func MustUseTrans(ctx context.Context,
db *gorm.DB,
fn func(context.Context, Conn) error) error {
return UseTrans(ctx, db, fn, true)
}
// UseTrans when beginTrans is true , it will begin a new transaction
// to execute the function, recover when panic happen
func UseTrans(ctx context.Context,
db *gorm.DB,
fn func(context.Context, Conn) error, beginTrans bool) (err error) {
var tx Conn
tx = NewTransactionContext(db)
if beginTrans {
if err = tx.Begin(); err != nil {
return
}
}
defer func() {
if p := recover(); p != nil {
if e := tx.Rollback(); e != nil {
err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e)
} else {
err = fmt.Errorf("recoveer from %#v", p)
}
} else if err != nil {
if e := tx.Rollback(); e != nil {
err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e)
}
} else {
err = tx.Commit()
}
}()
return fn(ctx, tx)
}
func PaginationAndCount(ctx context.Context, tx *gorm.DB, params map[string]interface{}, dst interface{}) (int64, *gorm.DB) {
var total int64
// 只返回数量
if v, ok := params["countOnly"]; ok && v.(bool) {
tx = tx.Count(&total)
return total, tx
}
// 只返回记录
if v, ok := params["findOnly"]; ok && v.(bool) {
if v, ok := params["offset"]; ok {
tx.Offset(v.(int))
}
if v, ok := params["limit"]; ok {
tx.Limit(v.(int))
}
if tx = tx.Find(dst); tx.Error != nil {
return 0, tx
}
return total, tx
}
// 数量跟记录都返回
tx = tx.Count(&total)
if tx.Error != nil {
return total, tx
}
if v, ok := params["offset"]; ok {
tx.Offset(v.(int))
}
if v, ok := params["limit"]; ok {
tx.Limit(v.(int))
}
if tx = tx.Find(dst); tx.Error != nil {
return 0, tx
}
return total, tx
}