forked from ebhomengo/niki
278 lines
6.5 KiB
Go
278 lines
6.5 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
|
|
errmsg "git.gocasts.ir/ebhomengo/niki/pkg/err_msg"
|
|
richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error"
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const DBTxHolderContextKey contextKey = "txholder"
|
|
|
|
type conn interface {
|
|
Commit() error
|
|
Rollback() error
|
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
|
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error)
|
|
StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row
|
|
StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error)
|
|
}
|
|
|
|
type TxConn struct {
|
|
*sql.Tx
|
|
}
|
|
|
|
type DBTxHolder struct {
|
|
conn conn
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (db *DBTxHolder) Conn() (conn, error) {
|
|
if db.conn == nil {
|
|
return nil, fmt.Errorf("Conn() called before BeginTx()")
|
|
}
|
|
return db.conn, nil
|
|
}
|
|
|
|
func GetDBTxHolderFromContextOrNew(ctx context.Context) (*DBTxHolder, context.Context) {
|
|
db, ok := ctx.Value(DBTxHolderContextKey).(*DBTxHolder)
|
|
if !ok {
|
|
db = &DBTxHolder{
|
|
conn: nil,
|
|
}
|
|
ctx = context.WithValue(ctx, DBTxHolderContextKey, db)
|
|
}
|
|
|
|
return db, ctx
|
|
}
|
|
|
|
func (db *DBTxHolder) BeginTx(ctx context.Context, conn *sql.DB) (conn, error) {
|
|
dbTransactionConn, err := db.Continue(ctx, conn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return dbTransactionConn, nil
|
|
}
|
|
|
|
func (db *DBTxHolder) Continue(ctx context.Context, conn *sql.DB) (conn, error) {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
|
|
var err error
|
|
if db.conn == nil {
|
|
|
|
db.conn, err = db.txFactory(ctx, conn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
}
|
|
|
|
return db.conn, err
|
|
}
|
|
|
|
func (db *DBTxHolder) txFactory(ctx context.Context, rConn *sql.DB) (*TxConn, error) {
|
|
|
|
tx, bErr := rConn.BeginTx(ctx, nil)
|
|
if bErr != nil {
|
|
|
|
return nil, bErr
|
|
}
|
|
|
|
return &TxConn{Tx: tx}, nil
|
|
}
|
|
|
|
func (db *DBTxHolder) Commit() error {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
if db.conn == nil {
|
|
return fmt.Errorf("no active transaction")
|
|
}
|
|
err := db.conn.Commit()
|
|
db.conn = nil
|
|
return err
|
|
}
|
|
|
|
func (db *DBTxHolder) Rollback() error {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
if db.conn == nil {
|
|
return fmt.Errorf("no active transaction")
|
|
}
|
|
err := db.conn.Rollback()
|
|
db.conn = nil
|
|
|
|
return err
|
|
}
|
|
|
|
func (db *DBTxHolder) SavePoint(ctx context.Context, key string) error {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
if db.conn == nil {
|
|
return fmt.Errorf("no active transaction")
|
|
}
|
|
quoted := pq.QuoteIdentifier(key)
|
|
_, err := db.conn.ExecContext(ctx, "SAVEPOINT "+quoted)
|
|
return err
|
|
}
|
|
|
|
func (db *DBTxHolder) RollbackSavePoint(ctx context.Context, key string) error {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
if db.conn == nil {
|
|
return fmt.Errorf("no active transaction")
|
|
}
|
|
|
|
quoted := pq.QuoteIdentifier(key)
|
|
|
|
_, err := db.conn.ExecContext(ctx, "ROLLBACK TO SAVEPOINT "+quoted)
|
|
|
|
if err != nil {
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
func (db *DBTxHolder) ReleaseSavePoint(ctx context.Context, key string) error {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
if db.conn == nil {
|
|
return fmt.Errorf("no active transaction")
|
|
}
|
|
quoted := pq.QuoteIdentifier(key)
|
|
_, err := db.conn.ExecContext(ctx, "RELEASE SAVEPOINT "+quoted)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (tx *TxConn) StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) {
|
|
txStmt := tx.StmtContext(ctx, stmt)
|
|
return txStmt.QueryContext(ctx, args...)
|
|
}
|
|
|
|
func (tx *TxConn) StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row {
|
|
txStmt := tx.StmtContext(ctx, stmt)
|
|
return txStmt.QueryRowContext(ctx, args...)
|
|
}
|
|
|
|
func (tx *TxConn) StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) {
|
|
txStmt := tx.StmtContext(ctx, stmt)
|
|
result, err := txStmt.ExecContext(ctx, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
///////////////////////// generic query
|
|
|
|
func TXInstantQueryContext[T AllowUseDBGenericFunc](ctx context.Context, txConn *sql.Tx, stmtKey statementKey, query string, conn *DB, scanner ScannerFunc[T], args ...any) ([]T, error) {
|
|
const op = richerror.Op("postgres.TXInstantQueryContext")
|
|
|
|
stmt, err := conn.PrepareStatement(ctx, stmtKey, query)
|
|
|
|
if err != nil {
|
|
return nil, richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery)
|
|
}
|
|
|
|
txStmt := txConn.StmtContext(ctx, stmt)
|
|
|
|
rows, qErr := txStmt.QueryContext(ctx, args...)
|
|
|
|
if qErr != nil {
|
|
return nil, richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var itemsList []T
|
|
|
|
for rows.Next() {
|
|
|
|
item, sErr := scanner(rows)
|
|
if sErr != nil {
|
|
return nil, richerror.New(op).WithErr(sErr).WithMessage(errmsg.ErrorMsgCantScanQueryResult)
|
|
}
|
|
|
|
itemsList = append(itemsList, item)
|
|
|
|
}
|
|
|
|
if rErr := rows.Err(); rErr != nil {
|
|
return nil, richerror.New(op).WithErr(rErr).WithMessage(errmsg.ErrorMsgFailedQuery)
|
|
|
|
}
|
|
|
|
return itemsList, nil
|
|
|
|
}
|
|
|
|
func TXInstantQueryRowContext[T AllowUseDBGenericFunc](ctx context.Context, txConn *sql.Tx, stmtKey statementKey, query string, conn *DB, scanner ScannerFunc[T], args ...any) (item T, err error) {
|
|
const op = richerror.Op("postgres.TXInstantQueryRowContext")
|
|
|
|
stmt, sErr := conn.PrepareStatement(ctx, stmtKey, query)
|
|
|
|
if sErr != nil {
|
|
err = richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery)
|
|
return
|
|
}
|
|
txStmt := txConn.StmtContext(ctx, stmt)
|
|
|
|
row := txStmt.QueryRowContext(ctx, args...)
|
|
|
|
item, scErr := scanner(row)
|
|
if scErr != nil {
|
|
if errors.Is(scErr, sql.ErrNoRows) {
|
|
err = richerror.New(op).WithErr(scErr).WithKind(richerror.KindNotFound).WithMessage(errmsg.ErrorMsgCantScanQueryResult)
|
|
return
|
|
|
|
}
|
|
err = richerror.New(op).WithErr(scErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgCantScanQueryResult)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if rErr := row.Err(); rErr != nil {
|
|
err = richerror.New(op).WithErr(rErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgFailedQuery)
|
|
return
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
func TXInstantExecContext[T AllowUseDBGenericFunc](ctx context.Context, txConn *sql.Tx, stmtKey statementKey, query string, conn *DB, args ...any) (sql.Result, error) {
|
|
const op = richerror.Op("postgres.TXInstantExecContext")
|
|
|
|
stmt, err := conn.PrepareStatement(ctx, stmtKey, query)
|
|
|
|
if err != nil {
|
|
return nil, richerror.New(op)
|
|
}
|
|
txStmt := txConn.StmtContext(ctx, stmt)
|
|
|
|
result, err := txStmt.ExecContext(ctx, args...)
|
|
if err != nil {
|
|
return nil, richerror.New(op)
|
|
|
|
}
|
|
|
|
return result, nil
|
|
}
|