diff --git a/domain/wallet/entity/transaction.go b/domain/wallet/entity/transaction.go index 09c98888..6a9a7de7 100644 --- a/domain/wallet/entity/transaction.go +++ b/domain/wallet/entity/transaction.go @@ -9,6 +9,7 @@ type Transaction struct { Currency Currency ActionType TransactionType Timestamp time.Time + CreatedAt time.Time } type TransactionType string diff --git a/domain/wallet/entity/wallet.go b/domain/wallet/entity/wallet.go index 175a638d..648ba19a 100644 --- a/domain/wallet/entity/wallet.go +++ b/domain/wallet/entity/wallet.go @@ -7,8 +7,8 @@ type Wallet struct { UserID uint64 // user unique ID Balance float64 Currency Currency - UpdatedAt time.Time Status WalletStatus // "active", "frozen", "closed" + UpdatedAt time.Time } type WalletStatus string diff --git a/domain/wallet/param/create_transaction.go b/domain/wallet/param/create_transaction.go index 1341ed46..9d73e6b1 100644 --- a/domain/wallet/param/create_transaction.go +++ b/domain/wallet/param/create_transaction.go @@ -1,6 +1,8 @@ package param import ( + "time" + "git.gocasts.ir/ebhomengo/niki/domain/wallet/entity" ) @@ -9,7 +11,9 @@ type CreateTransactionRequest struct { Amount float64 `json:"amount"` Currency entity.Currency `json:"currency"` ActionType entity.TransactionType `json:"action_type"` + Timestamp time.Time `json:"timestamp"` } type InsertTransactionResponse struct { + Balance float64 `json:"balance"` } diff --git a/domain/wallet/param/transaction_history.go b/domain/wallet/param/transaction_history.go index fe720528..0a4a3332 100644 --- a/domain/wallet/param/transaction_history.go +++ b/domain/wallet/param/transaction_history.go @@ -4,14 +4,17 @@ import ( "time" "git.gocasts.ir/ebhomengo/niki/domain/wallet/entity" + "git.gocasts.ir/ebhomengo/niki/pkg/database/postgres" ) type TransactionRequest struct { - UserID uint64 `json:"user_id"` + UserID uint64 `json:"user_id"` + Pagination postgres.RequestPagination `json:"pagination"` } type TransactionResponse struct { - Transaction []TransactionInfo `json:"transaction"` + Transaction []TransactionInfo `json:"transactions"` + Pagination postgres.ResponsePagination `json:"pagination"` } type TransactionInfo struct { diff --git a/domain/wallet/repository/postgres/create_transaction.go b/domain/wallet/repository/postgres/create_transaction.go new file mode 100644 index 00000000..5841d5af --- /dev/null +++ b/domain/wallet/repository/postgres/create_transaction.go @@ -0,0 +1,99 @@ +package postgres + +import ( + "context" + + "git.gocasts.ir/ebhomengo/niki/domain/wallet/entity" + "git.gocasts.ir/ebhomengo/niki/pkg/database/postgres" + errmsg "git.gocasts.ir/ebhomengo/niki/pkg/err_msg" + richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error" +) + +func (db *DB) InsertTransaction(ctx context.Context, transaction entity.Transaction) (balance float64, err error) { + const op = richerror.Op("wallet.repo.InsertTransaction") + + query := `INSERT INTO Transactions (user_id, amount ,currency ,action_type, timestamp) values ($1, $2, $3, $4, $5)` + + stmt, stErr := db.conn.PrepareStatement(ctx, postgres.StatementKeyWalletInsertTransaction, query) + if stErr != nil { + err = richerror.New(op).WithErr(stErr).WithKind(richerror.KindUnexpected) + return + } + + txHolder, newCtx := postgres.GetDBTxHolderFromContextOrNew(ctx) + tx, txErr := txHolder.BeginTx(newCtx, db.conn.Conn()) + if txErr != nil { + err = richerror.New(op).WithErr(txErr).WithKind(richerror.KindUnexpected) + return + } + + defer func() { + if tx != nil { + if err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + // log rbErr + } + } else if cErr := tx.Commit(); cErr != nil { + // log cErr + } + + } + + }() + + params := []any{transaction.UserID, transaction.Currency, transaction.Amount, transaction.ActionType, transaction.Timestamp} + + _, execErr := tx.StmtExecContext(newCtx, stmt, params...) + if execErr != nil { + err = richerror.New(op).WithErr(execErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgCantInsertTransaction) + + return + } + + newBalance, blanceErr := db.UpsertBalance(newCtx, transaction.UserID, transaction.Amount) + if blanceErr != nil { + err = richerror.New(op).WithErr(blanceErr) + return + + } + + balance = newBalance + return +} + +func (db *DB) UpsertBalance(newCtx context.Context, userID uint64, amount float64) (balance float64, err error) { + const op = richerror.Op("wallet.repo.UpdateBalance") + + txHolder, _ := postgres.GetDBTxHolderFromContextOrNew(newCtx) + tx, txErr := txHolder.Conn() + + if txErr != nil { + err = richerror.New(op).WithMessage(errmsg.ErrorMsgCantUpsertBalance).WithKind(richerror.KindUnexpected) + return + } + + if tx == nil { + err = richerror.New(op).WithMessage(errmsg.ErrorMsgCantUpsertBalance).WithKind(richerror.KindUnexpected) + return + } + + upsertQuery := `INSERT INTO wallets (user_id, balance) VALUES ($1, $2) + ON CONFLICT (user_id) DO UPDATE SET balance = wallets.balance + $2 + RETURNING balance` + + upsertStmt, stErr := db.conn.PrepareStatement(newCtx, postgres.StatementKeyWalletUpsertBalance, upsertQuery) + if stErr != nil { + err = richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery).WithKind(richerror.KindUnexpected) + return + } + + row := tx.StmtQueryRowContext(newCtx, upsertStmt, userID, amount) + + sErr := row.Scan(&balance) + if sErr != nil { + err = richerror.New(op).WithMessage(errmsg.ErrorMsgCantScanQueryResult).WithKind(richerror.KindUnexpected) + return + } + + return balance, nil +} diff --git a/domain/wallet/repository/postgres/db.go b/domain/wallet/repository/postgres/db.go index 787efa54..61367094 100644 --- a/domain/wallet/repository/postgres/db.go +++ b/domain/wallet/repository/postgres/db.go @@ -2,6 +2,9 @@ package postgres import "git.gocasts.ir/ebhomengo/niki/pkg/database/postgres" +type Config struct { +} + type DB struct { conn *postgres.DB } diff --git a/domain/wallet/repository/postgres/get_transaction_list.go b/domain/wallet/repository/postgres/get_transaction_list.go new file mode 100644 index 00000000..d83a54e7 --- /dev/null +++ b/domain/wallet/repository/postgres/get_transaction_list.go @@ -0,0 +1,72 @@ +package postgres + +import ( + "context" + + "git.gocasts.ir/ebhomengo/niki/domain/wallet/entity" + "git.gocasts.ir/ebhomengo/niki/pkg/database/postgres" + errmsg "git.gocasts.ir/ebhomengo/niki/pkg/err_msg" + richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error" +) + +func (db *DB) GetTransactionListByUserID(ctx context.Context, userID uint64, dbPagination postgres.DBPagination) ([]entity.Transaction, int64, error) { + const op = richerror.Op("Wallet.repo.GetTransactionListByUserID") + + query := `SELECT * FROM Transactions WHERE user_id = $1 AND transaction_timestamp < $2 ORDER BY transaction_id DESC LIMIT $3` + + stmt, StErr := db.conn.PrepareStatement(ctx, postgres.StatementKeyAWalletGetTransactionHistory, query) + + if StErr != nil { + + return nil, 0, richerror.New(op).WithErr(StErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgFailedQuery) + + } + + totalRetrieveRecord := (dbPagination.MaxNextPages-1)*dbPagination.PageSize + 1 + + lastTimeStamp := dbPagination.LastTimeStamp + + queryRows, err := stmt.QueryContext(ctx, userID, lastTimeStamp, totalRetrieveRecord) + + if err != nil { + return nil, 0, richerror.New(op).WithErr(StErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgFailedQuery) + } + + defer queryRows.Close() + + var transactions []entity.Transaction + + var lenCounter int64 + + for queryRows.Next() { + lenCounter++ + + if lenCounter < dbPagination.PageSize+1 { + transaction, err := scanTransaction(queryRows) + + if err != nil { + return nil, 0, richerror.New(op).WithErr(err).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgCantScanQueryResult) + } + + transactions = append(transactions, transaction) + + } + + } + + if qErr := queryRows.Err(); qErr != nil { + + return nil, 0, richerror.New(op).WithErr(qErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgCantScanQueryResult) + } + + return transactions, lenCounter, nil + +} + +func scanTransaction(scanner postgres.Scanner) (transaction entity.Transaction, err error) { + err = scanner.Scan(&transaction.UserID, &transaction.Currency, &transaction.Amount, &transaction.ActionType, &transaction.Timestamp) + if err != nil { + return + } + return +} diff --git a/domain/wallet/repository/postgres/wallet.go b/domain/wallet/repository/postgres/wallet.go new file mode 100644 index 00000000..747ce266 --- /dev/null +++ b/domain/wallet/repository/postgres/wallet.go @@ -0,0 +1,40 @@ +package postgres + +import ( + "context" + + "git.gocasts.ir/ebhomengo/niki/domain/wallet/entity" + "git.gocasts.ir/ebhomengo/niki/pkg/database/postgres" + richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error" +) + +func (db *DB) GetWalletByUserID(ctx context.Context, UserID uint64) (entity.Wallet, error) { + const op = richerror.Op("Wallet.repo.GetWalletByUserID") + query := `SELECT * FROM wallets WHERE user_id = $1` + stmt, stErr := db.conn.PrepareStatement(ctx, postgres.StatementKeyWalletGetUserWallet, query) + if stErr != nil { + return entity.Wallet{}, richerror.New(op).WithErr(stErr) + } + walletRow := db.conn.StmtQueryRowContext(ctx, stmt, UserID) + + wallet, sErr := scanWallet(walletRow) + + if sErr != nil { + return entity.Wallet{}, richerror.New(op).WithErr(sErr) + } + + return wallet, nil + +} + +func scanWallet(scanner postgres.Scanner) (entity.Wallet, error) { + + var wallet entity.Wallet + + err := scanner.Scan(&wallet.ID, &wallet.Balance, &wallet.Currency, &wallet.Status, &wallet.UpdatedAt) + if err != nil { + return entity.Wallet{}, err + } + + return wallet, nil +} diff --git a/domain/wallet/service/create_transaction.go b/domain/wallet/service/create_transaction.go index cfd98ac6..9a2c5e03 100644 --- a/domain/wallet/service/create_transaction.go +++ b/domain/wallet/service/create_transaction.go @@ -19,11 +19,12 @@ func (s Service) CreateTransaction(ctx context.Context, request param.CreateTran Amount: request.Amount, Currency: request.Currency, ActionType: request.ActionType, - Timestamp: time.Now(), + Timestamp: request.Timestamp, + CreatedAt: time.Now(), } err := s.repo.InsertTransaction(ctx, transaction) if err != nil { - return param.InsertTransactionResponse{}, err + return param.InsertTransactionResponse{}, richerror.New(op).WithErr(err) } return param.InsertTransactionResponse{}, nil diff --git a/domain/wallet/service/service.go b/domain/wallet/service/service.go index 3f949184..38e97cb2 100644 --- a/domain/wallet/service/service.go +++ b/domain/wallet/service/service.go @@ -4,15 +4,18 @@ import ( "context" "git.gocasts.ir/ebhomengo/niki/domain/wallet/entity" + "git.gocasts.ir/ebhomengo/niki/pkg/database/postgres" ) type Repository interface { - GetTransactionListByUserID(ctx context.Context, UserID uint64) ([]entity.Transaction, error) + GetTransactionListByUserID(ctx context.Context, UserID uint64, DBPagination postgres.DBPagination) ([]entity.Transaction, int64, error) GetWalletByUserID(ctx context.Context, UserID uint64) (entity.Wallet, error) InsertTransaction(ctx context.Context, transaction entity.Transaction) error } type Config struct { + PageSize int64 `koanf:"page_size"` + MaxNextPages int64 `koanf:"max_next_pages"` } type Service struct { diff --git a/domain/wallet/service/transaction_history.go b/domain/wallet/service/transaction_history.go index 941b8da5..9dc81f90 100644 --- a/domain/wallet/service/transaction_history.go +++ b/domain/wallet/service/transaction_history.go @@ -2,22 +2,46 @@ package service import ( "context" + "time" "git.gocasts.ir/ebhomengo/niki/domain/wallet/entity" "git.gocasts.ir/ebhomengo/niki/domain/wallet/param" + "git.gocasts.ir/ebhomengo/niki/pkg/database/postgres" richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error" ) func (s Service) GetUserTransactionHistory(ctx context.Context, request param.TransactionRequest) (param.TransactionResponse, error) { const op = richerror.Op("wallet.service.GetUserTransactionHistory") - - transactionList, err := s.repo.GetTransactionListByUserID(ctx, request.UserID) - - if err != nil { - return param.TransactionResponse{}, err + lastTimeStamp := request.Pagination.LastTimeStamp + if lastTimeStamp.IsZero() { + lastTimeStamp = time.Now() } - return param.TransactionResponse{Transaction: transactionEntityToTransactionInfo(transactionList)}, nil + dbPagination := postgres.DBPagination{ + LastTimeStamp: lastTimeStamp, + PageNumber: request.Pagination.PageNumber, + MaxNextPages: s.cfg.MaxNextPages, + PageSize: s.cfg.PageSize, + } + + transactionList, listLen, err := s.repo.GetTransactionListByUserID(ctx, request.UserID, dbPagination) + + showableNextPagesNum := postgres.ComputeNextPages(listLen, s.cfg.PageSize, s.cfg.MaxNextPages) + + if err != nil { + return param.TransactionResponse{}, richerror.New(op).WithErr(err) + } + + paginationInfo := postgres.ResponsePagination{ + PageNumber: request.Pagination.PageNumber, + PageSize: s.cfg.PageSize, + ShowableNextPagesNum: showableNextPagesNum, + } + + return param.TransactionResponse{ + Transaction: transactionEntityToTransactionInfo(transactionList), + Pagination: paginationInfo, + }, nil } diff --git a/domain/wallet/service/user_wallet.go b/domain/wallet/service/user_wallet.go index 12ac49b7..714f77f0 100644 --- a/domain/wallet/service/user_wallet.go +++ b/domain/wallet/service/user_wallet.go @@ -13,7 +13,7 @@ func (s Service) GetUserWallet(ctx context.Context, request param.WalletRequest) wallet, err := s.repo.GetWalletByUserID(ctx, request.UserID) if err != nil { - return param.WalletResponse{}, err + return param.WalletResponse{}, richerror.New(op).WithErr(err) } return param.WalletResponse{ diff --git a/pkg/database/postgres/db.go b/pkg/database/postgres/db.go index b843be3d..9d99f803 100644 --- a/pkg/database/postgres/db.go +++ b/pkg/database/postgres/db.go @@ -7,31 +7,30 @@ import ( "sync" "time" - querier "git.gocasts.ir/ebhomengo/niki/pkg/query_transaction/sql" _ "github.com/jackc/pgx/v5/stdlib" ) - + type Config struct { - Host string `koanf:"host"` - Port int `koanf:"port"` - User string `koanf:"user"` - Password string `koanf:"password"` - DbName string `koanf:"dbName"` - SSLMode string `koanf:"sslMode"` - MaxIdleConn int `koanf:"maxIdleConns"` - MaxOpenConn int `koanf:"maxOpenConns"` - ConnMaxLifetime int `koanf:"connMaxLifetime"` - PathOfMigrations string `koanf:"pathOfMigrations"` + Host string `koanf:"host"` + Port int `koanf:"port"` + User string `koanf:"user"` + Password string `koanf:"password"` + DbName string `koanf:"dbName"` + SSLMode string `koanf:"sslMode"` + MaxIdleConn int `koanf:"maxIdleConns"` + MaxOpenConn int `koanf:"maxOpenConns"` + ConnMaxLifetime int `koanf:"connMaxLifetime"` } type DB struct { config Config - db *querier.SQLDB + db *sql.DB mu sync.Mutex statements map[statementKey]*sql.Stmt } -func (db *DB) Conn() *querier.SQLDB { +func (db *DB) Conn() *sql.DB { + return db.db } @@ -56,7 +55,7 @@ func New(config Config) *DB { return &DB{ config: config, - db: &querier.SQLDB{DB: db}, + db: db, statements: make(map[statementKey]*sql.Stmt), } } @@ -93,5 +92,22 @@ func (db *DB) CloseStatements() error { } func (db *DB) Close() error { - return db.db.DB.Close() + return db.db.Close() +} + +func (db *DB) StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) { + return stmt.QueryContext(ctx, args...) +} + +func (db *DB) StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row { + return stmt.QueryRowContext(ctx, args...) +} + +func (db *DB) StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) { + result, err := stmt.ExecContext(ctx, args...) + if err != nil { + return nil, err + } + + return result, nil } diff --git a/domain/wallet/repository/postgres/dbconfig.yml b/pkg/database/postgres/dbconfig.yml similarity index 87% rename from domain/wallet/repository/postgres/dbconfig.yml rename to pkg/database/postgres/dbconfig.yml index af051ab0..c216fd7f 100644 --- a/domain/wallet/repository/postgres/dbconfig.yml +++ b/pkg/database/postgres/dbconfig.yml @@ -2,4 +2,4 @@ production: dialect: postgres datasource: "host=127.0.0.1 port=5432 user=wallet password=wallet2123 dbname=wallet_db sslmode=disable" dir: domain/wallet/repository/postgres/migrations - table: wallet_migrationsns \ No newline at end of file + table: gorp_migrationsns \ No newline at end of file diff --git a/pkg/database/postgres/migrator/migrator.go b/pkg/database/postgres/migrator/migrator.go index 0cbadd99..e8893c74 100644 --- a/pkg/database/postgres/migrator/migrator.go +++ b/pkg/database/postgres/migrator/migrator.go @@ -1 +1,113 @@ package migrator + +import ( + "database/sql" + "fmt" + "os" + "text/tabwriter" + + "git.gocasts.ir/ebhomengo/niki/pkg/database/postgres" + _ "github.com/jackc/pgx/v5/stdlib" + migrate "github.com/rubenv/sql-migrate" +) + +type Config struct { + migrationDBName string `koanf:"migration_db_name"` + pathOfMigrations string `koanf:"pathOfMigrations"` + dbConfig postgres.Config +} + +type Migrator struct { + cfg Config + dialect string + migrations *migrate.FileMigrationSource +} + +func New(cfg Config) *Migrator { + + return &Migrator{ + cfg: cfg, + dialect: "psx", + migrations: &migrate.FileMigrationSource{Dir: cfg.pathOfMigrations}, + } + +} + +func (m *Migrator) getDsn() string { + + return fmt.Sprintf("host=%s user=%s password=%s dbname=%s sslmode=disable", + m.cfg.dbConfig.Host, m.cfg.dbConfig.User, m.cfg.dbConfig.Password, m.cfg.dbConfig.DbName) + +} + +func (m *Migrator) Up() { + db, err := sql.Open(m.dialect, m.getDsn()) + + if err != nil { + panic(fmt.Errorf("can't open postgres db: %v", err)) + + } + + defer db.Close() + if err != nil { + return + } + + n, err := migrate.Exec(db, m.dialect, m.migrations, migrate.Up) + if err != nil { + panic(fmt.Errorf("cant apply migrations : %v ", err)) + + } + fmt.Printf("Applied %d migrations\n", n) + +} + +func (m *Migrator) Down() { + migrate.SetTable(m.cfg.migrationDBName) + + db, err := sql.Open(m.dialect, m.getDsn()) + + if err != nil { + panic(fmt.Errorf("can't open postgres db: %v", err)) + + } + defer db.Close() + + n, err := migrate.Exec(db, m.dialect, m.migrations, migrate.Down) + + if err != nil { + panic(fmt.Errorf("cant rollback migrations : %v ", err)) + + } + fmt.Printf("Applied %d migrations\n", n) + +} + +func (m *Migrator) Status() { + migrate.SetTable(m.cfg.migrationDBName) + + db, err := sql.Open(m.dialect, m.getDsn()) + if err != nil { + panic(fmt.Errorf("can't open postgres db: %v", err)) + } + defer db.Close() + + migrations, _, err := migrate.PlanMigration(db, m.dialect, m.migrations, migrate.Up, 0) + if err != nil { + panic(fmt.Errorf("can't plan migrations: %v", err)) + } + + if len(migrations) == 0 { + fmt.Println("✅ No pending migrations.") + return + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + fmt.Fprintln(w, "PENDING MIGRATIONS") + fmt.Fprintln(w, "ID\tSTATUS") + for _, migration := range migrations { + fmt.Fprintf(w, "%s\tPending\n", migration.Id) + } + w.Flush() + +} diff --git a/pkg/database/postgres/pagination.go b/pkg/database/postgres/pagination.go new file mode 100644 index 00000000..a79ab45d --- /dev/null +++ b/pkg/database/postgres/pagination.go @@ -0,0 +1,35 @@ +package postgres + +import "time" + +type RequestPagination struct { + PageNumber int64 `json:"page_number"` + LastTimeStamp time.Time `json:"last_time_stamp"` +} + +type ResponsePagination struct { + PageNumber int64 `json:"page_number"` + PageSize int64 `json:"page_size"` + ShowableNextPagesNum int64 `json:"showable_next_pages_num"` +} + +type DBPagination struct { + LastTimeStamp time.Time + PageNumber int64 + MaxNextPages int64 + PageSize int64 +} + +func ComputeNextPages(listLen int64, pageSize int64, maxNextPages int64) int64 { + + pages := float64(listLen) / float64(pageSize) + + for i := maxNextPages - 1; i >= 0; i-- { + if pages > float64(i) { + return i + 1 + } + } + + return 0 + +} diff --git a/pkg/database/postgres/prepared_statement.go b/pkg/database/postgres/prepared_statement.go index cebc2703..5b5382f8 100644 --- a/pkg/database/postgres/prepared_statement.go +++ b/pkg/database/postgres/prepared_statement.go @@ -3,7 +3,8 @@ package postgres type statementKey uint const ( - StatementKeyAWalletGetTransactionHistory statementKey = iota + 1 - StatementKeyWalletInsertTransaction - StatementKeyWalletGetUserWallet + StatementKeyAWalletGetTransactionHistory statementKey = iota + 1 //wallet + StatementKeyWalletInsertTransaction //wallet + StatementKeyWalletGetUserWallet //wallet + StatementKeyWalletUpsertBalance //wallet ) diff --git a/pkg/database/postgres/scanner.go b/pkg/database/postgres/scanner.go new file mode 100644 index 00000000..12fed5a1 --- /dev/null +++ b/pkg/database/postgres/scanner.go @@ -0,0 +1,5 @@ +package postgres + +type Scanner interface { + Scan(dest ...any) error +} diff --git a/pkg/database/postgres/transaction_handler.go b/pkg/database/postgres/transaction_handler.go new file mode 100644 index 00000000..35fae0fb --- /dev/null +++ b/pkg/database/postgres/transaction_handler.go @@ -0,0 +1,178 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "sync" + + "github.com/lib/pq" +) + +type contextKey string + +const DBTxHolderContextKey contextKey = "txholder" + +type conn interface { + Commit() error + Rollback() error + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) + StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row + StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) +} + +type TxConn struct { + *sql.Tx +} + +type DBTxHolder struct { + conn conn + mu sync.Mutex +} + +func (db *DBTxHolder) Conn() (conn, error) { + if db.conn == nil { + return nil, fmt.Errorf("Conn() called before BeginTx()") + } + return db.conn, nil +} + +func GetDBTxHolderFromContextOrNew(ctx context.Context) (*DBTxHolder, context.Context) { + db, ok := ctx.Value(DBTxHolderContextKey).(*DBTxHolder) + if !ok { + db = &DBTxHolder{ + conn: nil, + } + ctx = context.WithValue(ctx, DBTxHolderContextKey, db) + } + + return db, ctx +} + +func (db *DBTxHolder) BeginTx(ctx context.Context, conn *sql.DB) (conn, error) { + dbTransactionConn, err := db.Continue(ctx, conn) + if err != nil { + return nil, err + } + return dbTransactionConn, nil +} + +func (db *DBTxHolder) Continue(ctx context.Context, conn *sql.DB) (conn, error) { + //db.mu.Lock() + //defer db.mu.Unlock() + + var err error + if db.conn == nil { + + db.conn, err = db.txFactory(ctx, conn) + if err != nil { + return nil, err + } + + } + + return db.conn, err +} + +func (db *DBTxHolder) txFactory(ctx context.Context, rConn *sql.DB) (*TxConn, error) { + + tx, bErr := rConn.BeginTx(ctx, nil) + if bErr != nil { + + return nil, bErr + } + + return &TxConn{Tx: tx}, nil +} + +func (db *DBTxHolder) Commit() error { + //db.mu.Lock() + //defer db.mu.Unlock() + if db.conn == nil { + return fmt.Errorf("no active transaction") + } + err := db.conn.Commit() + db.conn = nil + return err +} + +func (db *DBTxHolder) Rollback() error { + //db.mu.Lock() + //defer db.mu.Unlock() + if db.conn == nil { + return fmt.Errorf("no active transaction") + } + err := db.conn.Rollback() + db.conn = nil + + return err +} + +func (db *DBTxHolder) SavePoint(ctx context.Context, key string) error { + //db.mu.Lock() + //defer db.mu.Unlock() + if db.conn == nil { + return fmt.Errorf("no active transaction") + } + quoted := pq.QuoteIdentifier(key) + _, err := db.conn.ExecContext(ctx, "SAVEPOINT "+quoted) + return err +} + +func (db *DBTxHolder) RollbackSavePoint(ctx context.Context, key string) error { + //db.mu.Lock() + //defer db.mu.Unlock() + if db.conn == nil { + return fmt.Errorf("no active transaction") + } + + quoted := pq.QuoteIdentifier(key) + + _, err := db.conn.ExecContext(ctx, "ROLLBACK TO SAVEPOINT "+quoted) + + if err != nil { + return err + + } + + return nil + +} + +func (db *DBTxHolder) ReleaseSavePoint(ctx context.Context, key string) error { + //db.mu.Lock() + //defer db.mu.Unlock() + if db.conn == nil { + return fmt.Errorf("no active transaction") + } + quoted := pq.QuoteIdentifier(key) + _, err := db.conn.ExecContext(ctx, "RELEASE SAVEPOINT "+quoted) + + if err != nil { + return err + } + return nil +} + +func (tx *TxConn) StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) { + txStmt := tx.StmtContext(ctx, stmt) + return txStmt.QueryContext(ctx, args...) +} + +func (tx *TxConn) StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row { + txStmt := tx.StmtContext(ctx, stmt) + return txStmt.QueryRowContext(ctx, args...) +} + +func (tx *TxConn) StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) { + txStmt := tx.StmtContext(ctx, stmt) + result, err := txStmt.ExecContext(ctx, args...) + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/pkg/err_msg/message.go b/pkg/err_msg/message.go index c6c4fb81..e12a018c 100644 --- a/pkg/err_msg/message.go +++ b/pkg/err_msg/message.go @@ -58,4 +58,9 @@ const ( ErrorMsgInvalidRefreshToken = "invalid refresh token" ErrorMsgInvalidBenefactorStatus = "invalid benefactor status" ErrorMsgInvalidAction = "action invalid" + ErrorMsgCantUpsertBalance = "cant update balance" // wallet + ErrorMsgCantGetBalance = "cant update balance" // wallet + ErrorMsgCantInsertTransaction = "cant insert transaction" // wallet + ErrorMsgFailedQuery = "query failed" // wallet + ) diff --git a/vendor/github.com/jackc/pgx/v5/CLAUDE.md b/vendor/github.com/jackc/pgx/v5/CLAUDE.md new file mode 100644 index 00000000..e3ed1a2e --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/CLAUDE.md @@ -0,0 +1,73 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +pgx is a PostgreSQL driver and toolkit for Go (`github.com/jackc/pgx/v5`). It provides both a native PostgreSQL interface and a `database/sql` compatible driver. Requires Go 1.25+ and supports PostgreSQL 14+ and CockroachDB. + +## Build & Test Commands + +```bash +# Run all tests (requires PGX_TEST_DATABASE to be set) +go test ./... + +# Run a specific test +go test -run TestFunctionName ./... + +# Run tests for a specific package +go test ./pgconn/... + +# Run tests with race detector +go test -race ./... + +# DevContainer: run tests against specific PostgreSQL versions +./test.sh pg18 # Default: PostgreSQL 18 +./test.sh pg16 -run TestConnect # Specific test against PG16 +./test.sh crdb # CockroachDB +./test.sh all # All targets (pg14-18 + crdb) + +# Format (always run after making changes) +goimports -w . + +# Lint +golangci-lint run ./... +``` + +## Test Database Setup + +Tests require `PGX_TEST_DATABASE` environment variable. In the devcontainer, `test.sh` handles this. For local development: + +```bash +export PGX_TEST_DATABASE="host=localhost user=postgres password=postgres dbname=pgx_test" +``` + +The test database needs extensions: `hstore`, `ltree`, and a `uint64` domain. See `testsetup/postgresql_setup.sql` for full setup. Many tests are skipped unless additional `PGX_TEST_*` env vars are set (for TLS, SCRAM, MD5, unix socket, PgBouncer testing). + +## Architecture + +The codebase is a layered architecture, bottom-up: + +- **pgproto3/** — PostgreSQL wire protocol v3 encoder/decoder. Defines `FrontendMessage` and `BackendMessage` types for every protocol message. +- **pgconn/** — Low-level connection layer (roughly libpq-equivalent). Handles authentication, TLS, query execution, COPY protocol, and notifications. `PgConn` is the core type. +- **pgx** (root package) — High-level query interface built on `pgconn`. Provides `Conn`, `Rows`, `Tx`, `Batch`, `CopyFrom`, and generic helpers like `CollectRows`/`ForEachRow`. Includes automatic statement caching (LRU). +- **pgtype/** — Type system mapping between Go and PostgreSQL types (70+ types). Key interfaces: `Codec`, `Type`, `TypeMap`. Custom types (enums, composites, domains) are registered through `TypeMap`. +- **pgxpool/** — Concurrency-safe connection pool built on `puddle/v2`. `Pool` is the main type; wraps `pgx.Conn`. +- **stdlib/** — `database/sql` compatibility adapter. + +Supporting packages: +- **internal/stmtcache/** — Prepared statement cache with LRU eviction +- **internal/sanitize/** — SQL query sanitization +- **tracelog/** — Logging adapter that implements tracer interfaces +- **multitracer/** — Composes multiple tracers into one +- **pgxtest/** — Test helpers for running tests across connection types + +## Key Design Conventions + +- **Semantic versioning** — strictly followed. Do not break the public API (no removing or renaming exported types, functions, methods, or fields; no changing function signatures). +- **Minimal dependencies** — adding new dependencies is strongly discouraged (see CONTRIBUTING.md). +- **Context-based** — all blocking operations take `context.Context`. +- **Tracer interfaces** — observability via `QueryTracer`, `BatchTracer`, `CopyFromTracer`, `PrepareTracer` on `ConnConfig.Tracer`. +- **Formatting** — always run `goimports -w .` after making changes to ensure code is properly formatted. CI checks formatting via `gofmt -l -s -w . && git diff --exit-code`. `gofumpt` with extra rules is also enforced via `golangci-lint`. +- **Linters** — `govet` and `ineffassign` only (configured in `.golangci.yml`). +- **CI matrix** — tests run against Go 1.25/1.26 × PostgreSQL 14-18 + CockroachDB, on Linux and Windows. Race detector enabled on Linux only. diff --git a/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmark.sh b/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmark.sh new file mode 100644 index 00000000..b4ee3fe7 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmark.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +current_branch=$(git rev-parse --abbrev-ref HEAD) +if [ "$current_branch" == "HEAD" ]; then + current_branch=$(git rev-parse HEAD) +fi + +restore_branch() { + echo "Restoring original branch/commit: $current_branch" + git checkout "$current_branch" +} +trap restore_branch EXIT + +# Check if there are uncommitted changes +if ! git diff --quiet || ! git diff --cached --quiet; then + echo "There are uncommitted changes. Please commit or stash them before running this script." + exit 1 +fi + +# Ensure that at least one commit argument is passed +if [ "$#" -lt 1 ]; then + echo "Usage: $0 ... " + exit 1 +fi + +commits=("$@") +benchmarks_dir=benchmarks + +if ! mkdir -p "${benchmarks_dir}"; then + echo "Unable to create dir for benchmarks data" + exit 1 +fi + +# Benchmark results +bench_files=() + +# Run benchmark for each listed commit +for i in "${!commits[@]}"; do + commit="${commits[i]}" + git checkout "$commit" || { + echo "Failed to checkout $commit" + exit 1 + } + + # Sanitized commit message + commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_') + + # Benchmark data will go there + bench_file="${benchmarks_dir}/${i}_${commit_message}.bench" + + if ! go test -bench=. -count=10 >"$bench_file"; then + echo "Benchmarking failed for commit $commit" + exit 1 + fi + + bench_files+=("$bench_file") +done + +# go install golang.org/x/perf/cmd/benchstat[@latest] +benchstat "${bench_files[@]}" diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go b/vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go new file mode 100644 index 00000000..991f6585 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go @@ -0,0 +1,67 @@ +package pgconn + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/jackc/pgx/v5/pgproto3" +) + +func (c *PgConn) oauthAuth(ctx context.Context) error { + if c.config.OAuthTokenProvider == nil { + return errors.New("OAuth authentication required but no token provider configured") + } + + token, err := c.config.OAuthTokenProvider(ctx) + if err != nil { + return fmt.Errorf("failed to obtain OAuth token: %w", err) + } + + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1 + initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01") + + saslInitialResponse := &pgproto3.SASLInitialResponse{ + AuthMechanism: "OAUTHBEARER", + Data: initialResponse, + } + c.frontend.Send(saslInitialResponse) + err = c.flushWithPotentialWriteReadDeadlock() + if err != nil { + return err + } + + msg, err := c.receiveMessage() + if err != nil { + return err + } + + switch m := msg.(type) { + case *pgproto3.AuthenticationOk: + return nil + case *pgproto3.AuthenticationSASLContinue: + // Server sent error response in SASL continue + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2 + // https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3 + errResponse := struct { + Status string `json:"status"` + Scope string `json:"scope"` + OpenIDConfiguration string `json:"openid-configuration"` + }{} + err := json.Unmarshal(m.Data, &errResponse) + if err != nil { + return fmt.Errorf("invalid OAuth error response from server: %w", err) + } + + // Per RFC 7628 section 3.2.3, we should send a SASLResponse which only contains \x01. + // However, since the connection will be closed anyway, we can skip this + return fmt.Errorf("OAuth authentication failed: %s", errResponse.Status) + + case *pgproto3.ErrorResponse: + return ErrorResponseToPgError(m) + + default: + return fmt.Errorf("unexpected message type during OAuth auth: %T", msg) + } +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go b/vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go new file mode 100644 index 00000000..43bd7ec6 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go @@ -0,0 +1,93 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type NegotiateProtocolVersion struct { + NewestMinorProtocol uint32 + UnrecognizedOptions []string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*NegotiateProtocolVersion) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *NegotiateProtocolVersion) Decode(src []byte) error { + if len(src) < 8 { + return &invalidMessageLenErr{messageType: "NegotiateProtocolVersion", expectedLen: 8, actualLen: len(src)} + } + + dst.NewestMinorProtocol = binary.BigEndian.Uint32(src[:4]) + optionCount := int(binary.BigEndian.Uint32(src[4:8])) + + rp := 8 + + // Use the remaining message size as an upper bound for capacity to prevent + // malicious optionCount values from causing excessive memory allocation. + capHint := optionCount + if remaining := len(src) - rp; capHint > remaining { + capHint = remaining + } + dst.UnrecognizedOptions = make([]string, 0, capHint) + for i := 0; i < optionCount; i++ { + if rp >= len(src) { + return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"} + } + end := rp + for end < len(src) && src[end] != 0 { + end++ + } + if end >= len(src) { + return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"} + } + dst.UnrecognizedOptions = append(dst.UnrecognizedOptions, string(src[rp:end])) + rp = end + 1 + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NegotiateProtocolVersion) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'v') + dst = pgio.AppendUint32(dst, src.NewestMinorProtocol) + dst = pgio.AppendUint32(dst, uint32(len(src.UnrecognizedOptions))) + for _, option := range src.UnrecognizedOptions { + dst = append(dst, option...) + dst = append(dst, 0) + } + return finishMessage(dst, sp) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src NegotiateProtocolVersion) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + NewestMinorProtocol uint32 + UnrecognizedOptions []string + }{ + Type: "NegotiateProtocolVersion", + NewestMinorProtocol: src.NewestMinorProtocol, + UnrecognizedOptions: src.UnrecognizedOptions, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *NegotiateProtocolVersion) UnmarshalJSON(data []byte) error { + var msg struct { + NewestMinorProtocol uint32 + UnrecognizedOptions []string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.NewestMinorProtocol = msg.NewestMinorProtocol + dst.UnrecognizedOptions = msg.UnrecognizedOptions + return nil +} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go b/vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go new file mode 100644 index 00000000..b357948a --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go @@ -0,0 +1,507 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type TSVectorScanner interface { + ScanTSVector(TSVector) error +} + +type TSVectorValuer interface { + TSVectorValue() (TSVector, error) +} + +// TSVector represents a PostgreSQL tsvector value. +type TSVector struct { + Lexemes []TSVectorLexeme + Valid bool +} + +// TSVectorLexeme represents a lexeme within a tsvector, consisting of a word and its positions. +type TSVectorLexeme struct { + Word string + Positions []TSVectorPosition +} + +// ScanTSVector implements the [TSVectorScanner] interface. +func (t *TSVector) ScanTSVector(v TSVector) error { + *t = v + return nil +} + +// TSVectorValue implements the [TSVectorValuer] interface. +func (t TSVector) TSVectorValue() (TSVector, error) { + return t, nil +} + +func (t TSVector) String() string { + buf, _ := encodePlanTSVectorCodecText{}.Encode(t, nil) + return string(buf) +} + +// Scan implements the [database/sql.Scanner] interface. +func (t *TSVector) Scan(src any) error { + if src == nil { + *t = TSVector{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToTSVectorScanner{}.scanString(src, t) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (t TSVector) Value() (driver.Value, error) { + if !t.Valid { + return nil, nil + } + + buf, err := TSVectorCodec{}.PlanEncode(nil, 0, TextFormatCode, t).Encode(t, nil) + if err != nil { + return nil, err + } + + return string(buf), nil +} + +// TSVectorWeight represents the weight label of a lexeme position in a tsvector. +type TSVectorWeight byte + +const ( + TSVectorWeightA = TSVectorWeight('A') + TSVectorWeightB = TSVectorWeight('B') + TSVectorWeightC = TSVectorWeight('C') + TSVectorWeightD = TSVectorWeight('D') +) + +// tsvectorWeightToBinary converts a TSVectorWeight to the 2-bit binary encoding used by PostgreSQL. +func tsvectorWeightToBinary(w TSVectorWeight) uint16 { + switch w { + case TSVectorWeightA: + return 3 + case TSVectorWeightB: + return 2 + case TSVectorWeightC: + return 1 + default: + return 0 // D or unset + } +} + +// tsvectorWeightFromBinary converts a 2-bit binary weight value to a TSVectorWeight. +func tsvectorWeightFromBinary(b uint16) TSVectorWeight { + switch b { + case 3: + return TSVectorWeightA + case 2: + return TSVectorWeightB + case 1: + return TSVectorWeightC + default: + return TSVectorWeightD + } +} + +// TSVectorPosition represents a lexeme position and its optional weight within a tsvector. +type TSVectorPosition struct { + Position uint16 + Weight TSVectorWeight +} + +func (p TSVectorPosition) String() string { + s := strconv.FormatUint(uint64(p.Position), 10) + if p.Weight != 0 && p.Weight != TSVectorWeightD { + s += string(p.Weight) + } + return s +} + +type TSVectorCodec struct{} + +func (TSVectorCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TSVectorCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TSVectorCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TSVectorValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTSVectorCodecBinary{} + case TextFormatCode: + return encodePlanTSVectorCodecText{} + } + + return nil +} + +type encodePlanTSVectorCodecBinary struct{} + +func (encodePlanTSVectorCodecBinary) Encode(value any, buf []byte) ([]byte, error) { + tsv, err := value.(TSVectorValuer).TSVectorValue() + if err != nil { + return nil, err + } + + if !tsv.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, int32(len(tsv.Lexemes))) + + for _, entry := range tsv.Lexemes { + buf = append(buf, entry.Word...) + buf = append(buf, 0x00) + buf = pgio.AppendUint16(buf, uint16(len(entry.Positions))) + + // Each position is a uint16: weight (2 bits) | position (14 bits) + for _, pos := range entry.Positions { + packed := tsvectorWeightToBinary(pos.Weight)<<14 | uint16(pos.Position)&0x3FFF + buf = pgio.AppendUint16(buf, packed) + } + } + + return buf, nil +} + +type scanPlanBinaryTSVectorToTSVectorScanner struct{} + +func (scanPlanBinaryTSVectorToTSVectorScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TSVectorScanner) + + if src == nil { + return scanner.ScanTSVector(TSVector{}) + } + + rp := 0 + + const ( + uint16Len = 2 + uint32Len = 4 + ) + + if len(src[rp:]) < uint32Len { + return fmt.Errorf("tsvector incomplete %v", src) + } + entryCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += uint32Len + + var tsv TSVector + if entryCount > 0 { + tsv.Lexemes = make([]TSVectorLexeme, entryCount) + } + + for i := range entryCount { + nullIndex := bytes.IndexByte(src[rp:], 0x00) + if nullIndex == -1 { + return fmt.Errorf("invalid tsvector binary format: missing null terminator") + } + + lexeme := TSVectorLexeme{Word: string(src[rp : rp+nullIndex])} + rp += nullIndex + 1 // skip past null terminator + + // Read position count. + if len(src[rp:]) < uint16Len { + return fmt.Errorf("invalid tsvector binary format: incomplete position count") + } + + numPositions := int(binary.BigEndian.Uint16(src[rp:])) + rp += uint16Len + + // Read each packed position: weight (2 bits) | position (14 bits) + if len(src[rp:]) < numPositions*uint16Len { + return fmt.Errorf("invalid tsvector binary format: incomplete positions") + } + + if numPositions > 0 { + lexeme.Positions = make([]TSVectorPosition, numPositions) + for pos := range numPositions { + packed := binary.BigEndian.Uint16(src[rp:]) + rp += uint16Len + lexeme.Positions[pos] = TSVectorPosition{ + Position: packed & 0x3FFF, + Weight: tsvectorWeightFromBinary(packed >> 14), + } + } + } + + tsv.Lexemes[i] = lexeme + } + tsv.Valid = true + + return scanner.ScanTSVector(tsv) +} + +var tsvectorLexemeReplacer = strings.NewReplacer( + `\`, `\\`, + `'`, `\'`, +) + +type encodePlanTSVectorCodecText struct{} + +func (encodePlanTSVectorCodecText) Encode(value any, buf []byte) ([]byte, error) { + tsv, err := value.(TSVectorValuer).TSVectorValue() + if err != nil { + return nil, err + } + + if !tsv.Valid { + return nil, nil + } + + if buf == nil { + buf = []byte{} + } + + for i, lex := range tsv.Lexemes { + if i > 0 { + buf = append(buf, ' ') + } + + buf = append(buf, '\'') + buf = append(buf, tsvectorLexemeReplacer.Replace(lex.Word)...) + buf = append(buf, '\'') + + sep := byte(':') + for _, p := range lex.Positions { + buf = append(buf, sep) + buf = append(buf, p.String()...) + sep = ',' + } + } + + return buf, nil +} + +func (TSVectorCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case TSVectorScanner: + return scanPlanBinaryTSVectorToTSVectorScanner{} + } + case TextFormatCode: + switch target.(type) { + case TSVectorScanner: + return scanPlanTextAnyToTSVectorScanner{} + } + } + + return nil +} + +type scanPlanTextAnyToTSVectorScanner struct{} + +func (s scanPlanTextAnyToTSVectorScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TSVectorScanner) + + if src == nil { + return scanner.ScanTSVector(TSVector{}) + } + + return s.scanString(string(src), scanner) +} + +func (scanPlanTextAnyToTSVectorScanner) scanString(src string, scanner TSVectorScanner) error { + tsv, err := parseTSVector(src) + if err != nil { + return err + } + return scanner.ScanTSVector(tsv) +} + +func (c TSVectorCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c TSVectorCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var tsv TSVector + err := codecScan(c, m, oid, format, src, &tsv) + if err != nil { + return nil, err + } + return tsv, nil +} + +type tsvectorParser struct { + str string + pos int +} + +func (p *tsvectorParser) atEnd() bool { + return p.pos >= len(p.str) +} + +func (p *tsvectorParser) peek() byte { + return p.str[p.pos] +} + +func (p *tsvectorParser) consume() (byte, bool) { + if p.pos >= len(p.str) { + return 0, true + } + b := p.str[p.pos] + p.pos++ + return b, false +} + +func (p *tsvectorParser) consumeSpaces() { + for !p.atEnd() && p.peek() == ' ' { + p.consume() + } +} + +// consumeLexeme consumes a single-quoted lexeme, handling single quotes and backslash escapes. +func (p *tsvectorParser) consumeLexeme() (string, error) { + ch, end := p.consume() + if end || ch != '\'' { + return "", fmt.Errorf("invalid tsvector format: lexeme must start with a single quote") + } + + var buf strings.Builder + for { + ch, end := p.consume() + if end { + return "", fmt.Errorf("invalid tsvector format: unterminated quoted lexeme") + } + + switch ch { + case '\'': + // Escaped quote ('') — write a literal single quote + if !p.atEnd() && p.peek() == '\'' { + p.consume() + buf.WriteByte('\'') + } else { + // Closing quote — lexeme is complete + return buf.String(), nil + } + case '\\': + next, end := p.consume() + if end { + return "", fmt.Errorf("invalid tsvector format: unexpected end after backslash") + } + buf.WriteByte(next) + default: + buf.WriteByte(ch) + } + } +} + +// consumePositions consumes a comma-separated list of position[weight] values. +func (p *tsvectorParser) consumePositions() ([]TSVectorPosition, error) { + var positions []TSVectorPosition + + for { + pos, err := p.consumePosition() + if err != nil { + return nil, err + } + positions = append(positions, pos) + + if p.atEnd() || p.peek() != ',' { + break + } + + p.consume() // skip ',' + } + + return positions, nil +} + +// consumePosition consumes a single position number with optional weight letter. +func (p *tsvectorParser) consumePosition() (TSVectorPosition, error) { + start := p.pos + + for !p.atEnd() && p.peek() >= '0' && p.peek() <= '9' { + p.consume() + } + + if p.pos == start { + return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: expected position number") + } + + num, err := strconv.ParseUint(p.str[start:p.pos], 10, 16) + if err != nil { + return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: invalid position number %q", p.str[start:p.pos]) + } + + pos := TSVectorPosition{Position: uint16(num), Weight: TSVectorWeightD} + + // Check for optional weight letter + if !p.atEnd() { + switch p.peek() { + case 'A', 'a': + pos.Weight = TSVectorWeightA + case 'B', 'b': + pos.Weight = TSVectorWeightB + case 'C', 'c': + pos.Weight = TSVectorWeightC + case 'D', 'd': + pos.Weight = TSVectorWeightD + default: + return pos, nil + } + p.consume() + } + + return pos, nil +} + +// parseTSVector parses a PostgreSQL tsvector text representation. +func parseTSVector(s string) (TSVector, error) { + result := TSVector{} + p := &tsvectorParser{str: strings.TrimSpace(s), pos: 0} + + for !p.atEnd() { + p.consumeSpaces() + if p.atEnd() { + break + } + + word, err := p.consumeLexeme() + if err != nil { + return TSVector{}, err + } + + entry := TSVectorLexeme{Word: word} + + // Check for optional positions after ':' + if !p.atEnd() && p.peek() == ':' { + p.consume() // skip ':' + + positions, err := p.consumePositions() + if err != nil { + return TSVector{}, err + } + entry.Positions = positions + } + + result.Lexemes = append(result.Lexemes, entry) + } + + result.Valid = true + + return result, nil +} diff --git a/vendor/github.com/jackc/pgx/v5/test.sh b/vendor/github.com/jackc/pgx/v5/test.sh new file mode 100644 index 00000000..8bab2d28 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/test.sh @@ -0,0 +1,170 @@ +#!/usr/bin/env bash +set -euo pipefail + +# test.sh - Run pgx tests against specific database targets +# +# Usage: +# ./test.sh [target] [go test flags...] +# +# Targets: +# pg14 - PostgreSQL 14 (port 5414) +# pg15 - PostgreSQL 15 (port 5415) +# pg16 - PostgreSQL 16 (port 5416) +# pg17 - PostgreSQL 17 (port 5417) +# pg18 - PostgreSQL 18 (port 5432) [default] +# crdb - CockroachDB (port 26257) +# all - Run against all targets sequentially +# +# Examples: +# ./test.sh # Test against PG18 +# ./test.sh pg14 # Test against PG14 +# ./test.sh crdb # Test against CockroachDB +# ./test.sh all # Test against all targets +# ./test.sh pg16 -run TestConnect # Test specific test against PG16 +# ./test.sh pg18 -count=1 -v # Verbose, no cache, PG18 + +# Color output (disabled if not a terminal) +if [ -t 1 ]; then + GREEN='\033[0;32m' + RED='\033[0;31m' + BLUE='\033[0;34m' + NC='\033[0m' +else + GREEN='' + RED='' + BLUE='' + NC='' +fi + +log_info() { echo -e "${BLUE}==> $*${NC}"; } +log_ok() { echo -e "${GREEN}==> $*${NC}"; } +log_err() { echo -e "${RED}==> $*${NC}" >&2; } + +# Wait for a database to accept connections +wait_for_ready() { + local connstr="$1" + local label="$2" + local max_attempts=30 + local attempt=0 + + log_info "Waiting for $label to be ready..." + while ! psql "$connstr" -c "SELECT 1" > /dev/null 2>&1; do + attempt=$((attempt + 1)) + if [ "$attempt" -ge "$max_attempts" ]; then + log_err "$label did not become ready after $max_attempts attempts" + return 1 + fi + sleep 1 + done + log_ok "$label is ready" +} + +# Directory containing this script (used to locate testsetup/) +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CERTS_DIR="$SCRIPT_DIR/testsetup/certs" + +# Copy client certificates to /tmp for TLS tests +setup_client_certs() { + if [ -d "$CERTS_DIR" ]; then + base64 -d "$CERTS_DIR/ca.pem.b64" > /tmp/ca.pem + base64 -d "$CERTS_DIR/pgx_sslcert.crt.b64" > /tmp/pgx_sslcert.crt + base64 -d "$CERTS_DIR/pgx_sslcert.key.b64" > /tmp/pgx_sslcert.key + fi +} + +# Initialize CockroachDB (create database if not exists) +init_crdb() { + local connstr="postgresql://root@localhost:26257/?sslmode=disable" + wait_for_ready "$connstr" "CockroachDB" + log_info "Ensuring pgx_test database exists on CockroachDB..." + psql "$connstr" -c "CREATE DATABASE IF NOT EXISTS pgx_test" 2>/dev/null || true +} + +# Run tests against a single target +run_tests() { + local target="$1" + shift + local extra_args=("$@") + + local label="" + local port="" + + case "$target" in + pg14) label="PostgreSQL 14"; port=5414 ;; + pg15) label="PostgreSQL 15"; port=5415 ;; + pg16) label="PostgreSQL 16"; port=5416 ;; + pg17) label="PostgreSQL 17"; port=5417 ;; + pg18) label="PostgreSQL 18"; port=5432 ;; + crdb) + label="CockroachDB (port 26257)" + init_crdb + log_info "Testing against $label" + if ! PGX_TEST_DATABASE="postgresql://root@localhost:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" \ + go test -count=1 "${extra_args[@]}" ./...; then + log_err "Tests FAILED against $label" + return 1 + fi + log_ok "Tests passed against $label" + return 0 + ;; + *) + log_err "Unknown target: $target" + log_err "Valid targets: pg14, pg15, pg16, pg17, pg18, crdb, all" + return 1 + ;; + esac + + setup_client_certs + + log_info "Testing against $label (port $port)" + if ! PGX_TEST_DATABASE="host=localhost port=$port user=postgres password=postgres dbname=pgx_test" \ + PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql port=$port user=postgres dbname=pgx_test" \ + PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \ + PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \ + PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" \ + PGX_TEST_SCRAM_PLUS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" \ + PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_pw password=secret dbname=pgx_test" \ + PGX_TEST_TLS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" \ + PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost port=$port user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" \ + PGX_SSL_PASSWORD=certpw \ + go test -count=1 "${extra_args[@]}" ./...; then + log_err "Tests FAILED against $label" + return 1 + fi + log_ok "Tests passed against $label" +} + +# Main +main() { + local target="${1:-pg18}" + + if [ "$target" = "all" ]; then + shift || true + local targets=(pg14 pg15 pg16 pg17 pg18 crdb) + local failed=() + + for t in "${targets[@]}"; do + echo "" + log_info "==========================================" + log_info "Target: $t" + log_info "==========================================" + if ! run_tests "$t" "$@"; then + failed+=("$t") + log_err "FAILED: $t" + fi + done + + echo "" + if [ ${#failed[@]} -gt 0 ]; then + log_err "Failed targets: ${failed[*]}" + return 1 + else + log_ok "All targets passed" + fi + else + shift || true + run_tests "$target" "$@" + fi +} + +main "$@" diff --git a/walletapp/config.go b/walletapp/config.go index a011ca94..8d23c4cc 100644 --- a/walletapp/config.go +++ b/walletapp/config.go @@ -2,14 +2,14 @@ package walletapp import ( "git.gocasts.ir/ebhomengo/niki/adapter/redis" - "git.gocasts.ir/ebhomengo/niki/domain/shoppingbasket/repository" + "git.gocasts.ir/ebhomengo/niki/domain/wallet/repository/postgres" "git.gocasts.ir/ebhomengo/niki/pkg/httpserver" logger "git.gocasts.ir/ebhomengo/niki/pkg/logger" ) type Config struct { Redis redis.Config `koanf:"redis" json:"redis"` - Repo repository.Config `koanf:"repo" json:"repo"` + Repo postgres.Config `koanf:"repo" json:"repo"` HTTPServer httpserver.Config `koanf:"http_server" json:"http_server"` Logger logger.Config `koanf:"logger" json:"logger"` }