niki/pkg/database/postgres/transaction_handler.go

278 lines
6.5 KiB
Go

package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"sync"
errmsg "git.gocasts.ir/ebhomengo/niki/pkg/err_msg"
richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error"
"github.com/lib/pq"
)
type contextKey string
const DBTxHolderContextKey contextKey = "txholder"
type conn interface {
Commit() error
Rollback() error
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error)
StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row
StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error)
}
type TxConn struct {
*sql.Tx
}
type DBTxHolder struct {
conn conn
mu sync.Mutex
}
func (db *DBTxHolder) Conn() (conn, error) {
if db.conn == nil {
return nil, fmt.Errorf("Conn() called before BeginTx()")
}
return db.conn, nil
}
func GetDBTxHolderFromContextOrNew(ctx context.Context) (*DBTxHolder, context.Context) {
db, ok := ctx.Value(DBTxHolderContextKey).(*DBTxHolder)
if !ok {
db = &DBTxHolder{
conn: nil,
}
ctx = context.WithValue(ctx, DBTxHolderContextKey, db)
}
return db, ctx
}
func (db *DBTxHolder) BeginTx(ctx context.Context, conn *sql.DB) (conn, error) {
dbTransactionConn, err := db.Continue(ctx, conn)
if err != nil {
return nil, err
}
return dbTransactionConn, nil
}
func (db *DBTxHolder) Continue(ctx context.Context, conn *sql.DB) (conn, error) {
//db.mu.Lock()
//defer db.mu.Unlock()
var err error
if db.conn == nil {
db.conn, err = db.txFactory(ctx, conn)
if err != nil {
return nil, err
}
}
return db.conn, err
}
func (db *DBTxHolder) txFactory(ctx context.Context, rConn *sql.DB) (*TxConn, error) {
tx, bErr := rConn.BeginTx(ctx, nil)
if bErr != nil {
return nil, bErr
}
return &TxConn{Tx: tx}, nil
}
func (db *DBTxHolder) Commit() error {
//db.mu.Lock()
//defer db.mu.Unlock()
if db.conn == nil {
return fmt.Errorf("no active transaction")
}
err := db.conn.Commit()
db.conn = nil
return err
}
func (db *DBTxHolder) Rollback() error {
//db.mu.Lock()
//defer db.mu.Unlock()
if db.conn == nil {
return fmt.Errorf("no active transaction")
}
err := db.conn.Rollback()
db.conn = nil
return err
}
func (db *DBTxHolder) SavePoint(ctx context.Context, key string) error {
//db.mu.Lock()
//defer db.mu.Unlock()
if db.conn == nil {
return fmt.Errorf("no active transaction")
}
quoted := pq.QuoteIdentifier(key)
_, err := db.conn.ExecContext(ctx, "SAVEPOINT "+quoted)
return err
}
func (db *DBTxHolder) RollbackSavePoint(ctx context.Context, key string) error {
//db.mu.Lock()
//defer db.mu.Unlock()
if db.conn == nil {
return fmt.Errorf("no active transaction")
}
quoted := pq.QuoteIdentifier(key)
_, err := db.conn.ExecContext(ctx, "ROLLBACK TO SAVEPOINT "+quoted)
if err != nil {
return err
}
return nil
}
func (db *DBTxHolder) ReleaseSavePoint(ctx context.Context, key string) error {
//db.mu.Lock()
//defer db.mu.Unlock()
if db.conn == nil {
return fmt.Errorf("no active transaction")
}
quoted := pq.QuoteIdentifier(key)
_, err := db.conn.ExecContext(ctx, "RELEASE SAVEPOINT "+quoted)
if err != nil {
return err
}
return nil
}
func (tx *TxConn) StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) {
txStmt := tx.StmtContext(ctx, stmt)
return txStmt.QueryContext(ctx, args...)
}
func (tx *TxConn) StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row {
txStmt := tx.StmtContext(ctx, stmt)
return txStmt.QueryRowContext(ctx, args...)
}
func (tx *TxConn) StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) {
txStmt := tx.StmtContext(ctx, stmt)
result, err := txStmt.ExecContext(ctx, args...)
if err != nil {
return nil, err
}
return result, nil
}
///////////////////////// generic query
func TXInstantQueryContext[T AllowUseDBGenericFunc](ctx context.Context, txConn *sql.Tx, stmtKey statementKey, query string, conn *DB, scanner ScannerFunc[T], args ...any) ([]T, error) {
const op = richerror.Op("postgres.TXInstantQueryContext")
stmt, err := conn.PrepareStatement(ctx, stmtKey, query)
if err != nil {
return nil, richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery)
}
txStmt := txConn.StmtContext(ctx, stmt)
rows, qErr := txStmt.QueryContext(ctx, args...)
if qErr != nil {
return nil, richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery)
}
defer rows.Close()
var itemsList []T
for rows.Next() {
item, sErr := scanner(rows)
if sErr != nil {
return nil, richerror.New(op).WithErr(sErr).WithMessage(errmsg.ErrorMsgCantScanQueryResult)
}
itemsList = append(itemsList, item)
}
if rErr := rows.Err(); rErr != nil {
return nil, richerror.New(op).WithErr(rErr).WithMessage(errmsg.ErrorMsgFailedQuery)
}
return itemsList, nil
}
func TXInstantQueryRowContext[T AllowUseDBGenericFunc](ctx context.Context, txConn *sql.Tx, stmtKey statementKey, query string, conn *DB, scanner ScannerFunc[T], args ...any) (item T, err error) {
const op = richerror.Op("postgres.TXInstantQueryRowContext")
stmt, sErr := conn.PrepareStatement(ctx, stmtKey, query)
if sErr != nil {
err = richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery)
return
}
txStmt := txConn.StmtContext(ctx, stmt)
row := txStmt.QueryRowContext(ctx, args...)
item, scErr := scanner(row)
if scErr != nil {
if errors.Is(scErr, sql.ErrNoRows) {
err = richerror.New(op).WithErr(scErr).WithKind(richerror.KindNotFound).WithMessage(errmsg.ErrorMsgCantScanQueryResult)
return
}
err = richerror.New(op).WithErr(scErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgCantScanQueryResult)
return
}
if rErr := row.Err(); rErr != nil {
err = richerror.New(op).WithErr(rErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgFailedQuery)
return
}
return
}
func TXInstantExecContext[T AllowUseDBGenericFunc](ctx context.Context, txConn *sql.Tx, stmtKey statementKey, query string, conn *DB, args ...any) (sql.Result, error) {
const op = richerror.Op("postgres.TXInstantExecContext")
stmt, err := conn.PrepareStatement(ctx, stmtKey, query)
if err != nil {
return nil, richerror.New(op)
}
txStmt := txConn.StmtContext(ctx, stmt)
result, err := txStmt.ExecContext(ctx, args...)
if err != nil {
return nil, richerror.New(op)
}
return result, nil
}