interface

package repository

import "context"

type Transaction interface {
	Begin(ctx context.Context) (context.Context, error)
	Commit(ctx context.Context) error
	Rollback(ctx context.Context) error
}

usecase

func InTx(ctx context.Context, tx repository.Transaction, fn func(ctx context.Context) error) (err error) {
	tCtx, err := tx.Begin(ctx)
	if err != nil {
		return
	}
	defer func() {
		if re := recover(); re != nil {
			if err == nil {
				err = fmt.Errorf("panic recover: %v", re)
			} else {
				err = fmt.Errorf("panic recover: %v: %w", re, err)
			}
		}
		if err != nil {
			re := tx.Rollback(ctx)
			util.HandleError(re)
			util.LoggingDebug(ctx, "Transaction.Rollback", logrus.Fields{})
		}
	}()
	err = fn(tCtx)
	if err != nil {
		return
	}
	err = tx.Commit(tCtx)
	if err != nil {
		return
	}
	util.LoggingDebug(tCtx, "Transaction.Commi", logrus.Fields{})
	return
}

独自の interface としてトランザクションを定義しておくことで, テストを書く時モック化が簡単

type tmpTransaction struct {
	stateList []tmpTransactionState
}

func (tx *tmpTransaction) Begin(ctx context.Context) (context.Context, error) {
	tx.stateList = append(tx.stateList, begin)
	return ctx, nil
}

func (tx *tmpTransaction) Commit(ctx context.Context) error {
	tx.stateList = append(tx.stateList, commit)
	return nil
}

func (tx *tmpTransaction) Rollback(ctx context.Context) error {
	tx.stateList = append(tx.stateList, rollback)
	return nil
}

func TestInTx(t *testing.T) {
	ctx := context.Background()
	t.Run("commit", func(t *testing.T) {
		tx := &tmpTransaction{
			stateList: []tmpTransactionState{},
		}
		err := usecase.InTx(ctx, tx, func(ctx context.Context) error {
			return nil
		})
		expected := []tmpTransactionState{
			begin,
			commit,
		}
		assert.Nil(t, err)
		assert.Equal(t, expected, tx.stateList)
	})
	t.Run("error 発生で rollback", func(t *testing.T) {
		tx := &tmpTransaction{
			stateList: []tmpTransactionState{},
		}
		err := usecase.InTx(ctx, tx, func(ctx context.Context) error {
			return errors.New("test error")
		})
		expected := []tmpTransactionState{
			begin,
			rollback,
		}
		assert.EqualError(t, err, "test error")
		assert.Equal(t, expected, tx.stateList)
	})
	t.Run("panic 発生で rollback", func(t *testing.T) {
		tx := &tmpTransaction{
			stateList: []tmpTransactionState{},
		}
		err := usecase.InTx(ctx, tx, func(ctx context.Context) error {
			panic("test error")
		})
		expected := []tmpTransactionState{
			begin,
			rollback,
		}
		assert.EqualError(t, err, "panic recover: test error")
		assert.Equal(t, expected, tx.stateList)
	})
}

肝心の sqlboiler でのトランザクション実装

package mysql

import (
	"context"
	"database/sql"
)

type TransactionMySQL struct{}

func NewTransactionMySQL() *TransactionMySQL {
	return &TransactionMySQL{}
}

func (t *TransactionMySQL) Begin(ctx context.Context) (context.Context, error) {
	tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
	if err != nil {
		return nil, err
	}
	return setTransaction(ctx, tx), nil
}

func (t *TransactionMySQL) Commit(ctx context.Context) error {
	tx := getTransaction(ctx)
	return tx.Commit()
}

func (t *TransactionMySQL) Rollback(ctx context.Context) error {
	tx := getTransaction(ctx)
	return tx.Rollback()
}

type transactionKey struct{}

func setTransaction(ctx context.Context, tx *sql.Tx) context.Context {
	return context.WithValue(ctx, transactionKey{}, tx)
}

func getTransaction(ctx context.Context) *sql.Tx {
	return ctx.Value(transactionKey{}).(*sql.Tx)
}

func isStartedTransaction(ctx context.Context) bool {
	v := ctx.Value(transactionKey{})
	return v != nil
}

テストから実際クエリ的に begin...commit されてることを確認できる

func TestTransaction(t *testing.T) {
	db, mock, err := sqlmock.New()
	util.HandleError(err)

	mysql.SetDB(db)

	t.Run("commit", func(t *testing.T) {
		ctx := context.Background()

		mock.ExpectBegin()
		mock.ExpectCommit()

		tx := mysql.NewTransactionMySQL()
		ctx, err = tx.Begin(ctx)
		tx.Commit(ctx)
		assert.Nil(t, err)

		err = mock.ExpectationsWereMet()
		assert.Nil(t, err)
	})
	t.Run("rollback", func(t *testing.T) {
		ctx := context.Background()

		mock.ExpectBegin()
		mock.ExpectRollback()

		tx := mysql.NewTransactionMySQL()
		ctx, err = tx.Begin(ctx)
		tx.Rollback(ctx)
		assert.Nil(t, err)

		err = mock.ExpectationsWereMet()
		assert.Nil(t, err)
	})
}

他のインターフェース実装例