forked from ebhomengo/niki
179 lines
3.8 KiB
Go
179 lines
3.8 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const DBTxHolderContextKey contextKey = "txholder"
|
|
|
|
type conn interface {
|
|
Commit() error
|
|
Rollback() error
|
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
|
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error)
|
|
StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row
|
|
StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error)
|
|
}
|
|
|
|
type TxConn struct {
|
|
*sql.Tx
|
|
}
|
|
|
|
type DBTxHolder struct {
|
|
conn conn
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (db *DBTxHolder) Conn() (conn, error) {
|
|
if db.conn == nil {
|
|
return nil, fmt.Errorf("Conn() called before BeginTx()")
|
|
}
|
|
return db.conn, nil
|
|
}
|
|
|
|
func GetDBTxHolderFromContextOrNew(ctx context.Context) (*DBTxHolder, context.Context) {
|
|
db, ok := ctx.Value(DBTxHolderContextKey).(*DBTxHolder)
|
|
if !ok {
|
|
db = &DBTxHolder{
|
|
conn: nil,
|
|
}
|
|
ctx = context.WithValue(ctx, DBTxHolderContextKey, db)
|
|
}
|
|
|
|
return db, ctx
|
|
}
|
|
|
|
func (db *DBTxHolder) BeginTx(ctx context.Context, conn *sql.DB) (conn, error) {
|
|
dbTransactionConn, err := db.Continue(ctx, conn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return dbTransactionConn, nil
|
|
}
|
|
|
|
func (db *DBTxHolder) Continue(ctx context.Context, conn *sql.DB) (conn, error) {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
|
|
var err error
|
|
if db.conn == nil {
|
|
|
|
db.conn, err = db.txFactory(ctx, conn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
}
|
|
|
|
return db.conn, err
|
|
}
|
|
|
|
func (db *DBTxHolder) txFactory(ctx context.Context, rConn *sql.DB) (*TxConn, error) {
|
|
|
|
tx, bErr := rConn.BeginTx(ctx, nil)
|
|
if bErr != nil {
|
|
|
|
return nil, bErr
|
|
}
|
|
|
|
return &TxConn{Tx: tx}, nil
|
|
}
|
|
|
|
func (db *DBTxHolder) Commit() error {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
if db.conn == nil {
|
|
return fmt.Errorf("no active transaction")
|
|
}
|
|
err := db.conn.Commit()
|
|
db.conn = nil
|
|
return err
|
|
}
|
|
|
|
func (db *DBTxHolder) Rollback() error {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
if db.conn == nil {
|
|
return fmt.Errorf("no active transaction")
|
|
}
|
|
err := db.conn.Rollback()
|
|
db.conn = nil
|
|
|
|
return err
|
|
}
|
|
|
|
func (db *DBTxHolder) SavePoint(ctx context.Context, key string) error {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
if db.conn == nil {
|
|
return fmt.Errorf("no active transaction")
|
|
}
|
|
quoted := pq.QuoteIdentifier(key)
|
|
_, err := db.conn.ExecContext(ctx, "SAVEPOINT "+quoted)
|
|
return err
|
|
}
|
|
|
|
func (db *DBTxHolder) RollbackSavePoint(ctx context.Context, key string) error {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
if db.conn == nil {
|
|
return fmt.Errorf("no active transaction")
|
|
}
|
|
|
|
quoted := pq.QuoteIdentifier(key)
|
|
|
|
_, err := db.conn.ExecContext(ctx, "ROLLBACK TO SAVEPOINT "+quoted)
|
|
|
|
if err != nil {
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
func (db *DBTxHolder) ReleaseSavePoint(ctx context.Context, key string) error {
|
|
//db.mu.Lock()
|
|
//defer db.mu.Unlock()
|
|
if db.conn == nil {
|
|
return fmt.Errorf("no active transaction")
|
|
}
|
|
quoted := pq.QuoteIdentifier(key)
|
|
_, err := db.conn.ExecContext(ctx, "RELEASE SAVEPOINT "+quoted)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (tx *TxConn) StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) {
|
|
txStmt := tx.StmtContext(ctx, stmt)
|
|
return txStmt.QueryContext(ctx, args...)
|
|
}
|
|
|
|
func (tx *TxConn) StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row {
|
|
txStmt := tx.StmtContext(ctx, stmt)
|
|
return txStmt.QueryRowContext(ctx, args...)
|
|
}
|
|
|
|
func (tx *TxConn) StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) {
|
|
txStmt := tx.StmtContext(ctx, stmt)
|
|
result, err := txStmt.ExecContext(ctx, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|