transaction.go 2.9 KB
package transaction

import (

type Context struct {
	beginTransFlag bool
	db             *gorm.DB
	session        *gorm.DB
	lock           sync.Mutex

func (transactionContext *Context) Begin() error {
	defer transactionContext.lock.Unlock()
	transactionContext.beginTransFlag = true
	tx := transactionContext.db.Begin()
	transactionContext.session = tx
	return nil

func (transactionContext *Context) Commit() error {
	defer transactionContext.lock.Unlock()
	if !transactionContext.beginTransFlag {
		return nil
	tx := transactionContext.session.Commit()
	return tx.Error

func (transactionContext *Context) Rollback() error {
	defer transactionContext.lock.Unlock()
	if !transactionContext.beginTransFlag {
		return nil
	tx := transactionContext.session.Rollback()
	return tx.Error

func (transactionContext *Context) DB() *gorm.DB {
	if transactionContext.beginTransFlag && transactionContext.session != nil {
		return transactionContext.session
	return transactionContext.db

func NewTransactionContext(db *gorm.DB) *Context {
	return &Context{
		db: db,

type Conn interface {
	Begin() error
	Commit() error
	Rollback() error
	DB() *gorm.DB

// UseTrans when beginTrans is true , it will begin a new transaction
// to execute the function, recover when  panic happen
func UseTrans(ctx context.Context,
	db *gorm.DB,
	fn func(context.Context, Conn) error, beginTrans bool) (err error) {
	var tx Conn
	tx = NewTransactionContext(db)
	if beginTrans {
		if err = tx.Begin(); err != nil {
	defer func() {
		if p := recover(); p != nil {
			if e := tx.Rollback(); e != nil {
				err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e)
			} else {
				err = fmt.Errorf("recoveer from %#v", p)
		} else if err != nil {
			if e := tx.Rollback(); e != nil {
				err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e)
		} else {
			err = tx.Commit()

	return fn(ctx, tx)

func PaginationAndCount(ctx context.Context, tx *gorm.DB, params map[string]interface{}, dst interface{}) (int64, *gorm.DB) {
	var total int64
	// 只返回数量
	if v, ok := params["countOnly"]; ok && v.(bool) {
		tx = tx.Count(&total)
		return total, tx
	// 只返回记录
	if v, ok := params["findOnly"]; ok && v.(bool) {
		if v, ok := params["offset"]; ok {
		if v, ok := params["limit"]; ok {
		if tx = tx.Find(dst); tx.Error != nil {
			return 0, tx
		return total, tx
	tx = tx.Count(&total)
	if tx.Error != nil {
		return total, tx
	if v, ok := params["offset"]; ok {
	if v, ok := params["limit"]; ok {
	if tx = tx.Find(dst); tx.Error != nil {
		return 0, tx
	return total, tx