niki/pkg/database/postgres/transaction_handler.go

179 lines
3.8 KiB
Go

package postgres
import (
"context"
"database/sql"
"fmt"
"sync"
"github.com/lib/pq"
)
type contextKey string
const DBTxHolderContextKey contextKey = "txholder"
type conn interface {
Commit() error
Rollback() error
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error)
StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row
StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error)
}
type TxConn struct {
*sql.Tx
}
type DBTxHolder struct {
conn conn
mu sync.Mutex
}
func (db *DBTxHolder) Conn() (conn, error) {
if db.conn == nil {
return nil, fmt.Errorf("Conn() called before BeginTx()")
}
return db.conn, nil
}
func GetDBTxHolderFromContextOrNew(ctx context.Context) (*DBTxHolder, context.Context) {
db, ok := ctx.Value(DBTxHolderContextKey).(*DBTxHolder)
if !ok {
db = &DBTxHolder{
conn: nil,
}
ctx = context.WithValue(ctx, DBTxHolderContextKey, db)
}
return db, ctx
}
func (db *DBTxHolder) BeginTx(ctx context.Context, conn *sql.DB) (conn, error) {
dbTransactionConn, err := db.Continue(ctx, conn)
if err != nil {
return nil, err
}
return dbTransactionConn, nil
}
func (db *DBTxHolder) Continue(ctx context.Context, conn *sql.DB) (conn, error) {
//db.mu.Lock()
//defer db.mu.Unlock()
var err error
if db.conn == nil {
db.conn, err = db.txFactory(ctx, conn)
if err != nil {
return nil, err
}
}
return db.conn, err
}
func (db *DBTxHolder) txFactory(ctx context.Context, rConn *sql.DB) (*TxConn, error) {
tx, bErr := rConn.BeginTx(ctx, nil)
if bErr != nil {
return nil, bErr
}
return &TxConn{Tx: tx}, nil
}
func (db *DBTxHolder) Commit() error {
//db.mu.Lock()
//defer db.mu.Unlock()
if db.conn == nil {
return fmt.Errorf("no active transaction")
}
err := db.conn.Commit()
db.conn = nil
return err
}
func (db *DBTxHolder) Rollback() error {
//db.mu.Lock()
//defer db.mu.Unlock()
if db.conn == nil {
return fmt.Errorf("no active transaction")
}
err := db.conn.Rollback()
db.conn = nil
return err
}
func (db *DBTxHolder) SavePoint(ctx context.Context, key string) error {
//db.mu.Lock()
//defer db.mu.Unlock()
if db.conn == nil {
return fmt.Errorf("no active transaction")
}
quoted := pq.QuoteIdentifier(key)
_, err := db.conn.ExecContext(ctx, "SAVEPOINT "+quoted)
return err
}
func (db *DBTxHolder) RollbackSavePoint(ctx context.Context, key string) error {
//db.mu.Lock()
//defer db.mu.Unlock()
if db.conn == nil {
return fmt.Errorf("no active transaction")
}
quoted := pq.QuoteIdentifier(key)
_, err := db.conn.ExecContext(ctx, "ROLLBACK TO SAVEPOINT "+quoted)
if err != nil {
return err
}
return nil
}
func (db *DBTxHolder) ReleaseSavePoint(ctx context.Context, key string) error {
//db.mu.Lock()
//defer db.mu.Unlock()
if db.conn == nil {
return fmt.Errorf("no active transaction")
}
quoted := pq.QuoteIdentifier(key)
_, err := db.conn.ExecContext(ctx, "RELEASE SAVEPOINT "+quoted)
if err != nil {
return err
}
return nil
}
func (tx *TxConn) StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) {
txStmt := tx.StmtContext(ctx, stmt)
return txStmt.QueryContext(ctx, args...)
}
func (tx *TxConn) StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row {
txStmt := tx.StmtContext(ctx, stmt)
return txStmt.QueryRowContext(ctx, args...)
}
func (tx *TxConn) StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) {
txStmt := tx.StmtContext(ctx, stmt)
result, err := txStmt.ExecContext(ctx, args...)
if err != nil {
return nil, err
}
return result, nil
}