package postgres import ( "context" "database/sql" "errors" "fmt" "sync" "time" errmsg "git.gocasts.ir/ebhomengo/niki/pkg/err_msg" richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error" _ "github.com/jackc/pgx/v5/stdlib" ) type Config struct { Host string `koanf:"host"` Port int `koanf:"port"` User string `koanf:"user"` Password string `koanf:"password"` DbName string `koanf:"dbName"` SSLMode string `koanf:"sslMode"` MaxIdleConn int `koanf:"maxIdleConns"` MaxOpenConn int `koanf:"maxOpenConns"` ConnMaxLifetime int `koanf:"connMaxLifetime"` } type DB struct { config Config db *sql.DB mu sync.Mutex statements map[statementKey]*sql.Stmt } func (db *DB) Conn() *sql.DB { return db.db } func New(config Config) *DB { dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", config.Host, config.Port, config.User, config.Password, config.DbName, config.SSLMode) db, err := sql.Open("pgx", dsn) if err != nil { panic(fmt.Errorf("can't open postgres db: %w", err)) } maxIdle := config.MaxIdleConn maxOpen := config.MaxOpenConn lifetime := time.Duration(config.ConnMaxLifetime) * time.Second db.SetMaxIdleConns(maxIdle) db.SetMaxOpenConns(maxOpen) db.SetConnMaxLifetime(lifetime) return &DB{ config: config, db: db, statements: make(map[statementKey]*sql.Stmt), } } func (db *DB) PrepareStatement(ctx context.Context, key statementKey, query string) (*sql.Stmt, error) { db.mu.Lock() defer db.mu.Unlock() if stmt, ok := db.statements[key]; ok { return stmt, nil } stmt, err := db.db.PrepareContext(ctx, query) if err != nil { return nil, fmt.Errorf("prepare statement %q: %w", key, err) } db.statements[key] = stmt return stmt, nil } func (db *DB) CloseStatements() error { db.mu.Lock() defer db.mu.Unlock() var lastErr error for key, stmt := range db.statements { if err := stmt.Close(); err != nil { lastErr = err } delete(db.statements, key) } return lastErr } func (db *DB) Close() error { return db.db.Close() } func (db *DB) StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) { return stmt.QueryContext(ctx, args...) } func (db *DB) StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row { return stmt.QueryRowContext(ctx, args...) } func (db *DB) StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) { result, err := stmt.ExecContext(ctx, args...) if err != nil { return nil, err } return result, nil } ///////////////////////// generic query type ScannerFunc[T any] func(scanner Scanner) (T, error) func InstantQueryContext[T any](ctx context.Context, stmtKey statementKey, query string, conn *DB, scanner ScannerFunc[T], args ...any) ([]T, error) { const op = richerror.Op("postgres.InstantQueryContext") readyStmt, err := conn.PrepareStatement(ctx, stmtKey, query) if err != nil { return nil, richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery) } rows, qErr := readyStmt.QueryContext(ctx, args...) if qErr != nil { return nil, richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery) } defer rows.Close() var itemsList []T for rows.Next() { item, sErr := scanner(rows) if sErr != nil { return nil, richerror.New(op).WithErr(sErr).WithMessage(errmsg.ErrorMsgCantScanQueryResult) } itemsList = append(itemsList, item) } if rErr := rows.Err(); rErr != nil { return nil, richerror.New(op).WithErr(rErr).WithMessage(errmsg.ErrorMsgFailedQuery) } return itemsList, nil } func InstantQueryRowContext[T any](ctx context.Context, stmtKey statementKey, query string, conn *DB, scanner ScannerFunc[T], args ...any) (item T, err error) { const op = richerror.Op("postgres.InstantQueryRowContext") readyStmt, sErr := conn.PrepareStatement(ctx, stmtKey, query) if sErr != nil { err = richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery) return } row := readyStmt.QueryRowContext(ctx, args...) item, scErr := scanner(row) if scErr != nil { if errors.Is(scErr, sql.ErrNoRows) { err = richerror.New(op).WithErr(scErr).WithKind(richerror.KindNotFound).WithMessage(errmsg.ErrorMsgCantScanQueryResult) return } err = richerror.New(op).WithErr(scErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgCantScanQueryResult) return } if rErr := row.Err(); rErr != nil { err = richerror.New(op).WithErr(rErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgFailedQuery) return } return } func InstantExecContext(ctx context.Context, stmtKey statementKey, query string, conn *DB, args ...any) (sql.Result, error) { const op = richerror.Op("postgres.InstantExecContext") readyStmt, err := conn.PrepareStatement(ctx, stmtKey, query) if err != nil { return nil, richerror.New(op) } result, err := readyStmt.ExecContext(ctx, args...) if err != nil { return nil, richerror.New(op) } return result, nil }