package postgres import ( "context" "database/sql" "fmt" "sync" "github.com/lib/pq" ) type contextKey string const DBTxHolderContextKey contextKey = "txholder" type conn interface { Commit() error Rollback() error QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) } type TxConn struct { *sql.Tx } type DBTxHolder struct { conn conn mu sync.Mutex } func (db *DBTxHolder) Conn() (conn, error) { if db.conn == nil { return nil, fmt.Errorf("Conn() called before BeginTx()") } return db.conn, nil } func GetDBTxHolderFromContextOrNew(ctx context.Context) (*DBTxHolder, context.Context) { db, ok := ctx.Value(DBTxHolderContextKey).(*DBTxHolder) if !ok { db = &DBTxHolder{ conn: nil, } ctx = context.WithValue(ctx, DBTxHolderContextKey, db) } return db, ctx } func (db *DBTxHolder) BeginTx(ctx context.Context, conn *sql.DB) (conn, error) { dbTransactionConn, err := db.Continue(ctx, conn) if err != nil { return nil, err } return dbTransactionConn, nil } func (db *DBTxHolder) Continue(ctx context.Context, conn *sql.DB) (conn, error) { //db.mu.Lock() //defer db.mu.Unlock() var err error if db.conn == nil { db.conn, err = db.txFactory(ctx, conn) if err != nil { return nil, err } } return db.conn, err } func (db *DBTxHolder) txFactory(ctx context.Context, rConn *sql.DB) (*TxConn, error) { tx, bErr := rConn.BeginTx(ctx, nil) if bErr != nil { return nil, bErr } return &TxConn{Tx: tx}, nil } func (db *DBTxHolder) Commit() error { //db.mu.Lock() //defer db.mu.Unlock() if db.conn == nil { return fmt.Errorf("no active transaction") } err := db.conn.Commit() db.conn = nil return err } func (db *DBTxHolder) Rollback() error { //db.mu.Lock() //defer db.mu.Unlock() if db.conn == nil { return fmt.Errorf("no active transaction") } err := db.conn.Rollback() db.conn = nil return err } func (db *DBTxHolder) SavePoint(ctx context.Context, key string) error { //db.mu.Lock() //defer db.mu.Unlock() if db.conn == nil { return fmt.Errorf("no active transaction") } quoted := pq.QuoteIdentifier(key) _, err := db.conn.ExecContext(ctx, "SAVEPOINT "+quoted) return err } func (db *DBTxHolder) RollbackSavePoint(ctx context.Context, key string) error { //db.mu.Lock() //defer db.mu.Unlock() if db.conn == nil { return fmt.Errorf("no active transaction") } quoted := pq.QuoteIdentifier(key) _, err := db.conn.ExecContext(ctx, "ROLLBACK TO SAVEPOINT "+quoted) if err != nil { return err } return nil } func (db *DBTxHolder) ReleaseSavePoint(ctx context.Context, key string) error { //db.mu.Lock() //defer db.mu.Unlock() if db.conn == nil { return fmt.Errorf("no active transaction") } quoted := pq.QuoteIdentifier(key) _, err := db.conn.ExecContext(ctx, "RELEASE SAVEPOINT "+quoted) if err != nil { return err } return nil } func (tx *TxConn) StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) { txStmt := tx.StmtContext(ctx, stmt) return txStmt.QueryContext(ctx, args...) } func (tx *TxConn) StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row { txStmt := tx.StmtContext(ctx, stmt) return txStmt.QueryRowContext(ctx, args...) } func (tx *TxConn) StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) { txStmt := tx.StmtContext(ctx, stmt) result, err := txStmt.ExecContext(ctx, args...) if err != nil { return nil, err } return result, nil }