forked from ebhomengo/niki
ADD db transaction handler & pagination & implements repository interface
This commit is contained in:
parent
3b22e99697
commit
bc24bcc686
|
|
@ -9,6 +9,7 @@ type Transaction struct {
|
|||
Currency Currency
|
||||
ActionType TransactionType
|
||||
Timestamp time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type TransactionType string
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ type Wallet struct {
|
|||
UserID uint64 // user unique ID
|
||||
Balance float64
|
||||
Currency Currency
|
||||
UpdatedAt time.Time
|
||||
Status WalletStatus // "active", "frozen", "closed"
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type WalletStatus string
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package param
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||
)
|
||||
|
||||
|
|
@ -9,7 +11,9 @@ type CreateTransactionRequest struct {
|
|||
Amount float64 `json:"amount"`
|
||||
Currency entity.Currency `json:"currency"`
|
||||
ActionType entity.TransactionType `json:"action_type"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
type InsertTransactionResponse struct {
|
||||
Balance float64 `json:"balance"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,14 +4,17 @@ import (
|
|||
"time"
|
||||
|
||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
||||
)
|
||||
|
||||
type TransactionRequest struct {
|
||||
UserID uint64 `json:"user_id"`
|
||||
UserID uint64 `json:"user_id"`
|
||||
Pagination postgres.RequestPagination `json:"pagination"`
|
||||
}
|
||||
|
||||
type TransactionResponse struct {
|
||||
Transaction []TransactionInfo `json:"transaction"`
|
||||
Transaction []TransactionInfo `json:"transactions"`
|
||||
Pagination postgres.ResponsePagination `json:"pagination"`
|
||||
}
|
||||
|
||||
type TransactionInfo struct {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,99 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
||||
errmsg "git.gocasts.ir/ebhomengo/niki/pkg/err_msg"
|
||||
richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error"
|
||||
)
|
||||
|
||||
func (db *DB) InsertTransaction(ctx context.Context, transaction entity.Transaction) (balance float64, err error) {
|
||||
const op = richerror.Op("wallet.repo.InsertTransaction")
|
||||
|
||||
query := `INSERT INTO Transactions (user_id, amount ,currency ,action_type, timestamp) values ($1, $2, $3, $4, $5)`
|
||||
|
||||
stmt, stErr := db.conn.PrepareStatement(ctx, postgres.StatementKeyWalletInsertTransaction, query)
|
||||
if stErr != nil {
|
||||
err = richerror.New(op).WithErr(stErr).WithKind(richerror.KindUnexpected)
|
||||
return
|
||||
}
|
||||
|
||||
txHolder, newCtx := postgres.GetDBTxHolderFromContextOrNew(ctx)
|
||||
tx, txErr := txHolder.BeginTx(newCtx, db.conn.Conn())
|
||||
if txErr != nil {
|
||||
err = richerror.New(op).WithErr(txErr).WithKind(richerror.KindUnexpected)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if tx != nil {
|
||||
if err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
// log rbErr
|
||||
}
|
||||
} else if cErr := tx.Commit(); cErr != nil {
|
||||
// log cErr
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
params := []any{transaction.UserID, transaction.Currency, transaction.Amount, transaction.ActionType, transaction.Timestamp}
|
||||
|
||||
_, execErr := tx.StmtExecContext(newCtx, stmt, params...)
|
||||
if execErr != nil {
|
||||
err = richerror.New(op).WithErr(execErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgCantInsertTransaction)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
newBalance, blanceErr := db.UpsertBalance(newCtx, transaction.UserID, transaction.Amount)
|
||||
if blanceErr != nil {
|
||||
err = richerror.New(op).WithErr(blanceErr)
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
balance = newBalance
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) UpsertBalance(newCtx context.Context, userID uint64, amount float64) (balance float64, err error) {
|
||||
const op = richerror.Op("wallet.repo.UpdateBalance")
|
||||
|
||||
txHolder, _ := postgres.GetDBTxHolderFromContextOrNew(newCtx)
|
||||
tx, txErr := txHolder.Conn()
|
||||
|
||||
if txErr != nil {
|
||||
err = richerror.New(op).WithMessage(errmsg.ErrorMsgCantUpsertBalance).WithKind(richerror.KindUnexpected)
|
||||
return
|
||||
}
|
||||
|
||||
if tx == nil {
|
||||
err = richerror.New(op).WithMessage(errmsg.ErrorMsgCantUpsertBalance).WithKind(richerror.KindUnexpected)
|
||||
return
|
||||
}
|
||||
|
||||
upsertQuery := `INSERT INTO wallets (user_id, balance) VALUES ($1, $2)
|
||||
ON CONFLICT (user_id) DO UPDATE SET balance = wallets.balance + $2
|
||||
RETURNING balance`
|
||||
|
||||
upsertStmt, stErr := db.conn.PrepareStatement(newCtx, postgres.StatementKeyWalletUpsertBalance, upsertQuery)
|
||||
if stErr != nil {
|
||||
err = richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery).WithKind(richerror.KindUnexpected)
|
||||
return
|
||||
}
|
||||
|
||||
row := tx.StmtQueryRowContext(newCtx, upsertStmt, userID, amount)
|
||||
|
||||
sErr := row.Scan(&balance)
|
||||
if sErr != nil {
|
||||
err = richerror.New(op).WithMessage(errmsg.ErrorMsgCantScanQueryResult).WithKind(richerror.KindUnexpected)
|
||||
return
|
||||
}
|
||||
|
||||
return balance, nil
|
||||
}
|
||||
|
|
@ -2,6 +2,9 @@ package postgres
|
|||
|
||||
import "git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
||||
|
||||
type Config struct {
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
conn *postgres.DB
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,72 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
||||
errmsg "git.gocasts.ir/ebhomengo/niki/pkg/err_msg"
|
||||
richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error"
|
||||
)
|
||||
|
||||
func (db *DB) GetTransactionListByUserID(ctx context.Context, userID uint64, dbPagination postgres.DBPagination) ([]entity.Transaction, int64, error) {
|
||||
const op = richerror.Op("Wallet.repo.GetTransactionListByUserID")
|
||||
|
||||
query := `SELECT * FROM Transactions WHERE user_id = $1 AND transaction_timestamp < $2 ORDER BY transaction_id DESC LIMIT $3`
|
||||
|
||||
stmt, StErr := db.conn.PrepareStatement(ctx, postgres.StatementKeyAWalletGetTransactionHistory, query)
|
||||
|
||||
if StErr != nil {
|
||||
|
||||
return nil, 0, richerror.New(op).WithErr(StErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgFailedQuery)
|
||||
|
||||
}
|
||||
|
||||
totalRetrieveRecord := (dbPagination.MaxNextPages-1)*dbPagination.PageSize + 1
|
||||
|
||||
lastTimeStamp := dbPagination.LastTimeStamp
|
||||
|
||||
queryRows, err := stmt.QueryContext(ctx, userID, lastTimeStamp, totalRetrieveRecord)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, richerror.New(op).WithErr(StErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgFailedQuery)
|
||||
}
|
||||
|
||||
defer queryRows.Close()
|
||||
|
||||
var transactions []entity.Transaction
|
||||
|
||||
var lenCounter int64
|
||||
|
||||
for queryRows.Next() {
|
||||
lenCounter++
|
||||
|
||||
if lenCounter < dbPagination.PageSize+1 {
|
||||
transaction, err := scanTransaction(queryRows)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, richerror.New(op).WithErr(err).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgCantScanQueryResult)
|
||||
}
|
||||
|
||||
transactions = append(transactions, transaction)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if qErr := queryRows.Err(); qErr != nil {
|
||||
|
||||
return nil, 0, richerror.New(op).WithErr(qErr).WithKind(richerror.KindUnexpected).WithMessage(errmsg.ErrorMsgCantScanQueryResult)
|
||||
}
|
||||
|
||||
return transactions, lenCounter, nil
|
||||
|
||||
}
|
||||
|
||||
func scanTransaction(scanner postgres.Scanner) (transaction entity.Transaction, err error) {
|
||||
err = scanner.Scan(&transaction.UserID, &transaction.Currency, &transaction.Amount, &transaction.ActionType, &transaction.Timestamp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
||||
richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error"
|
||||
)
|
||||
|
||||
func (db *DB) GetWalletByUserID(ctx context.Context, UserID uint64) (entity.Wallet, error) {
|
||||
const op = richerror.Op("Wallet.repo.GetWalletByUserID")
|
||||
query := `SELECT * FROM wallets WHERE user_id = $1`
|
||||
stmt, stErr := db.conn.PrepareStatement(ctx, postgres.StatementKeyWalletGetUserWallet, query)
|
||||
if stErr != nil {
|
||||
return entity.Wallet{}, richerror.New(op).WithErr(stErr)
|
||||
}
|
||||
walletRow := db.conn.StmtQueryRowContext(ctx, stmt, UserID)
|
||||
|
||||
wallet, sErr := scanWallet(walletRow)
|
||||
|
||||
if sErr != nil {
|
||||
return entity.Wallet{}, richerror.New(op).WithErr(sErr)
|
||||
}
|
||||
|
||||
return wallet, nil
|
||||
|
||||
}
|
||||
|
||||
func scanWallet(scanner postgres.Scanner) (entity.Wallet, error) {
|
||||
|
||||
var wallet entity.Wallet
|
||||
|
||||
err := scanner.Scan(&wallet.ID, &wallet.Balance, &wallet.Currency, &wallet.Status, &wallet.UpdatedAt)
|
||||
if err != nil {
|
||||
return entity.Wallet{}, err
|
||||
}
|
||||
|
||||
return wallet, nil
|
||||
}
|
||||
|
|
@ -19,11 +19,12 @@ func (s Service) CreateTransaction(ctx context.Context, request param.CreateTran
|
|||
Amount: request.Amount,
|
||||
Currency: request.Currency,
|
||||
ActionType: request.ActionType,
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: request.Timestamp,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
err := s.repo.InsertTransaction(ctx, transaction)
|
||||
if err != nil {
|
||||
return param.InsertTransactionResponse{}, err
|
||||
return param.InsertTransactionResponse{}, richerror.New(op).WithErr(err)
|
||||
}
|
||||
|
||||
return param.InsertTransactionResponse{}, nil
|
||||
|
|
|
|||
|
|
@ -4,15 +4,18 @@ import (
|
|||
"context"
|
||||
|
||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
GetTransactionListByUserID(ctx context.Context, UserID uint64) ([]entity.Transaction, error)
|
||||
GetTransactionListByUserID(ctx context.Context, UserID uint64, DBPagination postgres.DBPagination) ([]entity.Transaction, int64, error)
|
||||
GetWalletByUserID(ctx context.Context, UserID uint64) (entity.Wallet, error)
|
||||
InsertTransaction(ctx context.Context, transaction entity.Transaction) error
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
PageSize int64 `koanf:"page_size"`
|
||||
MaxNextPages int64 `koanf:"max_next_pages"`
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
|
|
|
|||
|
|
@ -2,22 +2,46 @@ package service
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/param"
|
||||
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
||||
richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error"
|
||||
)
|
||||
|
||||
func (s Service) GetUserTransactionHistory(ctx context.Context, request param.TransactionRequest) (param.TransactionResponse, error) {
|
||||
const op = richerror.Op("wallet.service.GetUserTransactionHistory")
|
||||
|
||||
transactionList, err := s.repo.GetTransactionListByUserID(ctx, request.UserID)
|
||||
|
||||
if err != nil {
|
||||
return param.TransactionResponse{}, err
|
||||
lastTimeStamp := request.Pagination.LastTimeStamp
|
||||
if lastTimeStamp.IsZero() {
|
||||
lastTimeStamp = time.Now()
|
||||
}
|
||||
|
||||
return param.TransactionResponse{Transaction: transactionEntityToTransactionInfo(transactionList)}, nil
|
||||
dbPagination := postgres.DBPagination{
|
||||
LastTimeStamp: lastTimeStamp,
|
||||
PageNumber: request.Pagination.PageNumber,
|
||||
MaxNextPages: s.cfg.MaxNextPages,
|
||||
PageSize: s.cfg.PageSize,
|
||||
}
|
||||
|
||||
transactionList, listLen, err := s.repo.GetTransactionListByUserID(ctx, request.UserID, dbPagination)
|
||||
|
||||
showableNextPagesNum := postgres.ComputeNextPages(listLen, s.cfg.PageSize, s.cfg.MaxNextPages)
|
||||
|
||||
if err != nil {
|
||||
return param.TransactionResponse{}, richerror.New(op).WithErr(err)
|
||||
}
|
||||
|
||||
paginationInfo := postgres.ResponsePagination{
|
||||
PageNumber: request.Pagination.PageNumber,
|
||||
PageSize: s.cfg.PageSize,
|
||||
ShowableNextPagesNum: showableNextPagesNum,
|
||||
}
|
||||
|
||||
return param.TransactionResponse{
|
||||
Transaction: transactionEntityToTransactionInfo(transactionList),
|
||||
Pagination: paginationInfo,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ func (s Service) GetUserWallet(ctx context.Context, request param.WalletRequest)
|
|||
wallet, err := s.repo.GetWalletByUserID(ctx, request.UserID)
|
||||
|
||||
if err != nil {
|
||||
return param.WalletResponse{}, err
|
||||
return param.WalletResponse{}, richerror.New(op).WithErr(err)
|
||||
}
|
||||
|
||||
return param.WalletResponse{
|
||||
|
|
|
|||
|
|
@ -7,31 +7,30 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
querier "git.gocasts.ir/ebhomengo/niki/pkg/query_transaction/sql"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
|
||||
type Config struct {
|
||||
Host string `koanf:"host"`
|
||||
Port int `koanf:"port"`
|
||||
User string `koanf:"user"`
|
||||
Password string `koanf:"password"`
|
||||
DbName string `koanf:"dbName"`
|
||||
SSLMode string `koanf:"sslMode"`
|
||||
MaxIdleConn int `koanf:"maxIdleConns"`
|
||||
MaxOpenConn int `koanf:"maxOpenConns"`
|
||||
ConnMaxLifetime int `koanf:"connMaxLifetime"`
|
||||
PathOfMigrations string `koanf:"pathOfMigrations"`
|
||||
Host string `koanf:"host"`
|
||||
Port int `koanf:"port"`
|
||||
User string `koanf:"user"`
|
||||
Password string `koanf:"password"`
|
||||
DbName string `koanf:"dbName"`
|
||||
SSLMode string `koanf:"sslMode"`
|
||||
MaxIdleConn int `koanf:"maxIdleConns"`
|
||||
MaxOpenConn int `koanf:"maxOpenConns"`
|
||||
ConnMaxLifetime int `koanf:"connMaxLifetime"`
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
config Config
|
||||
db *querier.SQLDB
|
||||
db *sql.DB
|
||||
mu sync.Mutex
|
||||
statements map[statementKey]*sql.Stmt
|
||||
}
|
||||
|
||||
func (db *DB) Conn() *querier.SQLDB {
|
||||
func (db *DB) Conn() *sql.DB {
|
||||
|
||||
return db.db
|
||||
}
|
||||
|
||||
|
|
@ -56,7 +55,7 @@ func New(config Config) *DB {
|
|||
|
||||
return &DB{
|
||||
config: config,
|
||||
db: &querier.SQLDB{DB: db},
|
||||
db: db,
|
||||
statements: make(map[statementKey]*sql.Stmt),
|
||||
}
|
||||
}
|
||||
|
|
@ -93,5 +92,22 @@ func (db *DB) CloseStatements() error {
|
|||
}
|
||||
|
||||
func (db *DB) Close() error {
|
||||
return db.db.DB.Close()
|
||||
return db.db.Close()
|
||||
}
|
||||
|
||||
func (db *DB) StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) {
|
||||
return stmt.QueryContext(ctx, args...)
|
||||
}
|
||||
|
||||
func (db *DB) StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row {
|
||||
return stmt.QueryRowContext(ctx, args...)
|
||||
}
|
||||
|
||||
func (db *DB) StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) {
|
||||
result, err := stmt.ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,4 +2,4 @@ production:
|
|||
dialect: postgres
|
||||
datasource: "host=127.0.0.1 port=5432 user=wallet password=wallet2123 dbname=wallet_db sslmode=disable"
|
||||
dir: domain/wallet/repository/postgres/migrations
|
||||
table: wallet_migrationsns
|
||||
table: gorp_migrationsns
|
||||
|
|
@ -1 +1,113 @@
|
|||
package migrator
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"text/tabwriter"
|
||||
|
||||
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
migrate "github.com/rubenv/sql-migrate"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
migrationDBName string `koanf:"migration_db_name"`
|
||||
pathOfMigrations string `koanf:"pathOfMigrations"`
|
||||
dbConfig postgres.Config
|
||||
}
|
||||
|
||||
type Migrator struct {
|
||||
cfg Config
|
||||
dialect string
|
||||
migrations *migrate.FileMigrationSource
|
||||
}
|
||||
|
||||
func New(cfg Config) *Migrator {
|
||||
|
||||
return &Migrator{
|
||||
cfg: cfg,
|
||||
dialect: "psx",
|
||||
migrations: &migrate.FileMigrationSource{Dir: cfg.pathOfMigrations},
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (m *Migrator) getDsn() string {
|
||||
|
||||
return fmt.Sprintf("host=%s user=%s password=%s dbname=%s sslmode=disable",
|
||||
m.cfg.dbConfig.Host, m.cfg.dbConfig.User, m.cfg.dbConfig.Password, m.cfg.dbConfig.DbName)
|
||||
|
||||
}
|
||||
|
||||
func (m *Migrator) Up() {
|
||||
db, err := sql.Open(m.dialect, m.getDsn())
|
||||
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("can't open postgres db: %v", err))
|
||||
|
||||
}
|
||||
|
||||
defer db.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
n, err := migrate.Exec(db, m.dialect, m.migrations, migrate.Up)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("cant apply migrations : %v ", err))
|
||||
|
||||
}
|
||||
fmt.Printf("Applied %d migrations\n", n)
|
||||
|
||||
}
|
||||
|
||||
func (m *Migrator) Down() {
|
||||
migrate.SetTable(m.cfg.migrationDBName)
|
||||
|
||||
db, err := sql.Open(m.dialect, m.getDsn())
|
||||
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("can't open postgres db: %v", err))
|
||||
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
n, err := migrate.Exec(db, m.dialect, m.migrations, migrate.Down)
|
||||
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("cant rollback migrations : %v ", err))
|
||||
|
||||
}
|
||||
fmt.Printf("Applied %d migrations\n", n)
|
||||
|
||||
}
|
||||
|
||||
func (m *Migrator) Status() {
|
||||
migrate.SetTable(m.cfg.migrationDBName)
|
||||
|
||||
db, err := sql.Open(m.dialect, m.getDsn())
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("can't open postgres db: %v", err))
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
migrations, _, err := migrate.PlanMigration(db, m.dialect, m.migrations, migrate.Up, 0)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("can't plan migrations: %v", err))
|
||||
}
|
||||
|
||||
if len(migrations) == 0 {
|
||||
fmt.Println("✅ No pending migrations.")
|
||||
return
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||
fmt.Fprintln(w, "PENDING MIGRATIONS")
|
||||
fmt.Fprintln(w, "ID\tSTATUS")
|
||||
for _, migration := range migrations {
|
||||
fmt.Fprintf(w, "%s\tPending\n", migration.Id)
|
||||
}
|
||||
w.Flush()
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,35 @@
|
|||
package postgres
|
||||
|
||||
import "time"
|
||||
|
||||
type RequestPagination struct {
|
||||
PageNumber int64 `json:"page_number"`
|
||||
LastTimeStamp time.Time `json:"last_time_stamp"`
|
||||
}
|
||||
|
||||
type ResponsePagination struct {
|
||||
PageNumber int64 `json:"page_number"`
|
||||
PageSize int64 `json:"page_size"`
|
||||
ShowableNextPagesNum int64 `json:"showable_next_pages_num"`
|
||||
}
|
||||
|
||||
type DBPagination struct {
|
||||
LastTimeStamp time.Time
|
||||
PageNumber int64
|
||||
MaxNextPages int64
|
||||
PageSize int64
|
||||
}
|
||||
|
||||
func ComputeNextPages(listLen int64, pageSize int64, maxNextPages int64) int64 {
|
||||
|
||||
pages := float64(listLen) / float64(pageSize)
|
||||
|
||||
for i := maxNextPages - 1; i >= 0; i-- {
|
||||
if pages > float64(i) {
|
||||
return i + 1
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
|
||||
}
|
||||
|
|
@ -3,7 +3,8 @@ package postgres
|
|||
type statementKey uint
|
||||
|
||||
const (
|
||||
StatementKeyAWalletGetTransactionHistory statementKey = iota + 1
|
||||
StatementKeyWalletInsertTransaction
|
||||
StatementKeyWalletGetUserWallet
|
||||
StatementKeyAWalletGetTransactionHistory statementKey = iota + 1 //wallet
|
||||
StatementKeyWalletInsertTransaction //wallet
|
||||
StatementKeyWalletGetUserWallet //wallet
|
||||
StatementKeyWalletUpsertBalance //wallet
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
package postgres
|
||||
|
||||
type Scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const DBTxHolderContextKey contextKey = "txholder"
|
||||
|
||||
type conn interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error)
|
||||
StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row
|
||||
StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
type TxConn struct {
|
||||
*sql.Tx
|
||||
}
|
||||
|
||||
type DBTxHolder struct {
|
||||
conn conn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (db *DBTxHolder) Conn() (conn, error) {
|
||||
if db.conn == nil {
|
||||
return nil, fmt.Errorf("Conn() called before BeginTx()")
|
||||
}
|
||||
return db.conn, nil
|
||||
}
|
||||
|
||||
func GetDBTxHolderFromContextOrNew(ctx context.Context) (*DBTxHolder, context.Context) {
|
||||
db, ok := ctx.Value(DBTxHolderContextKey).(*DBTxHolder)
|
||||
if !ok {
|
||||
db = &DBTxHolder{
|
||||
conn: nil,
|
||||
}
|
||||
ctx = context.WithValue(ctx, DBTxHolderContextKey, db)
|
||||
}
|
||||
|
||||
return db, ctx
|
||||
}
|
||||
|
||||
func (db *DBTxHolder) BeginTx(ctx context.Context, conn *sql.DB) (conn, error) {
|
||||
dbTransactionConn, err := db.Continue(ctx, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbTransactionConn, nil
|
||||
}
|
||||
|
||||
func (db *DBTxHolder) Continue(ctx context.Context, conn *sql.DB) (conn, error) {
|
||||
//db.mu.Lock()
|
||||
//defer db.mu.Unlock()
|
||||
|
||||
var err error
|
||||
if db.conn == nil {
|
||||
|
||||
db.conn, err = db.txFactory(ctx, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return db.conn, err
|
||||
}
|
||||
|
||||
func (db *DBTxHolder) txFactory(ctx context.Context, rConn *sql.DB) (*TxConn, error) {
|
||||
|
||||
tx, bErr := rConn.BeginTx(ctx, nil)
|
||||
if bErr != nil {
|
||||
|
||||
return nil, bErr
|
||||
}
|
||||
|
||||
return &TxConn{Tx: tx}, nil
|
||||
}
|
||||
|
||||
func (db *DBTxHolder) Commit() error {
|
||||
//db.mu.Lock()
|
||||
//defer db.mu.Unlock()
|
||||
if db.conn == nil {
|
||||
return fmt.Errorf("no active transaction")
|
||||
}
|
||||
err := db.conn.Commit()
|
||||
db.conn = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *DBTxHolder) Rollback() error {
|
||||
//db.mu.Lock()
|
||||
//defer db.mu.Unlock()
|
||||
if db.conn == nil {
|
||||
return fmt.Errorf("no active transaction")
|
||||
}
|
||||
err := db.conn.Rollback()
|
||||
db.conn = nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *DBTxHolder) SavePoint(ctx context.Context, key string) error {
|
||||
//db.mu.Lock()
|
||||
//defer db.mu.Unlock()
|
||||
if db.conn == nil {
|
||||
return fmt.Errorf("no active transaction")
|
||||
}
|
||||
quoted := pq.QuoteIdentifier(key)
|
||||
_, err := db.conn.ExecContext(ctx, "SAVEPOINT "+quoted)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *DBTxHolder) RollbackSavePoint(ctx context.Context, key string) error {
|
||||
//db.mu.Lock()
|
||||
//defer db.mu.Unlock()
|
||||
if db.conn == nil {
|
||||
return fmt.Errorf("no active transaction")
|
||||
}
|
||||
|
||||
quoted := pq.QuoteIdentifier(key)
|
||||
|
||||
_, err := db.conn.ExecContext(ctx, "ROLLBACK TO SAVEPOINT "+quoted)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (db *DBTxHolder) ReleaseSavePoint(ctx context.Context, key string) error {
|
||||
//db.mu.Lock()
|
||||
//defer db.mu.Unlock()
|
||||
if db.conn == nil {
|
||||
return fmt.Errorf("no active transaction")
|
||||
}
|
||||
quoted := pq.QuoteIdentifier(key)
|
||||
_, err := db.conn.ExecContext(ctx, "RELEASE SAVEPOINT "+quoted)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *TxConn) StmtQueryContext(ctx context.Context, stmt *sql.Stmt, args ...any) (*sql.Rows, error) {
|
||||
txStmt := tx.StmtContext(ctx, stmt)
|
||||
return txStmt.QueryContext(ctx, args...)
|
||||
}
|
||||
|
||||
func (tx *TxConn) StmtQueryRowContext(ctx context.Context, stmt *sql.Stmt, args ...any) *sql.Row {
|
||||
txStmt := tx.StmtContext(ctx, stmt)
|
||||
return txStmt.QueryRowContext(ctx, args...)
|
||||
}
|
||||
|
||||
func (tx *TxConn) StmtExecContext(ctx context.Context, stmt *sql.Stmt, args ...any) (sql.Result, error) {
|
||||
txStmt := tx.StmtContext(ctx, stmt)
|
||||
result, err := txStmt.ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
@ -58,4 +58,9 @@ const (
|
|||
ErrorMsgInvalidRefreshToken = "invalid refresh token"
|
||||
ErrorMsgInvalidBenefactorStatus = "invalid benefactor status"
|
||||
ErrorMsgInvalidAction = "action invalid"
|
||||
ErrorMsgCantUpsertBalance = "cant update balance" // wallet
|
||||
ErrorMsgCantGetBalance = "cant update balance" // wallet
|
||||
ErrorMsgCantInsertTransaction = "cant insert transaction" // wallet
|
||||
ErrorMsgFailedQuery = "query failed" // wallet
|
||||
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
pgx is a PostgreSQL driver and toolkit for Go (`github.com/jackc/pgx/v5`). It provides both a native PostgreSQL interface and a `database/sql` compatible driver. Requires Go 1.25+ and supports PostgreSQL 14+ and CockroachDB.
|
||||
|
||||
## Build & Test Commands
|
||||
|
||||
```bash
|
||||
# Run all tests (requires PGX_TEST_DATABASE to be set)
|
||||
go test ./...
|
||||
|
||||
# Run a specific test
|
||||
go test -run TestFunctionName ./...
|
||||
|
||||
# Run tests for a specific package
|
||||
go test ./pgconn/...
|
||||
|
||||
# Run tests with race detector
|
||||
go test -race ./...
|
||||
|
||||
# DevContainer: run tests against specific PostgreSQL versions
|
||||
./test.sh pg18 # Default: PostgreSQL 18
|
||||
./test.sh pg16 -run TestConnect # Specific test against PG16
|
||||
./test.sh crdb # CockroachDB
|
||||
./test.sh all # All targets (pg14-18 + crdb)
|
||||
|
||||
# Format (always run after making changes)
|
||||
goimports -w .
|
||||
|
||||
# Lint
|
||||
golangci-lint run ./...
|
||||
```
|
||||
|
||||
## Test Database Setup
|
||||
|
||||
Tests require `PGX_TEST_DATABASE` environment variable. In the devcontainer, `test.sh` handles this. For local development:
|
||||
|
||||
```bash
|
||||
export PGX_TEST_DATABASE="host=localhost user=postgres password=postgres dbname=pgx_test"
|
||||
```
|
||||
|
||||
The test database needs extensions: `hstore`, `ltree`, and a `uint64` domain. See `testsetup/postgresql_setup.sql` for full setup. Many tests are skipped unless additional `PGX_TEST_*` env vars are set (for TLS, SCRAM, MD5, unix socket, PgBouncer testing).
|
||||
|
||||
## Architecture
|
||||
|
||||
The codebase is a layered architecture, bottom-up:
|
||||
|
||||
- **pgproto3/** — PostgreSQL wire protocol v3 encoder/decoder. Defines `FrontendMessage` and `BackendMessage` types for every protocol message.
|
||||
- **pgconn/** — Low-level connection layer (roughly libpq-equivalent). Handles authentication, TLS, query execution, COPY protocol, and notifications. `PgConn` is the core type.
|
||||
- **pgx** (root package) — High-level query interface built on `pgconn`. Provides `Conn`, `Rows`, `Tx`, `Batch`, `CopyFrom`, and generic helpers like `CollectRows`/`ForEachRow`. Includes automatic statement caching (LRU).
|
||||
- **pgtype/** — Type system mapping between Go and PostgreSQL types (70+ types). Key interfaces: `Codec`, `Type`, `TypeMap`. Custom types (enums, composites, domains) are registered through `TypeMap`.
|
||||
- **pgxpool/** — Concurrency-safe connection pool built on `puddle/v2`. `Pool` is the main type; wraps `pgx.Conn`.
|
||||
- **stdlib/** — `database/sql` compatibility adapter.
|
||||
|
||||
Supporting packages:
|
||||
- **internal/stmtcache/** — Prepared statement cache with LRU eviction
|
||||
- **internal/sanitize/** — SQL query sanitization
|
||||
- **tracelog/** — Logging adapter that implements tracer interfaces
|
||||
- **multitracer/** — Composes multiple tracers into one
|
||||
- **pgxtest/** — Test helpers for running tests across connection types
|
||||
|
||||
## Key Design Conventions
|
||||
|
||||
- **Semantic versioning** — strictly followed. Do not break the public API (no removing or renaming exported types, functions, methods, or fields; no changing function signatures).
|
||||
- **Minimal dependencies** — adding new dependencies is strongly discouraged (see CONTRIBUTING.md).
|
||||
- **Context-based** — all blocking operations take `context.Context`.
|
||||
- **Tracer interfaces** — observability via `QueryTracer`, `BatchTracer`, `CopyFromTracer`, `PrepareTracer` on `ConnConfig.Tracer`.
|
||||
- **Formatting** — always run `goimports -w .` after making changes to ensure code is properly formatted. CI checks formatting via `gofmt -l -s -w . && git diff --exit-code`. `gofumpt` with extra rules is also enforced via `golangci-lint`.
|
||||
- **Linters** — `govet` and `ineffassign` only (configured in `.golangci.yml`).
|
||||
- **CI matrix** — tests run against Go 1.25/1.26 × PostgreSQL 14-18 + CockroachDB, on Linux and Windows. Race detector enabled on Linux only.
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
current_branch=$(git rev-parse --abbrev-ref HEAD)
|
||||
if [ "$current_branch" == "HEAD" ]; then
|
||||
current_branch=$(git rev-parse HEAD)
|
||||
fi
|
||||
|
||||
restore_branch() {
|
||||
echo "Restoring original branch/commit: $current_branch"
|
||||
git checkout "$current_branch"
|
||||
}
|
||||
trap restore_branch EXIT
|
||||
|
||||
# Check if there are uncommitted changes
|
||||
if ! git diff --quiet || ! git diff --cached --quiet; then
|
||||
echo "There are uncommitted changes. Please commit or stash them before running this script."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure that at least one commit argument is passed
|
||||
if [ "$#" -lt 1 ]; then
|
||||
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
commits=("$@")
|
||||
benchmarks_dir=benchmarks
|
||||
|
||||
if ! mkdir -p "${benchmarks_dir}"; then
|
||||
echo "Unable to create dir for benchmarks data"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Benchmark results
|
||||
bench_files=()
|
||||
|
||||
# Run benchmark for each listed commit
|
||||
for i in "${!commits[@]}"; do
|
||||
commit="${commits[i]}"
|
||||
git checkout "$commit" || {
|
||||
echo "Failed to checkout $commit"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Sanitized commit message
|
||||
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
|
||||
|
||||
# Benchmark data will go there
|
||||
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
|
||||
|
||||
if ! go test -bench=. -count=10 >"$bench_file"; then
|
||||
echo "Benchmarking failed for commit $commit"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
bench_files+=("$bench_file")
|
||||
done
|
||||
|
||||
# go install golang.org/x/perf/cmd/benchstat[@latest]
|
||||
benchstat "${bench_files[@]}"
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
package pgconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
func (c *PgConn) oauthAuth(ctx context.Context) error {
|
||||
if c.config.OAuthTokenProvider == nil {
|
||||
return errors.New("OAuth authentication required but no token provider configured")
|
||||
}
|
||||
|
||||
token, err := c.config.OAuthTokenProvider(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to obtain OAuth token: %w", err)
|
||||
}
|
||||
|
||||
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1
|
||||
initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01")
|
||||
|
||||
saslInitialResponse := &pgproto3.SASLInitialResponse{
|
||||
AuthMechanism: "OAUTHBEARER",
|
||||
Data: initialResponse,
|
||||
}
|
||||
c.frontend.Send(saslInitialResponse)
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationOk:
|
||||
return nil
|
||||
case *pgproto3.AuthenticationSASLContinue:
|
||||
// Server sent error response in SASL continue
|
||||
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2
|
||||
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3
|
||||
errResponse := struct {
|
||||
Status string `json:"status"`
|
||||
Scope string `json:"scope"`
|
||||
OpenIDConfiguration string `json:"openid-configuration"`
|
||||
}{}
|
||||
err := json.Unmarshal(m.Data, &errResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid OAuth error response from server: %w", err)
|
||||
}
|
||||
|
||||
// Per RFC 7628 section 3.2.3, we should send a SASLResponse which only contains \x01.
|
||||
// However, since the connection will be closed anyway, we can skip this
|
||||
return fmt.Errorf("OAuth authentication failed: %s", errResponse.Status)
|
||||
|
||||
case *pgproto3.ErrorResponse:
|
||||
return ErrorResponseToPgError(m)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected message type during OAuth auth: %T", msg)
|
||||
}
|
||||
}
|
||||
93
vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go
generated
vendored
Normal file
93
vendor/github.com/jackc/pgx/v5/pgproto3/negotiate_protocol_version.go
generated
vendored
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type NegotiateProtocolVersion struct {
|
||||
NewestMinorProtocol uint32
|
||||
UnrecognizedOptions []string
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*NegotiateProtocolVersion) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *NegotiateProtocolVersion) Decode(src []byte) error {
|
||||
if len(src) < 8 {
|
||||
return &invalidMessageLenErr{messageType: "NegotiateProtocolVersion", expectedLen: 8, actualLen: len(src)}
|
||||
}
|
||||
|
||||
dst.NewestMinorProtocol = binary.BigEndian.Uint32(src[:4])
|
||||
optionCount := int(binary.BigEndian.Uint32(src[4:8]))
|
||||
|
||||
rp := 8
|
||||
|
||||
// Use the remaining message size as an upper bound for capacity to prevent
|
||||
// malicious optionCount values from causing excessive memory allocation.
|
||||
capHint := optionCount
|
||||
if remaining := len(src) - rp; capHint > remaining {
|
||||
capHint = remaining
|
||||
}
|
||||
dst.UnrecognizedOptions = make([]string, 0, capHint)
|
||||
for i := 0; i < optionCount; i++ {
|
||||
if rp >= len(src) {
|
||||
return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"}
|
||||
}
|
||||
end := rp
|
||||
for end < len(src) && src[end] != 0 {
|
||||
end++
|
||||
}
|
||||
if end >= len(src) {
|
||||
return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"}
|
||||
}
|
||||
dst.UnrecognizedOptions = append(dst.UnrecognizedOptions, string(src[rp:end]))
|
||||
rp = end + 1
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *NegotiateProtocolVersion) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'v')
|
||||
dst = pgio.AppendUint32(dst, src.NewestMinorProtocol)
|
||||
dst = pgio.AppendUint32(dst, uint32(len(src.UnrecognizedOptions)))
|
||||
for _, option := range src.UnrecognizedOptions {
|
||||
dst = append(dst, option...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src NegotiateProtocolVersion) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
NewestMinorProtocol uint32
|
||||
UnrecognizedOptions []string
|
||||
}{
|
||||
Type: "NegotiateProtocolVersion",
|
||||
NewestMinorProtocol: src.NewestMinorProtocol,
|
||||
UnrecognizedOptions: src.UnrecognizedOptions,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *NegotiateProtocolVersion) UnmarshalJSON(data []byte) error {
|
||||
var msg struct {
|
||||
NewestMinorProtocol uint32
|
||||
UnrecognizedOptions []string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.NewestMinorProtocol = msg.NewestMinorProtocol
|
||||
dst.UnrecognizedOptions = msg.UnrecognizedOptions
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,507 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type TSVectorScanner interface {
|
||||
ScanTSVector(TSVector) error
|
||||
}
|
||||
|
||||
type TSVectorValuer interface {
|
||||
TSVectorValue() (TSVector, error)
|
||||
}
|
||||
|
||||
// TSVector represents a PostgreSQL tsvector value.
|
||||
type TSVector struct {
|
||||
Lexemes []TSVectorLexeme
|
||||
Valid bool
|
||||
}
|
||||
|
||||
// TSVectorLexeme represents a lexeme within a tsvector, consisting of a word and its positions.
|
||||
type TSVectorLexeme struct {
|
||||
Word string
|
||||
Positions []TSVectorPosition
|
||||
}
|
||||
|
||||
// ScanTSVector implements the [TSVectorScanner] interface.
|
||||
func (t *TSVector) ScanTSVector(v TSVector) error {
|
||||
*t = v
|
||||
return nil
|
||||
}
|
||||
|
||||
// TSVectorValue implements the [TSVectorValuer] interface.
|
||||
func (t TSVector) TSVectorValue() (TSVector, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t TSVector) String() string {
|
||||
buf, _ := encodePlanTSVectorCodecText{}.Encode(t, nil)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (t *TSVector) Scan(src any) error {
|
||||
if src == nil {
|
||||
*t = TSVector{}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case string:
|
||||
return scanPlanTextAnyToTSVectorScanner{}.scanString(src, t)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot scan %T", src)
|
||||
}
|
||||
|
||||
// Value implements the [database/sql/driver.Valuer] interface.
|
||||
func (t TSVector) Value() (driver.Value, error) {
|
||||
if !t.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buf, err := TSVectorCodec{}.PlanEncode(nil, 0, TextFormatCode, t).Encode(t, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
// TSVectorWeight represents the weight label of a lexeme position in a tsvector.
|
||||
type TSVectorWeight byte
|
||||
|
||||
const (
|
||||
TSVectorWeightA = TSVectorWeight('A')
|
||||
TSVectorWeightB = TSVectorWeight('B')
|
||||
TSVectorWeightC = TSVectorWeight('C')
|
||||
TSVectorWeightD = TSVectorWeight('D')
|
||||
)
|
||||
|
||||
// tsvectorWeightToBinary converts a TSVectorWeight to the 2-bit binary encoding used by PostgreSQL.
|
||||
func tsvectorWeightToBinary(w TSVectorWeight) uint16 {
|
||||
switch w {
|
||||
case TSVectorWeightA:
|
||||
return 3
|
||||
case TSVectorWeightB:
|
||||
return 2
|
||||
case TSVectorWeightC:
|
||||
return 1
|
||||
default:
|
||||
return 0 // D or unset
|
||||
}
|
||||
}
|
||||
|
||||
// tsvectorWeightFromBinary converts a 2-bit binary weight value to a TSVectorWeight.
|
||||
func tsvectorWeightFromBinary(b uint16) TSVectorWeight {
|
||||
switch b {
|
||||
case 3:
|
||||
return TSVectorWeightA
|
||||
case 2:
|
||||
return TSVectorWeightB
|
||||
case 1:
|
||||
return TSVectorWeightC
|
||||
default:
|
||||
return TSVectorWeightD
|
||||
}
|
||||
}
|
||||
|
||||
// TSVectorPosition represents a lexeme position and its optional weight within a tsvector.
|
||||
type TSVectorPosition struct {
|
||||
Position uint16
|
||||
Weight TSVectorWeight
|
||||
}
|
||||
|
||||
func (p TSVectorPosition) String() string {
|
||||
s := strconv.FormatUint(uint64(p.Position), 10)
|
||||
if p.Weight != 0 && p.Weight != TSVectorWeightD {
|
||||
s += string(p.Weight)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type TSVectorCodec struct{}
|
||||
|
||||
func (TSVectorCodec) FormatSupported(format int16) bool {
|
||||
return format == TextFormatCode || format == BinaryFormatCode
|
||||
}
|
||||
|
||||
func (TSVectorCodec) PreferredFormat() int16 {
|
||||
return BinaryFormatCode
|
||||
}
|
||||
|
||||
func (TSVectorCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||
if _, ok := value.(TSVectorValuer); !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
return encodePlanTSVectorCodecBinary{}
|
||||
case TextFormatCode:
|
||||
return encodePlanTSVectorCodecText{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type encodePlanTSVectorCodecBinary struct{}
|
||||
|
||||
func (encodePlanTSVectorCodecBinary) Encode(value any, buf []byte) ([]byte, error) {
|
||||
tsv, err := value.(TSVectorValuer).TSVectorValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !tsv.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt32(buf, int32(len(tsv.Lexemes)))
|
||||
|
||||
for _, entry := range tsv.Lexemes {
|
||||
buf = append(buf, entry.Word...)
|
||||
buf = append(buf, 0x00)
|
||||
buf = pgio.AppendUint16(buf, uint16(len(entry.Positions)))
|
||||
|
||||
// Each position is a uint16: weight (2 bits) | position (14 bits)
|
||||
for _, pos := range entry.Positions {
|
||||
packed := tsvectorWeightToBinary(pos.Weight)<<14 | uint16(pos.Position)&0x3FFF
|
||||
buf = pgio.AppendUint16(buf, packed)
|
||||
}
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
type scanPlanBinaryTSVectorToTSVectorScanner struct{}
|
||||
|
||||
func (scanPlanBinaryTSVectorToTSVectorScanner) Scan(src []byte, dst any) error {
|
||||
scanner := (dst).(TSVectorScanner)
|
||||
|
||||
if src == nil {
|
||||
return scanner.ScanTSVector(TSVector{})
|
||||
}
|
||||
|
||||
rp := 0
|
||||
|
||||
const (
|
||||
uint16Len = 2
|
||||
uint32Len = 4
|
||||
)
|
||||
|
||||
if len(src[rp:]) < uint32Len {
|
||||
return fmt.Errorf("tsvector incomplete %v", src)
|
||||
}
|
||||
entryCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += uint32Len
|
||||
|
||||
var tsv TSVector
|
||||
if entryCount > 0 {
|
||||
tsv.Lexemes = make([]TSVectorLexeme, entryCount)
|
||||
}
|
||||
|
||||
for i := range entryCount {
|
||||
nullIndex := bytes.IndexByte(src[rp:], 0x00)
|
||||
if nullIndex == -1 {
|
||||
return fmt.Errorf("invalid tsvector binary format: missing null terminator")
|
||||
}
|
||||
|
||||
lexeme := TSVectorLexeme{Word: string(src[rp : rp+nullIndex])}
|
||||
rp += nullIndex + 1 // skip past null terminator
|
||||
|
||||
// Read position count.
|
||||
if len(src[rp:]) < uint16Len {
|
||||
return fmt.Errorf("invalid tsvector binary format: incomplete position count")
|
||||
}
|
||||
|
||||
numPositions := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += uint16Len
|
||||
|
||||
// Read each packed position: weight (2 bits) | position (14 bits)
|
||||
if len(src[rp:]) < numPositions*uint16Len {
|
||||
return fmt.Errorf("invalid tsvector binary format: incomplete positions")
|
||||
}
|
||||
|
||||
if numPositions > 0 {
|
||||
lexeme.Positions = make([]TSVectorPosition, numPositions)
|
||||
for pos := range numPositions {
|
||||
packed := binary.BigEndian.Uint16(src[rp:])
|
||||
rp += uint16Len
|
||||
lexeme.Positions[pos] = TSVectorPosition{
|
||||
Position: packed & 0x3FFF,
|
||||
Weight: tsvectorWeightFromBinary(packed >> 14),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tsv.Lexemes[i] = lexeme
|
||||
}
|
||||
tsv.Valid = true
|
||||
|
||||
return scanner.ScanTSVector(tsv)
|
||||
}
|
||||
|
||||
var tsvectorLexemeReplacer = strings.NewReplacer(
|
||||
`\`, `\\`,
|
||||
`'`, `\'`,
|
||||
)
|
||||
|
||||
type encodePlanTSVectorCodecText struct{}
|
||||
|
||||
func (encodePlanTSVectorCodecText) Encode(value any, buf []byte) ([]byte, error) {
|
||||
tsv, err := value.(TSVectorValuer).TSVectorValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !tsv.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if buf == nil {
|
||||
buf = []byte{}
|
||||
}
|
||||
|
||||
for i, lex := range tsv.Lexemes {
|
||||
if i > 0 {
|
||||
buf = append(buf, ' ')
|
||||
}
|
||||
|
||||
buf = append(buf, '\'')
|
||||
buf = append(buf, tsvectorLexemeReplacer.Replace(lex.Word)...)
|
||||
buf = append(buf, '\'')
|
||||
|
||||
sep := byte(':')
|
||||
for _, p := range lex.Positions {
|
||||
buf = append(buf, sep)
|
||||
buf = append(buf, p.String()...)
|
||||
sep = ','
|
||||
}
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (TSVectorCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
switch target.(type) {
|
||||
case TSVectorScanner:
|
||||
return scanPlanBinaryTSVectorToTSVectorScanner{}
|
||||
}
|
||||
case TextFormatCode:
|
||||
switch target.(type) {
|
||||
case TSVectorScanner:
|
||||
return scanPlanTextAnyToTSVectorScanner{}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type scanPlanTextAnyToTSVectorScanner struct{}
|
||||
|
||||
func (s scanPlanTextAnyToTSVectorScanner) Scan(src []byte, dst any) error {
|
||||
scanner := (dst).(TSVectorScanner)
|
||||
|
||||
if src == nil {
|
||||
return scanner.ScanTSVector(TSVector{})
|
||||
}
|
||||
|
||||
return s.scanString(string(src), scanner)
|
||||
}
|
||||
|
||||
func (scanPlanTextAnyToTSVectorScanner) scanString(src string, scanner TSVectorScanner) error {
|
||||
tsv, err := parseTSVector(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return scanner.ScanTSVector(tsv)
|
||||
}
|
||||
|
||||
func (c TSVectorCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||
return codecDecodeToTextFormat(c, m, oid, format, src)
|
||||
}
|
||||
|
||||
func (c TSVectorCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var tsv TSVector
|
||||
err := codecScan(c, m, oid, format, src, &tsv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tsv, nil
|
||||
}
|
||||
|
||||
type tsvectorParser struct {
|
||||
str string
|
||||
pos int
|
||||
}
|
||||
|
||||
func (p *tsvectorParser) atEnd() bool {
|
||||
return p.pos >= len(p.str)
|
||||
}
|
||||
|
||||
func (p *tsvectorParser) peek() byte {
|
||||
return p.str[p.pos]
|
||||
}
|
||||
|
||||
func (p *tsvectorParser) consume() (byte, bool) {
|
||||
if p.pos >= len(p.str) {
|
||||
return 0, true
|
||||
}
|
||||
b := p.str[p.pos]
|
||||
p.pos++
|
||||
return b, false
|
||||
}
|
||||
|
||||
func (p *tsvectorParser) consumeSpaces() {
|
||||
for !p.atEnd() && p.peek() == ' ' {
|
||||
p.consume()
|
||||
}
|
||||
}
|
||||
|
||||
// consumeLexeme consumes a single-quoted lexeme, handling single quotes and backslash escapes.
|
||||
func (p *tsvectorParser) consumeLexeme() (string, error) {
|
||||
ch, end := p.consume()
|
||||
if end || ch != '\'' {
|
||||
return "", fmt.Errorf("invalid tsvector format: lexeme must start with a single quote")
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
for {
|
||||
ch, end := p.consume()
|
||||
if end {
|
||||
return "", fmt.Errorf("invalid tsvector format: unterminated quoted lexeme")
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case '\'':
|
||||
// Escaped quote ('') — write a literal single quote
|
||||
if !p.atEnd() && p.peek() == '\'' {
|
||||
p.consume()
|
||||
buf.WriteByte('\'')
|
||||
} else {
|
||||
// Closing quote — lexeme is complete
|
||||
return buf.String(), nil
|
||||
}
|
||||
case '\\':
|
||||
next, end := p.consume()
|
||||
if end {
|
||||
return "", fmt.Errorf("invalid tsvector format: unexpected end after backslash")
|
||||
}
|
||||
buf.WriteByte(next)
|
||||
default:
|
||||
buf.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// consumePositions consumes a comma-separated list of position[weight] values.
|
||||
func (p *tsvectorParser) consumePositions() ([]TSVectorPosition, error) {
|
||||
var positions []TSVectorPosition
|
||||
|
||||
for {
|
||||
pos, err := p.consumePosition()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
positions = append(positions, pos)
|
||||
|
||||
if p.atEnd() || p.peek() != ',' {
|
||||
break
|
||||
}
|
||||
|
||||
p.consume() // skip ','
|
||||
}
|
||||
|
||||
return positions, nil
|
||||
}
|
||||
|
||||
// consumePosition consumes a single position number with optional weight letter.
|
||||
func (p *tsvectorParser) consumePosition() (TSVectorPosition, error) {
|
||||
start := p.pos
|
||||
|
||||
for !p.atEnd() && p.peek() >= '0' && p.peek() <= '9' {
|
||||
p.consume()
|
||||
}
|
||||
|
||||
if p.pos == start {
|
||||
return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: expected position number")
|
||||
}
|
||||
|
||||
num, err := strconv.ParseUint(p.str[start:p.pos], 10, 16)
|
||||
if err != nil {
|
||||
return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: invalid position number %q", p.str[start:p.pos])
|
||||
}
|
||||
|
||||
pos := TSVectorPosition{Position: uint16(num), Weight: TSVectorWeightD}
|
||||
|
||||
// Check for optional weight letter
|
||||
if !p.atEnd() {
|
||||
switch p.peek() {
|
||||
case 'A', 'a':
|
||||
pos.Weight = TSVectorWeightA
|
||||
case 'B', 'b':
|
||||
pos.Weight = TSVectorWeightB
|
||||
case 'C', 'c':
|
||||
pos.Weight = TSVectorWeightC
|
||||
case 'D', 'd':
|
||||
pos.Weight = TSVectorWeightD
|
||||
default:
|
||||
return pos, nil
|
||||
}
|
||||
p.consume()
|
||||
}
|
||||
|
||||
return pos, nil
|
||||
}
|
||||
|
||||
// parseTSVector parses a PostgreSQL tsvector text representation.
|
||||
func parseTSVector(s string) (TSVector, error) {
|
||||
result := TSVector{}
|
||||
p := &tsvectorParser{str: strings.TrimSpace(s), pos: 0}
|
||||
|
||||
for !p.atEnd() {
|
||||
p.consumeSpaces()
|
||||
if p.atEnd() {
|
||||
break
|
||||
}
|
||||
|
||||
word, err := p.consumeLexeme()
|
||||
if err != nil {
|
||||
return TSVector{}, err
|
||||
}
|
||||
|
||||
entry := TSVectorLexeme{Word: word}
|
||||
|
||||
// Check for optional positions after ':'
|
||||
if !p.atEnd() && p.peek() == ':' {
|
||||
p.consume() // skip ':'
|
||||
|
||||
positions, err := p.consumePositions()
|
||||
if err != nil {
|
||||
return TSVector{}, err
|
||||
}
|
||||
entry.Positions = positions
|
||||
}
|
||||
|
||||
result.Lexemes = append(result.Lexemes, entry)
|
||||
}
|
||||
|
||||
result.Valid = true
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# test.sh - Run pgx tests against specific database targets
|
||||
#
|
||||
# Usage:
|
||||
# ./test.sh [target] [go test flags...]
|
||||
#
|
||||
# Targets:
|
||||
# pg14 - PostgreSQL 14 (port 5414)
|
||||
# pg15 - PostgreSQL 15 (port 5415)
|
||||
# pg16 - PostgreSQL 16 (port 5416)
|
||||
# pg17 - PostgreSQL 17 (port 5417)
|
||||
# pg18 - PostgreSQL 18 (port 5432) [default]
|
||||
# crdb - CockroachDB (port 26257)
|
||||
# all - Run against all targets sequentially
|
||||
#
|
||||
# Examples:
|
||||
# ./test.sh # Test against PG18
|
||||
# ./test.sh pg14 # Test against PG14
|
||||
# ./test.sh crdb # Test against CockroachDB
|
||||
# ./test.sh all # Test against all targets
|
||||
# ./test.sh pg16 -run TestConnect # Test specific test against PG16
|
||||
# ./test.sh pg18 -count=1 -v # Verbose, no cache, PG18
|
||||
|
||||
# Color output (disabled if not a terminal)
|
||||
if [ -t 1 ]; then
|
||||
GREEN='\033[0;32m'
|
||||
RED='\033[0;31m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m'
|
||||
else
|
||||
GREEN=''
|
||||
RED=''
|
||||
BLUE=''
|
||||
NC=''
|
||||
fi
|
||||
|
||||
log_info() { echo -e "${BLUE}==> $*${NC}"; }
|
||||
log_ok() { echo -e "${GREEN}==> $*${NC}"; }
|
||||
log_err() { echo -e "${RED}==> $*${NC}" >&2; }
|
||||
|
||||
# Wait for a database to accept connections
|
||||
wait_for_ready() {
|
||||
local connstr="$1"
|
||||
local label="$2"
|
||||
local max_attempts=30
|
||||
local attempt=0
|
||||
|
||||
log_info "Waiting for $label to be ready..."
|
||||
while ! psql "$connstr" -c "SELECT 1" > /dev/null 2>&1; do
|
||||
attempt=$((attempt + 1))
|
||||
if [ "$attempt" -ge "$max_attempts" ]; then
|
||||
log_err "$label did not become ready after $max_attempts attempts"
|
||||
return 1
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
log_ok "$label is ready"
|
||||
}
|
||||
|
||||
# Directory containing this script (used to locate testsetup/)
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
CERTS_DIR="$SCRIPT_DIR/testsetup/certs"
|
||||
|
||||
# Copy client certificates to /tmp for TLS tests
|
||||
setup_client_certs() {
|
||||
if [ -d "$CERTS_DIR" ]; then
|
||||
base64 -d "$CERTS_DIR/ca.pem.b64" > /tmp/ca.pem
|
||||
base64 -d "$CERTS_DIR/pgx_sslcert.crt.b64" > /tmp/pgx_sslcert.crt
|
||||
base64 -d "$CERTS_DIR/pgx_sslcert.key.b64" > /tmp/pgx_sslcert.key
|
||||
fi
|
||||
}
|
||||
|
||||
# Initialize CockroachDB (create database if not exists)
|
||||
init_crdb() {
|
||||
local connstr="postgresql://root@localhost:26257/?sslmode=disable"
|
||||
wait_for_ready "$connstr" "CockroachDB"
|
||||
log_info "Ensuring pgx_test database exists on CockroachDB..."
|
||||
psql "$connstr" -c "CREATE DATABASE IF NOT EXISTS pgx_test" 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Run tests against a single target
|
||||
run_tests() {
|
||||
local target="$1"
|
||||
shift
|
||||
local extra_args=("$@")
|
||||
|
||||
local label=""
|
||||
local port=""
|
||||
|
||||
case "$target" in
|
||||
pg14) label="PostgreSQL 14"; port=5414 ;;
|
||||
pg15) label="PostgreSQL 15"; port=5415 ;;
|
||||
pg16) label="PostgreSQL 16"; port=5416 ;;
|
||||
pg17) label="PostgreSQL 17"; port=5417 ;;
|
||||
pg18) label="PostgreSQL 18"; port=5432 ;;
|
||||
crdb)
|
||||
label="CockroachDB (port 26257)"
|
||||
init_crdb
|
||||
log_info "Testing against $label"
|
||||
if ! PGX_TEST_DATABASE="postgresql://root@localhost:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" \
|
||||
go test -count=1 "${extra_args[@]}" ./...; then
|
||||
log_err "Tests FAILED against $label"
|
||||
return 1
|
||||
fi
|
||||
log_ok "Tests passed against $label"
|
||||
return 0
|
||||
;;
|
||||
*)
|
||||
log_err "Unknown target: $target"
|
||||
log_err "Valid targets: pg14, pg15, pg16, pg17, pg18, crdb, all"
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
|
||||
setup_client_certs
|
||||
|
||||
log_info "Testing against $label (port $port)"
|
||||
if ! PGX_TEST_DATABASE="host=localhost port=$port user=postgres password=postgres dbname=pgx_test" \
|
||||
PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql port=$port user=postgres dbname=pgx_test" \
|
||||
PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \
|
||||
PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \
|
||||
PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" \
|
||||
PGX_TEST_SCRAM_PLUS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" \
|
||||
PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_pw password=secret dbname=pgx_test" \
|
||||
PGX_TEST_TLS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" \
|
||||
PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost port=$port user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" \
|
||||
PGX_SSL_PASSWORD=certpw \
|
||||
go test -count=1 "${extra_args[@]}" ./...; then
|
||||
log_err "Tests FAILED against $label"
|
||||
return 1
|
||||
fi
|
||||
log_ok "Tests passed against $label"
|
||||
}
|
||||
|
||||
# Main
|
||||
main() {
|
||||
local target="${1:-pg18}"
|
||||
|
||||
if [ "$target" = "all" ]; then
|
||||
shift || true
|
||||
local targets=(pg14 pg15 pg16 pg17 pg18 crdb)
|
||||
local failed=()
|
||||
|
||||
for t in "${targets[@]}"; do
|
||||
echo ""
|
||||
log_info "=========================================="
|
||||
log_info "Target: $t"
|
||||
log_info "=========================================="
|
||||
if ! run_tests "$t" "$@"; then
|
||||
failed+=("$t")
|
||||
log_err "FAILED: $t"
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
if [ ${#failed[@]} -gt 0 ]; then
|
||||
log_err "Failed targets: ${failed[*]}"
|
||||
return 1
|
||||
else
|
||||
log_ok "All targets passed"
|
||||
fi
|
||||
else
|
||||
shift || true
|
||||
run_tests "$target" "$@"
|
||||
fi
|
||||
}
|
||||
|
||||
main "$@"
|
||||
|
|
@ -2,14 +2,14 @@ package walletapp
|
|||
|
||||
import (
|
||||
"git.gocasts.ir/ebhomengo/niki/adapter/redis"
|
||||
"git.gocasts.ir/ebhomengo/niki/domain/shoppingbasket/repository"
|
||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/repository/postgres"
|
||||
"git.gocasts.ir/ebhomengo/niki/pkg/httpserver"
|
||||
logger "git.gocasts.ir/ebhomengo/niki/pkg/logger"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Redis redis.Config `koanf:"redis" json:"redis"`
|
||||
Repo repository.Config `koanf:"repo" json:"repo"`
|
||||
Repo postgres.Config `koanf:"repo" json:"repo"`
|
||||
HTTPServer httpserver.Config `koanf:"http_server" json:"http_server"`
|
||||
Logger logger.Config `koanf:"logger" json:"logger"`
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue