package querier

import (
	"context"
	"database/sql"
	"sync"
	"sync/atomic"
)

type contextKey string

const QuerierContextKey contextKey = "querier"

type conn interface {
	Commit() error
	Rollback() error
	Begin() (*sql.Tx, error)
	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
	QueryRow(query string, args ...any) *sql.Row
	QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
	Exec(query string, args ...any) (sql.Result, error)
	ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}

type Querier struct {
	txRequested atomic.Bool
	initOnce    sync.Once
	conn        conn
}

func GetQuerierFromContextOrNew(ctx context.Context) *Querier {
	q, ok := ctx.Value(QuerierContextKey).(*Querier)
	if !ok {
		q = &Querier{
			txRequested: atomic.Bool{},
			initOnce:    sync.Once{},
			conn:        nil,
		}
	}

	return q
}

func (q *Querier) Begin() *Querier {
	q.txRequested.Store(true)

	return q
}

func (q *Querier) Continue(ctx context.Context, conn *SQLDB) (*Querier, error) {
	var iErr error
	q.initOnce.Do(func() {
		if q.txRequested.Load() {
			tx, bErr := conn.BeginTx(ctx, nil)
			if bErr != nil {
				iErr = bErr

				return
			}
			q.conn = &SQLTx{tx}
		} else {
			q.conn = conn
		}
	})

	return q, iErr
}

func (q *Querier) Commit() error {
	return q.conn.Commit()
}

func (q *Querier) Rollback() error {
	return q.conn.Rollback()
}

func (q *Querier) Conn() conn {
	return q.conn
}

type SQLTx struct {
	*sql.Tx
}

func (tx *SQLTx) Begin() (*sql.Tx, error) {
	return &sql.Tx{}, nil
}

func (tx *SQLTx) BeginTx(_ context.Context, _ *sql.TxOptions) (*sql.Tx, error) {
	return &sql.Tx{}, nil
}

type SQLDB struct {
	*sql.DB
}

func (db *SQLDB) Commit() error {
	return nil
}

func (db *SQLDB) Rollback() error {
	return nil
}