forked from ebhomengo/niki
94 lines
1.7 KiB
Go
94 lines
1.7 KiB
Go
package querier
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"sync"
|
|
"sync/atomic"
|
|
)
|
|
|
|
const (
|
|
QuerierContextKey = "querier"
|
|
)
|
|
|
|
type conn interface {
|
|
Commit() error
|
|
Rollback() error
|
|
Begin() (*sql.Tx, error)
|
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
|
QueryRow(query string, args ...any) *sql.Row
|
|
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
|
Exec(query string, args ...any) (sql.Result, error)
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
}
|
|
|
|
type querier struct {
|
|
txRequested atomic.Bool
|
|
initOnce sync.Once
|
|
conn conn
|
|
}
|
|
|
|
func GetQuerierFromContextOrNew(ctx context.Context) *querier {
|
|
q, ok := ctx.Value(QuerierContextKey).(*querier)
|
|
if !ok {
|
|
q = &querier{
|
|
txRequested: atomic.Bool{},
|
|
initOnce: sync.Once{},
|
|
conn: nil,
|
|
}
|
|
}
|
|
return q
|
|
}
|
|
func (q *querier) Begin() *querier {
|
|
q.txRequested.Store(true)
|
|
return q
|
|
}
|
|
func (q *querier) Continue(ctx context.Context, conn conn) (*querier, error) {
|
|
var iErr error
|
|
q.initOnce.Do(func() {
|
|
if q.txRequested.Load() {
|
|
tx, bErr := conn.BeginTx(ctx, nil)
|
|
if bErr != nil {
|
|
iErr = bErr
|
|
return
|
|
}
|
|
q.conn = &SqlTx{tx}
|
|
} else {
|
|
q.conn = conn.(*SqlDB)
|
|
}
|
|
})
|
|
return q, iErr
|
|
}
|
|
func (q *querier) Commit() error {
|
|
return q.conn.Commit()
|
|
}
|
|
func (q *querier) Rollback() error {
|
|
return q.conn.Rollback()
|
|
}
|
|
func (q *querier) Conn() conn {
|
|
return q.conn
|
|
}
|
|
|
|
type SqlTx struct {
|
|
*sql.Tx
|
|
}
|
|
|
|
func (tx *SqlTx) Begin() (*sql.Tx, error) {
|
|
return &sql.Tx{}, nil
|
|
}
|
|
func (tx *SqlTx) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
|
return &sql.Tx{}, nil
|
|
}
|
|
|
|
type SqlDB struct {
|
|
*sql.DB
|
|
}
|
|
|
|
func (db *SqlDB) Commit() error {
|
|
return nil
|
|
}
|
|
|
|
func (db *SqlDB) Rollback() error {
|
|
return nil
|
|
}
|