forked from ebhomengo/niki
1
0
Fork 0
niki/pkg/query_transaction/sql/querier.go

94 lines
1.7 KiB
Go

package querier
import (
"context"
"database/sql"
"sync"
"sync/atomic"
)
const (
QuerierContextKey = "querier"
)
type conn interface {
Commit() error
Rollback() error
Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
QueryRow(query string, args ...any) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
Exec(query string, args ...any) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type querier struct {
txRequested atomic.Bool
initOnce sync.Once
conn conn
}
func GetQuerierFromContextOrNew(ctx context.Context) *querier {
q, ok := ctx.Value(QuerierContextKey).(*querier)
if !ok {
q = &querier{
txRequested: atomic.Bool{},
initOnce: sync.Once{},
conn: nil,
}
}
return q
}
func (q *querier) Begin() *querier {
q.txRequested.Store(true)
return q
}
func (q *querier) Continue(ctx context.Context, conn conn) (*querier, error) {
var iErr error
q.initOnce.Do(func() {
if q.txRequested.Load() {
tx, bErr := conn.BeginTx(ctx, nil)
if bErr != nil {
iErr = bErr
return
}
q.conn = &SqlTx{tx}
} else {
q.conn = conn.(*SqlDB)
}
})
return q, iErr
}
func (q *querier) Commit() error {
return q.conn.Commit()
}
func (q *querier) Rollback() error {
return q.conn.Rollback()
}
func (q *querier) Conn() conn {
return q.conn
}
type SqlTx struct {
*sql.Tx
}
func (tx *SqlTx) Begin() (*sql.Tx, error) {
return &sql.Tx{}, nil
}
func (tx *SqlTx) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return &sql.Tx{}, nil
}
type SqlDB struct {
*sql.DB
}
func (db *SqlDB) Commit() error {
return nil
}
func (db *SqlDB) Rollback() error {
return nil
}