package querier import ( "context" "database/sql" "sync" "sync/atomic" ) const ( QuerierContextKey = "querier" ) type conn interface { Commit() error Rollback() error Begin() (*sql.Tx, error) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) QueryRow(query string, args ...any) *sql.Row QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row Exec(query string, args ...any) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) } type querier struct { txRequested atomic.Bool initOnce sync.Once conn conn } func GetQuerierFromContextOrNew(ctx context.Context) *querier { q, ok := ctx.Value(QuerierContextKey).(*querier) if !ok { q = &querier{ txRequested: atomic.Bool{}, initOnce: sync.Once{}, conn: nil, } } return q } func (q *querier) Begin() *querier { q.txRequested.Store(true) return q } func (q *querier) Continue(ctx context.Context, conn *SqlDB) (*querier, error) { var iErr error q.initOnce.Do(func() { if q.txRequested.Load() { tx, bErr := conn.BeginTx(ctx, nil) if bErr != nil { iErr = bErr return } q.conn = &SqlTx{tx} } else { q.conn = conn } }) return q, iErr } func (q *querier) Commit() error { return q.conn.Commit() } func (q *querier) Rollback() error { return q.conn.Rollback() } func (q *querier) Conn() conn { return q.conn } type SqlTx struct { *sql.Tx } func (tx *SqlTx) Begin() (*sql.Tx, error) { return &sql.Tx{}, nil } func (tx *SqlTx) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { return &sql.Tx{}, nil } type SqlDB struct { *sql.DB } func (db *SqlDB) Commit() error { return nil } func (db *SqlDB) Rollback() error { return nil }