forked from ebhomengo/niki
wallet domain finished
This commit is contained in:
parent
27356e028f
commit
14a941493e
|
|
@ -1,12 +1,17 @@
|
||||||
package entity
|
package entity
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.gocasts.ir/ebhomengo/niki/pkg/types"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
|
)
|
||||||
|
|
||||||
type Transaction struct {
|
type Transaction struct {
|
||||||
ID uint64
|
ID uint64
|
||||||
UserID uint64
|
UserID uint64
|
||||||
Amount float64
|
Amount decimal.Decimal
|
||||||
Currency Currency
|
Currency types.Currency
|
||||||
ActionType TransactionType
|
ActionType TransactionType
|
||||||
Timestamp time.Time
|
Timestamp time.Time
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
|
|
@ -22,10 +27,3 @@ const (
|
||||||
TransactionTypeRefund TransactionType = "refund"
|
TransactionTypeRefund TransactionType = "refund"
|
||||||
TransactionTypeDonate TransactionType = "donate"
|
TransactionTypeDonate TransactionType = "donate"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Currency string
|
|
||||||
|
|
||||||
const (
|
|
||||||
IRR Currency = "IRR"
|
|
||||||
USD Currency = "USD"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,20 @@
|
||||||
package entity
|
package entity
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.gocasts.ir/ebhomengo/niki/pkg/types"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
|
)
|
||||||
|
|
||||||
type Wallet struct {
|
type Wallet struct {
|
||||||
ID uint64
|
ID uint64
|
||||||
UserID uint64 // user unique ID
|
UserID uint64 // user unique ID
|
||||||
Balance float64
|
Balance decimal.Decimal
|
||||||
Currency Currency
|
Currency types.Currency
|
||||||
Status WalletStatus // "active", "frozen", "closed"
|
Status WalletStatus // "active", "frozen", "closed"
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w Wallet) UnimplementedAllowUseDBGenericFunc() {}
|
func (w Wallet) UnimplementedAllowUseDBGenericFunc() {}
|
||||||
|
|
@ -19,7 +25,4 @@ const (
|
||||||
Frozen WalletStatus = "frozen" // when need to check , approve ,validate , solve sth (but deposit is possible)
|
Frozen WalletStatus = "frozen" // when need to check , approve ,validate , solve sth (but deposit is possible)
|
||||||
Active WalletStatus = "active" // when everything is ok
|
Active WalletStatus = "active" // when everything is ok
|
||||||
|
|
||||||
// ??
|
|
||||||
// Closed WalletStatus = "closed" // when need to check , approve ,validate , solve sth (exp : security problem)
|
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,19 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||||
|
"git.gocasts.ir/ebhomengo/niki/pkg/types"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CreateTransactionRequest struct {
|
type CreateTransactionRequest struct {
|
||||||
UserID uint64 `json:"user_id"`
|
UserID uint64 `json:"user_id"`
|
||||||
Amount float64 `json:"amount"`
|
Amount string `json:"amount"`
|
||||||
Currency entity.Currency `json:"currency"`
|
Currency types.Currency `json:"currency"`
|
||||||
ActionType entity.TransactionType `json:"action_type"`
|
ActionType entity.TransactionType `json:"action_type"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
Timestamp time.Time `json:"timestamp"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type InsertTransactionResponse struct {
|
type InsertTransactionResponse struct {
|
||||||
Balance float64 `json:"balance"`
|
Balance decimal.Decimal `json:"balance"`
|
||||||
|
Currency types.Currency `json:"currency"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import (
|
||||||
|
|
||||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||||
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
||||||
|
"git.gocasts.ir/ebhomengo/niki/pkg/types"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TransactionRequest struct {
|
type TransactionRequest struct {
|
||||||
|
|
@ -20,8 +22,8 @@ type TransactionResponse struct {
|
||||||
type TransactionInfo struct {
|
type TransactionInfo struct {
|
||||||
ID uint64 `json:"id"`
|
ID uint64 `json:"id"`
|
||||||
UserID uint64 `json:"user_id"`
|
UserID uint64 `json:"user_id"`
|
||||||
Amount float64 `json:"amount"`
|
Amount decimal.Decimal `json:"amount,string"`
|
||||||
Currency entity.Currency `json:"currency"`
|
Currency types.Currency `json:"currency"`
|
||||||
ActionType entity.TransactionType `json:"action_type"`
|
ActionType entity.TransactionType `json:"action_type"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
Timestamp time.Time `json:"timestamp"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||||
|
"git.gocasts.ir/ebhomengo/niki/pkg/types"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WalletRequest struct {
|
type WalletRequest struct {
|
||||||
|
|
@ -13,9 +15,9 @@ type WalletRequest struct {
|
||||||
type WalletResponse struct {
|
type WalletResponse struct {
|
||||||
Wallet WalletInfo `json:"wallet"`
|
Wallet WalletInfo `json:"wallet"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type WalletInfo struct {
|
type WalletInfo struct {
|
||||||
Balance float64 `json:"balance"`
|
Balance decimal.Decimal `json:"balance,string"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
Status entity.WalletStatus `json:"status"`
|
Status entity.WalletStatus `json:"status"`
|
||||||
|
Currency types.Currency `json:"currency"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
conn *postgres.DB
|
conn *postgres.DB
|
||||||
|
Config Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(conn *postgres.DB) *DB {
|
func New(conn *postgres.DB) *DB {
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ func scanWallet(scanner postgres.Scanner) (entity.Wallet, error) {
|
||||||
|
|
||||||
var wallet entity.Wallet
|
var wallet entity.Wallet
|
||||||
|
|
||||||
err := scanner.Scan(&wallet.ID, &wallet.UserID, &wallet.Balance, &wallet.Currency, &wallet.Status, &wallet.UpdatedAt)
|
err := scanner.Scan(&wallet.ID, &wallet.UserID, &wallet.Balance, &wallet.Currency, &wallet.Status, &wallet.UpdatedAt, &wallet.CreatedAt, &wallet.CreatedAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return entity.Wallet{}, err
|
return entity.Wallet{}, err
|
||||||
}
|
}
|
||||||
|
|
@ -2,16 +2,21 @@ package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||||
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
||||||
errmsg "git.gocasts.ir/ebhomengo/niki/pkg/err_msg"
|
errmsg "git.gocasts.ir/ebhomengo/niki/pkg/err_msg"
|
||||||
richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error"
|
richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error"
|
||||||
|
"git.gocasts.ir/ebhomengo/niki/pkg/types"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (db *DB) InsertTransaction(ctx context.Context, transaction entity.Transaction) (balance float64, err error) {
|
func (db *DB) InsertTransaction(ctx context.Context, transaction entity.Transaction, currencyRate decimal.Decimal) (balance decimal.Decimal, currency types.Currency, err error) {
|
||||||
const op = richerror.Op("wallet.repo.InsertTransaction")
|
const op = richerror.Op("wallet.repo.InsertTransaction")
|
||||||
|
|
||||||
|
// TODO : USE TX INSTANT QUERY
|
||||||
|
|
||||||
query := `INSERT INTO Transactions (user_id, amount ,currency ,action_type, timestamp) values ($1, $2, $3, $4, $5)`
|
query := `INSERT INTO Transactions (user_id, amount ,currency ,action_type, timestamp) values ($1, $2, $3, $4, $5)`
|
||||||
|
|
||||||
stmt, stErr := db.conn.PrepareStatement(ctx, postgres.StatementKeyWalletInsertTransaction, query)
|
stmt, stErr := db.conn.PrepareStatement(ctx, postgres.StatementKeyWalletInsertTransaction, query)
|
||||||
|
|
@ -50,18 +55,19 @@ func (db *DB) InsertTransaction(ctx context.Context, transaction entity.Transact
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newBalance, blanceErr := db.UpsertBalance(newCtx, transaction.UserID, transaction.Amount)
|
newBalance, walletCurrency, balanceErr := db.UpsertBalance(newCtx, transaction, currencyRate)
|
||||||
if blanceErr != nil {
|
if balanceErr != nil {
|
||||||
err = richerror.New(op).WithErr(blanceErr)
|
err = richerror.New(op).WithErr(balanceErr)
|
||||||
return
|
return
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
balance = newBalance
|
balance = newBalance
|
||||||
|
currency = walletCurrency
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) UpsertBalance(newCtx context.Context, userID uint64, amount float64) (balance float64, err error) {
|
func (db *DB) UpsertBalance(newCtx context.Context, transaction entity.Transaction, currencyRate decimal.Decimal) (balance decimal.Decimal, currency types.Currency, err error) {
|
||||||
const op = richerror.Op("wallet.repo.UpdateBalance")
|
const op = richerror.Op("wallet.repo.UpdateBalance")
|
||||||
|
|
||||||
txHolder, _ := postgres.GetDBTxHolderFromContextOrNew(newCtx)
|
txHolder, _ := postgres.GetDBTxHolderFromContextOrNew(newCtx)
|
||||||
|
|
@ -77,23 +83,23 @@ func (db *DB) UpsertBalance(newCtx context.Context, userID uint64, amount float6
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
upsertQuery := `INSERT INTO wallets (user_id, balance) VALUES ($1, $2)
|
upsertQuery := `INSERT INTO wallets (user_id, balance , updated_at) VALUES ($1,$2,$3 )
|
||||||
ON CONFLICT (user_id) DO UPDATE SET balance = wallets.balance + $2
|
ON CONFLICT (user_id) DO UPDATE SET balance = wallets.balance + $2
|
||||||
RETURNING balance`
|
RETURNING balance , currency`
|
||||||
|
|
||||||
upsertStmt, stErr := db.conn.PrepareStatement(newCtx, postgres.StatementKeyWalletUpsertBalance, upsertQuery)
|
upsertStmt, stErr := db.conn.PrepareStatement(newCtx, postgres.StatementKeyWalletUpsertBalance, upsertQuery)
|
||||||
if stErr != nil {
|
if stErr != nil {
|
||||||
err = richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery).WithKind(richerror.KindUnexpected)
|
err = richerror.New(op).WithMessage(errmsg.ErrorMsgFailedQuery).WithKind(richerror.KindUnexpected)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
amount := transaction.Amount.Mul(currencyRate)
|
||||||
|
row := tx.StmtQueryRowContext(newCtx, upsertStmt, transaction.UserID, amount, time.Now())
|
||||||
|
|
||||||
row := tx.StmtQueryRowContext(newCtx, upsertStmt, userID, amount)
|
sErr := row.Scan(&balance, ¤cy)
|
||||||
|
|
||||||
sErr := row.Scan(&balance)
|
|
||||||
if sErr != nil {
|
if sErr != nil {
|
||||||
err = richerror.New(op).WithMessage(errmsg.ErrorMsgCantScanQueryResult).WithKind(richerror.KindUnexpected)
|
err = richerror.New(op).WithMessage(errmsg.ErrorMsgCantScanQueryResult).WithKind(richerror.KindUnexpected)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return balance, nil
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
-- +migrate Up
|
||||||
|
|
||||||
|
CREATE TABLE "transaction" (
|
||||||
|
"id" BIGSERIAL PRIMARY KEY,
|
||||||
|
"user_id" BIGINT NOT NULL,
|
||||||
|
"amount" NUMERIC(20, 2) NOT NULL,
|
||||||
|
"currency" VARCHAR(100) NOT NULL,
|
||||||
|
"action_type" VARCHAR(100) NOT NULL,
|
||||||
|
"timestamp" TIMESTAMP NOT NULL,
|
||||||
|
"created_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
-- +migrate Down
|
||||||
|
DROP TABLE IF EXISTS "transaction";
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
-- +migrate Up
|
||||||
|
|
||||||
|
CREATE TABLE "wallets" (
|
||||||
|
"id" BIGSERIAL PRIMARY KEY,
|
||||||
|
"user_id" BIGINT NOT NULL,
|
||||||
|
"balance" NUMERIC(20, 2) NOT NULL,
|
||||||
|
"currency" VARCHAR(100) NOT NULL DEFAULT 'IRR',
|
||||||
|
"status" VARCHAR(100) NOT NULL DEFAULT 'active',
|
||||||
|
"updated_at" TIMESTAMP,
|
||||||
|
"created_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
-- +migrate Down
|
||||||
|
DROP TABLE IF EXISTS "wallets";
|
||||||
|
|
@ -7,26 +7,51 @@ import (
|
||||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/param"
|
"git.gocasts.ir/ebhomengo/niki/domain/wallet/param"
|
||||||
richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error"
|
richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error"
|
||||||
|
"git.gocasts.ir/ebhomengo/niki/pkg/types"
|
||||||
|
|
||||||
|
//"git.gocasts.ir/ebhomengo/niki/pkg/types"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s Service) CreateTransaction(ctx context.Context, request param.CreateTransactionRequest) (param.InsertTransactionResponse, error) {
|
func (s Service) CreateTransaction(ctx context.Context, request param.CreateTransactionRequest) (param.InsertTransactionResponse, error) {
|
||||||
|
|
||||||
const op = richerror.Op("wallet.service.CreateTransaction")
|
const op = richerror.Op("wallet.service.CreateTransaction")
|
||||||
|
|
||||||
|
currencyRate := s.convertCurrency(request.Currency)
|
||||||
|
|
||||||
|
convertedAmount, _ := decimal.NewFromString(request.Amount)
|
||||||
|
|
||||||
transaction := entity.Transaction{
|
transaction := entity.Transaction{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
UserID: request.UserID,
|
UserID: request.UserID,
|
||||||
Amount: request.Amount,
|
Amount: convertedAmount,
|
||||||
Currency: request.Currency,
|
Currency: request.Currency,
|
||||||
ActionType: request.ActionType,
|
ActionType: request.ActionType,
|
||||||
Timestamp: request.Timestamp,
|
Timestamp: request.Timestamp,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
balance, err := s.repo.InsertTransaction(ctx, transaction)
|
|
||||||
if err != nil {
|
balance, walletCurrency, inErr := s.repo.InsertTransaction(ctx, transaction, currencyRate)
|
||||||
return param.InsertTransactionResponse{}, richerror.New(op).WithErr(err)
|
if inErr != nil {
|
||||||
|
return param.InsertTransactionResponse{}, richerror.New(op).WithErr(inErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return param.InsertTransactionResponse{Balance: balance}, nil
|
return param.InsertTransactionResponse{Balance: balance, Currency: walletCurrency}, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Service) convertCurrency(currency types.Currency) decimal.Decimal {
|
||||||
|
|
||||||
|
if currency != types.IRR {
|
||||||
|
currencyRate, cErr := s.currencyRateProvider.GetCurrencyPriceRateInIRR(currency)
|
||||||
|
if cErr != nil {
|
||||||
|
// log // fallback or change provider
|
||||||
|
return decimal.Zero // if 0 => transaction commited with currency but wallet doesn't update or add 0 to wallet
|
||||||
|
}
|
||||||
|
|
||||||
|
return currencyRate
|
||||||
|
|
||||||
|
}
|
||||||
|
return decimal.NewFromInt(1)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ func (s Service) GetUserWallet(ctx context.Context, request param.WalletRequest)
|
||||||
return param.WalletResponse{
|
return param.WalletResponse{
|
||||||
Wallet: param.WalletInfo{
|
Wallet: param.WalletInfo{
|
||||||
Balance: wallet.Balance,
|
Balance: wallet.Balance,
|
||||||
|
Currency: wallet.Currency,
|
||||||
UpdatedAt: wallet.UpdatedAt,
|
UpdatedAt: wallet.UpdatedAt,
|
||||||
Status: wallet.Status,
|
Status: wallet.Status,
|
||||||
},
|
},
|
||||||
|
|
@ -5,12 +5,18 @@ import (
|
||||||
|
|
||||||
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
"git.gocasts.ir/ebhomengo/niki/domain/wallet/entity"
|
||||||
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
"git.gocasts.ir/ebhomengo/niki/pkg/database/postgres"
|
||||||
|
"git.gocasts.ir/ebhomengo/niki/pkg/types"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type CurrencyRateProvider interface {
|
||||||
|
GetCurrencyPriceRateInIRR(currency types.Currency) (decimal.Decimal, error)
|
||||||
|
}
|
||||||
|
|
||||||
type Repository interface {
|
type Repository interface {
|
||||||
GetTransactionListByUserID(ctx context.Context, UserID uint64, DBPagination postgres.DBPagination) ([]entity.Transaction, int64, error)
|
GetTransactionListByUserID(ctx context.Context, UserID uint64, DBPagination postgres.DBPagination) ([]entity.Transaction, int64, error)
|
||||||
GetWalletByUserID(ctx context.Context, UserID uint64) (entity.Wallet, error)
|
GetWalletByUserID(ctx context.Context, UserID uint64) (entity.Wallet, error)
|
||||||
InsertTransaction(ctx context.Context, transaction entity.Transaction) (float64, error)
|
InsertTransaction(ctx context.Context, transaction entity.Transaction, currencyRate decimal.Decimal) (decimal.Decimal, types.Currency, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
|
@ -18,8 +24,9 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
repo Repository
|
repo Repository
|
||||||
cfg Config
|
cfg Config
|
||||||
|
currencyRateProvider CurrencyRateProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(repo Repository, cfg Config) Service {
|
func New(repo Repository, cfg Config) Service {
|
||||||
|
|
|
||||||
1
go.mod
1
go.mod
|
|
@ -21,6 +21,7 @@ require (
|
||||||
github.com/ory/dockertest/v3 v3.12.0
|
github.com/ory/dockertest/v3 v3.12.0
|
||||||
github.com/redis/go-redis/v9 v9.18.0
|
github.com/redis/go-redis/v9 v9.18.0
|
||||||
github.com/rubenv/sql-migrate v1.8.1
|
github.com/rubenv/sql-migrate v1.8.1
|
||||||
|
github.com/shopspring/decimal v1.4.0
|
||||||
github.com/spf13/cobra v1.10.2
|
github.com/spf13/cobra v1.10.2
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/swaggo/echo-swagger v1.5.2
|
github.com/swaggo/echo-swagger v1.5.2
|
||||||
|
|
|
||||||
2
go.sum
2
go.sum
|
|
@ -366,6 +366,8 @@ github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb
|
||||||
github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
|
github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
|
||||||
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
|
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
|
||||||
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
|
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
|
||||||
|
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||||
|
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||||
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
type Currency string
|
||||||
|
|
||||||
|
const (
|
||||||
|
IRR Currency = "IRR"
|
||||||
|
USD Currency = "USD"
|
||||||
|
)
|
||||||
|
|
||||||
|
func StringCastToCurrency(s string) (Currency, error) {
|
||||||
|
switch s {
|
||||||
|
case "IRR":
|
||||||
|
return IRR, nil
|
||||||
|
case "USD":
|
||||||
|
return USD, nil
|
||||||
|
default:
|
||||||
|
return IRR, errors.New("not a valid currency")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
# CLAUDE.md
|
|
||||||
|
|
||||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
|
||||||
|
|
||||||
## Project Overview
|
|
||||||
|
|
||||||
pgx is a PostgreSQL driver and toolkit for Go (`github.com/jackc/pgx/v5`). It provides both a native PostgreSQL interface and a `database/sql` compatible driver. Requires Go 1.25+ and supports PostgreSQL 14+ and CockroachDB.
|
|
||||||
|
|
||||||
## Build & Test Commands
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run all tests (requires PGX_TEST_DATABASE to be set)
|
|
||||||
go test ./...
|
|
||||||
|
|
||||||
# Run a specific test
|
|
||||||
go test -run TestFunctionName ./...
|
|
||||||
|
|
||||||
# Run tests for a specific package
|
|
||||||
go test ./pgconn/...
|
|
||||||
|
|
||||||
# Run tests with race detector
|
|
||||||
go test -race ./...
|
|
||||||
|
|
||||||
# DevContainer: run tests against specific PostgreSQL versions
|
|
||||||
./test.sh pg18 # Default: PostgreSQL 18
|
|
||||||
./test.sh pg16 -run TestConnect # Specific test against PG16
|
|
||||||
./test.sh crdb # CockroachDB
|
|
||||||
./test.sh all # All targets (pg14-18 + crdb)
|
|
||||||
|
|
||||||
# Format (always run after making changes)
|
|
||||||
goimports -w .
|
|
||||||
|
|
||||||
# Lint
|
|
||||||
golangci-lint run ./...
|
|
||||||
```
|
|
||||||
|
|
||||||
## Test Database Setup
|
|
||||||
|
|
||||||
Tests require `PGX_TEST_DATABASE` environment variable. In the devcontainer, `test.sh` handles this. For local development:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export PGX_TEST_DATABASE="host=localhost user=postgres password=postgres dbname=pgx_test"
|
|
||||||
```
|
|
||||||
|
|
||||||
The test database needs extensions: `hstore`, `ltree`, and a `uint64` domain. See `testsetup/postgresql_setup.sql` for full setup. Many tests are skipped unless additional `PGX_TEST_*` env vars are set (for TLS, SCRAM, MD5, unix socket, PgBouncer testing).
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
The codebase is a layered architecture, bottom-up:
|
|
||||||
|
|
||||||
- **pgproto3/** — PostgreSQL wire protocol v3 encoder/decoder. Defines `FrontendMessage` and `BackendMessage` types for every protocol message.
|
|
||||||
- **pgconn/** — Low-level connection layer (roughly libpq-equivalent). Handles authentication, TLS, query execution, COPY protocol, and notifications. `PgConn` is the core type.
|
|
||||||
- **pgx** (root package) — High-level query interface built on `pgconn`. Provides `Conn`, `Rows`, `Tx`, `Batch`, `CopyFrom`, and generic helpers like `CollectRows`/`ForEachRow`. Includes automatic statement caching (LRU).
|
|
||||||
- **pgtype/** — Type system mapping between Go and PostgreSQL types (70+ types). Key interfaces: `Codec`, `Type`, `TypeMap`. Custom types (enums, composites, domains) are registered through `TypeMap`.
|
|
||||||
- **pgxpool/** — Concurrency-safe connection pool built on `puddle/v2`. `Pool` is the main type; wraps `pgx.Conn`.
|
|
||||||
- **stdlib/** — `database/sql` compatibility adapter.
|
|
||||||
|
|
||||||
Supporting packages:
|
|
||||||
- **internal/stmtcache/** — Prepared statement cache with LRU eviction
|
|
||||||
- **internal/sanitize/** — SQL query sanitization
|
|
||||||
- **tracelog/** — Logging adapter that implements tracer interfaces
|
|
||||||
- **multitracer/** — Composes multiple tracers into one
|
|
||||||
- **pgxtest/** — Test helpers for running tests across connection types
|
|
||||||
|
|
||||||
## Key Design Conventions
|
|
||||||
|
|
||||||
- **Semantic versioning** — strictly followed. Do not break the public API (no removing or renaming exported types, functions, methods, or fields; no changing function signatures).
|
|
||||||
- **Minimal dependencies** — adding new dependencies is strongly discouraged (see CONTRIBUTING.md).
|
|
||||||
- **Context-based** — all blocking operations take `context.Context`.
|
|
||||||
- **Tracer interfaces** — observability via `QueryTracer`, `BatchTracer`, `CopyFromTracer`, `PrepareTracer` on `ConnConfig.Tracer`.
|
|
||||||
- **Formatting** — always run `goimports -w .` after making changes to ensure code is properly formatted. CI checks formatting via `gofmt -l -s -w . && git diff --exit-code`. `gofumpt` with extra rules is also enforced via `golangci-lint`.
|
|
||||||
- **Linters** — `govet` and `ineffassign` only (configured in `.golangci.yml`).
|
|
||||||
- **CI matrix** — tests run against Go 1.25/1.26 × PostgreSQL 14-18 + CockroachDB, on Linux and Windows. Race detector enabled on Linux only.
|
|
||||||
|
|
@ -1,60 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
current_branch=$(git rev-parse --abbrev-ref HEAD)
|
|
||||||
if [ "$current_branch" == "HEAD" ]; then
|
|
||||||
current_branch=$(git rev-parse HEAD)
|
|
||||||
fi
|
|
||||||
|
|
||||||
restore_branch() {
|
|
||||||
echo "Restoring original branch/commit: $current_branch"
|
|
||||||
git checkout "$current_branch"
|
|
||||||
}
|
|
||||||
trap restore_branch EXIT
|
|
||||||
|
|
||||||
# Check if there are uncommitted changes
|
|
||||||
if ! git diff --quiet || ! git diff --cached --quiet; then
|
|
||||||
echo "There are uncommitted changes. Please commit or stash them before running this script."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Ensure that at least one commit argument is passed
|
|
||||||
if [ "$#" -lt 1 ]; then
|
|
||||||
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
commits=("$@")
|
|
||||||
benchmarks_dir=benchmarks
|
|
||||||
|
|
||||||
if ! mkdir -p "${benchmarks_dir}"; then
|
|
||||||
echo "Unable to create dir for benchmarks data"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Benchmark results
|
|
||||||
bench_files=()
|
|
||||||
|
|
||||||
# Run benchmark for each listed commit
|
|
||||||
for i in "${!commits[@]}"; do
|
|
||||||
commit="${commits[i]}"
|
|
||||||
git checkout "$commit" || {
|
|
||||||
echo "Failed to checkout $commit"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
# Sanitized commit message
|
|
||||||
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
|
|
||||||
|
|
||||||
# Benchmark data will go there
|
|
||||||
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
|
|
||||||
|
|
||||||
if ! go test -bench=. -count=10 >"$bench_file"; then
|
|
||||||
echo "Benchmarking failed for commit $commit"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
bench_files+=("$bench_file")
|
|
||||||
done
|
|
||||||
|
|
||||||
# go install golang.org/x/perf/cmd/benchstat[@latest]
|
|
||||||
benchstat "${bench_files[@]}"
|
|
||||||
|
|
@ -1,67 +0,0 @@
|
||||||
package pgconn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/pgproto3"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (c *PgConn) oauthAuth(ctx context.Context) error {
|
|
||||||
if c.config.OAuthTokenProvider == nil {
|
|
||||||
return errors.New("OAuth authentication required but no token provider configured")
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := c.config.OAuthTokenProvider(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to obtain OAuth token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1
|
|
||||||
initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01")
|
|
||||||
|
|
||||||
saslInitialResponse := &pgproto3.SASLInitialResponse{
|
|
||||||
AuthMechanism: "OAUTHBEARER",
|
|
||||||
Data: initialResponse,
|
|
||||||
}
|
|
||||||
c.frontend.Send(saslInitialResponse)
|
|
||||||
err = c.flushWithPotentialWriteReadDeadlock()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := c.receiveMessage()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch m := msg.(type) {
|
|
||||||
case *pgproto3.AuthenticationOk:
|
|
||||||
return nil
|
|
||||||
case *pgproto3.AuthenticationSASLContinue:
|
|
||||||
// Server sent error response in SASL continue
|
|
||||||
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2
|
|
||||||
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3
|
|
||||||
errResponse := struct {
|
|
||||||
Status string `json:"status"`
|
|
||||||
Scope string `json:"scope"`
|
|
||||||
OpenIDConfiguration string `json:"openid-configuration"`
|
|
||||||
}{}
|
|
||||||
err := json.Unmarshal(m.Data, &errResponse)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid OAuth error response from server: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Per RFC 7628 section 3.2.3, we should send a SASLResponse which only contains \x01.
|
|
||||||
// However, since the connection will be closed anyway, we can skip this
|
|
||||||
return fmt.Errorf("OAuth authentication failed: %s", errResponse.Status)
|
|
||||||
|
|
||||||
case *pgproto3.ErrorResponse:
|
|
||||||
return ErrorResponseToPgError(m)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unexpected message type during OAuth auth: %T", msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/internal/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type NegotiateProtocolVersion struct {
|
|
||||||
NewestMinorProtocol uint32
|
|
||||||
UnrecognizedOptions []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
|
||||||
func (*NegotiateProtocolVersion) Backend() {}
|
|
||||||
|
|
||||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
|
||||||
// type identifier and 4 byte message length.
|
|
||||||
func (dst *NegotiateProtocolVersion) Decode(src []byte) error {
|
|
||||||
if len(src) < 8 {
|
|
||||||
return &invalidMessageLenErr{messageType: "NegotiateProtocolVersion", expectedLen: 8, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.NewestMinorProtocol = binary.BigEndian.Uint32(src[:4])
|
|
||||||
optionCount := int(binary.BigEndian.Uint32(src[4:8]))
|
|
||||||
|
|
||||||
rp := 8
|
|
||||||
|
|
||||||
// Use the remaining message size as an upper bound for capacity to prevent
|
|
||||||
// malicious optionCount values from causing excessive memory allocation.
|
|
||||||
capHint := optionCount
|
|
||||||
if remaining := len(src) - rp; capHint > remaining {
|
|
||||||
capHint = remaining
|
|
||||||
}
|
|
||||||
dst.UnrecognizedOptions = make([]string, 0, capHint)
|
|
||||||
for i := 0; i < optionCount; i++ {
|
|
||||||
if rp >= len(src) {
|
|
||||||
return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"}
|
|
||||||
}
|
|
||||||
end := rp
|
|
||||||
for end < len(src) && src[end] != 0 {
|
|
||||||
end++
|
|
||||||
}
|
|
||||||
if end >= len(src) {
|
|
||||||
return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"}
|
|
||||||
}
|
|
||||||
dst.UnrecognizedOptions = append(dst.UnrecognizedOptions, string(src[rp:end]))
|
|
||||||
rp = end + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
|
||||||
func (src *NegotiateProtocolVersion) Encode(dst []byte) ([]byte, error) {
|
|
||||||
dst, sp := beginMessage(dst, 'v')
|
|
||||||
dst = pgio.AppendUint32(dst, src.NewestMinorProtocol)
|
|
||||||
dst = pgio.AppendUint32(dst, uint32(len(src.UnrecognizedOptions)))
|
|
||||||
for _, option := range src.UnrecognizedOptions {
|
|
||||||
dst = append(dst, option...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
}
|
|
||||||
return finishMessage(dst, sp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalJSON implements encoding/json.Marshaler.
|
|
||||||
func (src NegotiateProtocolVersion) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
NewestMinorProtocol uint32
|
|
||||||
UnrecognizedOptions []string
|
|
||||||
}{
|
|
||||||
Type: "NegotiateProtocolVersion",
|
|
||||||
NewestMinorProtocol: src.NewestMinorProtocol,
|
|
||||||
UnrecognizedOptions: src.UnrecognizedOptions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
|
||||||
func (dst *NegotiateProtocolVersion) UnmarshalJSON(data []byte) error {
|
|
||||||
var msg struct {
|
|
||||||
NewestMinorProtocol uint32
|
|
||||||
UnrecognizedOptions []string
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(data, &msg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.NewestMinorProtocol = msg.NewestMinorProtocol
|
|
||||||
dst.UnrecognizedOptions = msg.UnrecognizedOptions
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,507 +0,0 @@
|
||||||
package pgtype
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"database/sql/driver"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/internal/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TSVectorScanner interface {
|
|
||||||
ScanTSVector(TSVector) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type TSVectorValuer interface {
|
|
||||||
TSVectorValue() (TSVector, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TSVector represents a PostgreSQL tsvector value.
|
|
||||||
type TSVector struct {
|
|
||||||
Lexemes []TSVectorLexeme
|
|
||||||
Valid bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// TSVectorLexeme represents a lexeme within a tsvector, consisting of a word and its positions.
|
|
||||||
type TSVectorLexeme struct {
|
|
||||||
Word string
|
|
||||||
Positions []TSVectorPosition
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanTSVector implements the [TSVectorScanner] interface.
|
|
||||||
func (t *TSVector) ScanTSVector(v TSVector) error {
|
|
||||||
*t = v
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TSVectorValue implements the [TSVectorValuer] interface.
|
|
||||||
func (t TSVector) TSVectorValue() (TSVector, error) {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t TSVector) String() string {
|
|
||||||
buf, _ := encodePlanTSVectorCodecText{}.Encode(t, nil)
|
|
||||||
return string(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan implements the [database/sql.Scanner] interface.
|
|
||||||
func (t *TSVector) Scan(src any) error {
|
|
||||||
if src == nil {
|
|
||||||
*t = TSVector{}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch src := src.(type) {
|
|
||||||
case string:
|
|
||||||
return scanPlanTextAnyToTSVectorScanner{}.scanString(src, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("cannot scan %T", src)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value implements the [database/sql/driver.Valuer] interface.
|
|
||||||
func (t TSVector) Value() (driver.Value, error) {
|
|
||||||
if !t.Valid {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
buf, err := TSVectorCodec{}.PlanEncode(nil, 0, TextFormatCode, t).Encode(t, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(buf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TSVectorWeight represents the weight label of a lexeme position in a tsvector.
|
|
||||||
type TSVectorWeight byte
|
|
||||||
|
|
||||||
const (
|
|
||||||
TSVectorWeightA = TSVectorWeight('A')
|
|
||||||
TSVectorWeightB = TSVectorWeight('B')
|
|
||||||
TSVectorWeightC = TSVectorWeight('C')
|
|
||||||
TSVectorWeightD = TSVectorWeight('D')
|
|
||||||
)
|
|
||||||
|
|
||||||
// tsvectorWeightToBinary converts a TSVectorWeight to the 2-bit binary encoding used by PostgreSQL.
|
|
||||||
func tsvectorWeightToBinary(w TSVectorWeight) uint16 {
|
|
||||||
switch w {
|
|
||||||
case TSVectorWeightA:
|
|
||||||
return 3
|
|
||||||
case TSVectorWeightB:
|
|
||||||
return 2
|
|
||||||
case TSVectorWeightC:
|
|
||||||
return 1
|
|
||||||
default:
|
|
||||||
return 0 // D or unset
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// tsvectorWeightFromBinary converts a 2-bit binary weight value to a TSVectorWeight.
|
|
||||||
func tsvectorWeightFromBinary(b uint16) TSVectorWeight {
|
|
||||||
switch b {
|
|
||||||
case 3:
|
|
||||||
return TSVectorWeightA
|
|
||||||
case 2:
|
|
||||||
return TSVectorWeightB
|
|
||||||
case 1:
|
|
||||||
return TSVectorWeightC
|
|
||||||
default:
|
|
||||||
return TSVectorWeightD
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TSVectorPosition represents a lexeme position and its optional weight within a tsvector.
|
|
||||||
type TSVectorPosition struct {
|
|
||||||
Position uint16
|
|
||||||
Weight TSVectorWeight
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p TSVectorPosition) String() string {
|
|
||||||
s := strconv.FormatUint(uint64(p.Position), 10)
|
|
||||||
if p.Weight != 0 && p.Weight != TSVectorWeightD {
|
|
||||||
s += string(p.Weight)
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
type TSVectorCodec struct{}
|
|
||||||
|
|
||||||
func (TSVectorCodec) FormatSupported(format int16) bool {
|
|
||||||
return format == TextFormatCode || format == BinaryFormatCode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (TSVectorCodec) PreferredFormat() int16 {
|
|
||||||
return BinaryFormatCode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (TSVectorCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
|
||||||
if _, ok := value.(TSVectorValuer); !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch format {
|
|
||||||
case BinaryFormatCode:
|
|
||||||
return encodePlanTSVectorCodecBinary{}
|
|
||||||
case TextFormatCode:
|
|
||||||
return encodePlanTSVectorCodecText{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type encodePlanTSVectorCodecBinary struct{}
|
|
||||||
|
|
||||||
func (encodePlanTSVectorCodecBinary) Encode(value any, buf []byte) ([]byte, error) {
|
|
||||||
tsv, err := value.(TSVectorValuer).TSVectorValue()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !tsv.Valid {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
buf = pgio.AppendInt32(buf, int32(len(tsv.Lexemes)))
|
|
||||||
|
|
||||||
for _, entry := range tsv.Lexemes {
|
|
||||||
buf = append(buf, entry.Word...)
|
|
||||||
buf = append(buf, 0x00)
|
|
||||||
buf = pgio.AppendUint16(buf, uint16(len(entry.Positions)))
|
|
||||||
|
|
||||||
// Each position is a uint16: weight (2 bits) | position (14 bits)
|
|
||||||
for _, pos := range entry.Positions {
|
|
||||||
packed := tsvectorWeightToBinary(pos.Weight)<<14 | uint16(pos.Position)&0x3FFF
|
|
||||||
buf = pgio.AppendUint16(buf, packed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type scanPlanBinaryTSVectorToTSVectorScanner struct{}
|
|
||||||
|
|
||||||
func (scanPlanBinaryTSVectorToTSVectorScanner) Scan(src []byte, dst any) error {
|
|
||||||
scanner := (dst).(TSVectorScanner)
|
|
||||||
|
|
||||||
if src == nil {
|
|
||||||
return scanner.ScanTSVector(TSVector{})
|
|
||||||
}
|
|
||||||
|
|
||||||
rp := 0
|
|
||||||
|
|
||||||
const (
|
|
||||||
uint16Len = 2
|
|
||||||
uint32Len = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(src[rp:]) < uint32Len {
|
|
||||||
return fmt.Errorf("tsvector incomplete %v", src)
|
|
||||||
}
|
|
||||||
entryCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
|
||||||
rp += uint32Len
|
|
||||||
|
|
||||||
var tsv TSVector
|
|
||||||
if entryCount > 0 {
|
|
||||||
tsv.Lexemes = make([]TSVectorLexeme, entryCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range entryCount {
|
|
||||||
nullIndex := bytes.IndexByte(src[rp:], 0x00)
|
|
||||||
if nullIndex == -1 {
|
|
||||||
return fmt.Errorf("invalid tsvector binary format: missing null terminator")
|
|
||||||
}
|
|
||||||
|
|
||||||
lexeme := TSVectorLexeme{Word: string(src[rp : rp+nullIndex])}
|
|
||||||
rp += nullIndex + 1 // skip past null terminator
|
|
||||||
|
|
||||||
// Read position count.
|
|
||||||
if len(src[rp:]) < uint16Len {
|
|
||||||
return fmt.Errorf("invalid tsvector binary format: incomplete position count")
|
|
||||||
}
|
|
||||||
|
|
||||||
numPositions := int(binary.BigEndian.Uint16(src[rp:]))
|
|
||||||
rp += uint16Len
|
|
||||||
|
|
||||||
// Read each packed position: weight (2 bits) | position (14 bits)
|
|
||||||
if len(src[rp:]) < numPositions*uint16Len {
|
|
||||||
return fmt.Errorf("invalid tsvector binary format: incomplete positions")
|
|
||||||
}
|
|
||||||
|
|
||||||
if numPositions > 0 {
|
|
||||||
lexeme.Positions = make([]TSVectorPosition, numPositions)
|
|
||||||
for pos := range numPositions {
|
|
||||||
packed := binary.BigEndian.Uint16(src[rp:])
|
|
||||||
rp += uint16Len
|
|
||||||
lexeme.Positions[pos] = TSVectorPosition{
|
|
||||||
Position: packed & 0x3FFF,
|
|
||||||
Weight: tsvectorWeightFromBinary(packed >> 14),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tsv.Lexemes[i] = lexeme
|
|
||||||
}
|
|
||||||
tsv.Valid = true
|
|
||||||
|
|
||||||
return scanner.ScanTSVector(tsv)
|
|
||||||
}
|
|
||||||
|
|
||||||
var tsvectorLexemeReplacer = strings.NewReplacer(
|
|
||||||
`\`, `\\`,
|
|
||||||
`'`, `\'`,
|
|
||||||
)
|
|
||||||
|
|
||||||
type encodePlanTSVectorCodecText struct{}
|
|
||||||
|
|
||||||
func (encodePlanTSVectorCodecText) Encode(value any, buf []byte) ([]byte, error) {
|
|
||||||
tsv, err := value.(TSVectorValuer).TSVectorValue()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !tsv.Valid {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if buf == nil {
|
|
||||||
buf = []byte{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, lex := range tsv.Lexemes {
|
|
||||||
if i > 0 {
|
|
||||||
buf = append(buf, ' ')
|
|
||||||
}
|
|
||||||
|
|
||||||
buf = append(buf, '\'')
|
|
||||||
buf = append(buf, tsvectorLexemeReplacer.Replace(lex.Word)...)
|
|
||||||
buf = append(buf, '\'')
|
|
||||||
|
|
||||||
sep := byte(':')
|
|
||||||
for _, p := range lex.Positions {
|
|
||||||
buf = append(buf, sep)
|
|
||||||
buf = append(buf, p.String()...)
|
|
||||||
sep = ','
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (TSVectorCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
|
||||||
switch format {
|
|
||||||
case BinaryFormatCode:
|
|
||||||
switch target.(type) {
|
|
||||||
case TSVectorScanner:
|
|
||||||
return scanPlanBinaryTSVectorToTSVectorScanner{}
|
|
||||||
}
|
|
||||||
case TextFormatCode:
|
|
||||||
switch target.(type) {
|
|
||||||
case TSVectorScanner:
|
|
||||||
return scanPlanTextAnyToTSVectorScanner{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type scanPlanTextAnyToTSVectorScanner struct{}
|
|
||||||
|
|
||||||
func (s scanPlanTextAnyToTSVectorScanner) Scan(src []byte, dst any) error {
|
|
||||||
scanner := (dst).(TSVectorScanner)
|
|
||||||
|
|
||||||
if src == nil {
|
|
||||||
return scanner.ScanTSVector(TSVector{})
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.scanString(string(src), scanner)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (scanPlanTextAnyToTSVectorScanner) scanString(src string, scanner TSVectorScanner) error {
|
|
||||||
tsv, err := parseTSVector(src)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return scanner.ScanTSVector(tsv)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c TSVectorCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
|
||||||
return codecDecodeToTextFormat(c, m, oid, format, src)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c TSVectorCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
|
||||||
if src == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var tsv TSVector
|
|
||||||
err := codecScan(c, m, oid, format, src, &tsv)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return tsv, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type tsvectorParser struct {
|
|
||||||
str string
|
|
||||||
pos int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *tsvectorParser) atEnd() bool {
|
|
||||||
return p.pos >= len(p.str)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *tsvectorParser) peek() byte {
|
|
||||||
return p.str[p.pos]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *tsvectorParser) consume() (byte, bool) {
|
|
||||||
if p.pos >= len(p.str) {
|
|
||||||
return 0, true
|
|
||||||
}
|
|
||||||
b := p.str[p.pos]
|
|
||||||
p.pos++
|
|
||||||
return b, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *tsvectorParser) consumeSpaces() {
|
|
||||||
for !p.atEnd() && p.peek() == ' ' {
|
|
||||||
p.consume()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// consumeLexeme consumes a single-quoted lexeme, handling single quotes and backslash escapes.
|
|
||||||
func (p *tsvectorParser) consumeLexeme() (string, error) {
|
|
||||||
ch, end := p.consume()
|
|
||||||
if end || ch != '\'' {
|
|
||||||
return "", fmt.Errorf("invalid tsvector format: lexeme must start with a single quote")
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf strings.Builder
|
|
||||||
for {
|
|
||||||
ch, end := p.consume()
|
|
||||||
if end {
|
|
||||||
return "", fmt.Errorf("invalid tsvector format: unterminated quoted lexeme")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch ch {
|
|
||||||
case '\'':
|
|
||||||
// Escaped quote ('') — write a literal single quote
|
|
||||||
if !p.atEnd() && p.peek() == '\'' {
|
|
||||||
p.consume()
|
|
||||||
buf.WriteByte('\'')
|
|
||||||
} else {
|
|
||||||
// Closing quote — lexeme is complete
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
||||||
case '\\':
|
|
||||||
next, end := p.consume()
|
|
||||||
if end {
|
|
||||||
return "", fmt.Errorf("invalid tsvector format: unexpected end after backslash")
|
|
||||||
}
|
|
||||||
buf.WriteByte(next)
|
|
||||||
default:
|
|
||||||
buf.WriteByte(ch)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// consumePositions consumes a comma-separated list of position[weight] values.
|
|
||||||
func (p *tsvectorParser) consumePositions() ([]TSVectorPosition, error) {
|
|
||||||
var positions []TSVectorPosition
|
|
||||||
|
|
||||||
for {
|
|
||||||
pos, err := p.consumePosition()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
positions = append(positions, pos)
|
|
||||||
|
|
||||||
if p.atEnd() || p.peek() != ',' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
p.consume() // skip ','
|
|
||||||
}
|
|
||||||
|
|
||||||
return positions, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// consumePosition consumes a single position number with optional weight letter.
|
|
||||||
func (p *tsvectorParser) consumePosition() (TSVectorPosition, error) {
|
|
||||||
start := p.pos
|
|
||||||
|
|
||||||
for !p.atEnd() && p.peek() >= '0' && p.peek() <= '9' {
|
|
||||||
p.consume()
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.pos == start {
|
|
||||||
return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: expected position number")
|
|
||||||
}
|
|
||||||
|
|
||||||
num, err := strconv.ParseUint(p.str[start:p.pos], 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: invalid position number %q", p.str[start:p.pos])
|
|
||||||
}
|
|
||||||
|
|
||||||
pos := TSVectorPosition{Position: uint16(num), Weight: TSVectorWeightD}
|
|
||||||
|
|
||||||
// Check for optional weight letter
|
|
||||||
if !p.atEnd() {
|
|
||||||
switch p.peek() {
|
|
||||||
case 'A', 'a':
|
|
||||||
pos.Weight = TSVectorWeightA
|
|
||||||
case 'B', 'b':
|
|
||||||
pos.Weight = TSVectorWeightB
|
|
||||||
case 'C', 'c':
|
|
||||||
pos.Weight = TSVectorWeightC
|
|
||||||
case 'D', 'd':
|
|
||||||
pos.Weight = TSVectorWeightD
|
|
||||||
default:
|
|
||||||
return pos, nil
|
|
||||||
}
|
|
||||||
p.consume()
|
|
||||||
}
|
|
||||||
|
|
||||||
return pos, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseTSVector parses a PostgreSQL tsvector text representation.
|
|
||||||
func parseTSVector(s string) (TSVector, error) {
|
|
||||||
result := TSVector{}
|
|
||||||
p := &tsvectorParser{str: strings.TrimSpace(s), pos: 0}
|
|
||||||
|
|
||||||
for !p.atEnd() {
|
|
||||||
p.consumeSpaces()
|
|
||||||
if p.atEnd() {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
word, err := p.consumeLexeme()
|
|
||||||
if err != nil {
|
|
||||||
return TSVector{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := TSVectorLexeme{Word: word}
|
|
||||||
|
|
||||||
// Check for optional positions after ':'
|
|
||||||
if !p.atEnd() && p.peek() == ':' {
|
|
||||||
p.consume() // skip ':'
|
|
||||||
|
|
||||||
positions, err := p.consumePositions()
|
|
||||||
if err != nil {
|
|
||||||
return TSVector{}, err
|
|
||||||
}
|
|
||||||
entry.Positions = positions
|
|
||||||
}
|
|
||||||
|
|
||||||
result.Lexemes = append(result.Lexemes, entry)
|
|
||||||
}
|
|
||||||
|
|
||||||
result.Valid = true
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,170 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
# test.sh - Run pgx tests against specific database targets
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
# ./test.sh [target] [go test flags...]
|
|
||||||
#
|
|
||||||
# Targets:
|
|
||||||
# pg14 - PostgreSQL 14 (port 5414)
|
|
||||||
# pg15 - PostgreSQL 15 (port 5415)
|
|
||||||
# pg16 - PostgreSQL 16 (port 5416)
|
|
||||||
# pg17 - PostgreSQL 17 (port 5417)
|
|
||||||
# pg18 - PostgreSQL 18 (port 5432) [default]
|
|
||||||
# crdb - CockroachDB (port 26257)
|
|
||||||
# all - Run against all targets sequentially
|
|
||||||
#
|
|
||||||
# Examples:
|
|
||||||
# ./test.sh # Test against PG18
|
|
||||||
# ./test.sh pg14 # Test against PG14
|
|
||||||
# ./test.sh crdb # Test against CockroachDB
|
|
||||||
# ./test.sh all # Test against all targets
|
|
||||||
# ./test.sh pg16 -run TestConnect # Test specific test against PG16
|
|
||||||
# ./test.sh pg18 -count=1 -v # Verbose, no cache, PG18
|
|
||||||
|
|
||||||
# Color output (disabled if not a terminal)
|
|
||||||
if [ -t 1 ]; then
|
|
||||||
GREEN='\033[0;32m'
|
|
||||||
RED='\033[0;31m'
|
|
||||||
BLUE='\033[0;34m'
|
|
||||||
NC='\033[0m'
|
|
||||||
else
|
|
||||||
GREEN=''
|
|
||||||
RED=''
|
|
||||||
BLUE=''
|
|
||||||
NC=''
|
|
||||||
fi
|
|
||||||
|
|
||||||
log_info() { echo -e "${BLUE}==> $*${NC}"; }
|
|
||||||
log_ok() { echo -e "${GREEN}==> $*${NC}"; }
|
|
||||||
log_err() { echo -e "${RED}==> $*${NC}" >&2; }
|
|
||||||
|
|
||||||
# Wait for a database to accept connections
|
|
||||||
wait_for_ready() {
|
|
||||||
local connstr="$1"
|
|
||||||
local label="$2"
|
|
||||||
local max_attempts=30
|
|
||||||
local attempt=0
|
|
||||||
|
|
||||||
log_info "Waiting for $label to be ready..."
|
|
||||||
while ! psql "$connstr" -c "SELECT 1" > /dev/null 2>&1; do
|
|
||||||
attempt=$((attempt + 1))
|
|
||||||
if [ "$attempt" -ge "$max_attempts" ]; then
|
|
||||||
log_err "$label did not become ready after $max_attempts attempts"
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
sleep 1
|
|
||||||
done
|
|
||||||
log_ok "$label is ready"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Directory containing this script (used to locate testsetup/)
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
||||||
CERTS_DIR="$SCRIPT_DIR/testsetup/certs"
|
|
||||||
|
|
||||||
# Copy client certificates to /tmp for TLS tests
|
|
||||||
setup_client_certs() {
|
|
||||||
if [ -d "$CERTS_DIR" ]; then
|
|
||||||
base64 -d "$CERTS_DIR/ca.pem.b64" > /tmp/ca.pem
|
|
||||||
base64 -d "$CERTS_DIR/pgx_sslcert.crt.b64" > /tmp/pgx_sslcert.crt
|
|
||||||
base64 -d "$CERTS_DIR/pgx_sslcert.key.b64" > /tmp/pgx_sslcert.key
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize CockroachDB (create database if not exists)
|
|
||||||
init_crdb() {
|
|
||||||
local connstr="postgresql://root@localhost:26257/?sslmode=disable"
|
|
||||||
wait_for_ready "$connstr" "CockroachDB"
|
|
||||||
log_info "Ensuring pgx_test database exists on CockroachDB..."
|
|
||||||
psql "$connstr" -c "CREATE DATABASE IF NOT EXISTS pgx_test" 2>/dev/null || true
|
|
||||||
}
|
|
||||||
|
|
||||||
# Run tests against a single target
|
|
||||||
run_tests() {
|
|
||||||
local target="$1"
|
|
||||||
shift
|
|
||||||
local extra_args=("$@")
|
|
||||||
|
|
||||||
local label=""
|
|
||||||
local port=""
|
|
||||||
|
|
||||||
case "$target" in
|
|
||||||
pg14) label="PostgreSQL 14"; port=5414 ;;
|
|
||||||
pg15) label="PostgreSQL 15"; port=5415 ;;
|
|
||||||
pg16) label="PostgreSQL 16"; port=5416 ;;
|
|
||||||
pg17) label="PostgreSQL 17"; port=5417 ;;
|
|
||||||
pg18) label="PostgreSQL 18"; port=5432 ;;
|
|
||||||
crdb)
|
|
||||||
label="CockroachDB (port 26257)"
|
|
||||||
init_crdb
|
|
||||||
log_info "Testing against $label"
|
|
||||||
if ! PGX_TEST_DATABASE="postgresql://root@localhost:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" \
|
|
||||||
go test -count=1 "${extra_args[@]}" ./...; then
|
|
||||||
log_err "Tests FAILED against $label"
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
log_ok "Tests passed against $label"
|
|
||||||
return 0
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
log_err "Unknown target: $target"
|
|
||||||
log_err "Valid targets: pg14, pg15, pg16, pg17, pg18, crdb, all"
|
|
||||||
return 1
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
|
|
||||||
setup_client_certs
|
|
||||||
|
|
||||||
log_info "Testing against $label (port $port)"
|
|
||||||
if ! PGX_TEST_DATABASE="host=localhost port=$port user=postgres password=postgres dbname=pgx_test" \
|
|
||||||
PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql port=$port user=postgres dbname=pgx_test" \
|
|
||||||
PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \
|
|
||||||
PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \
|
|
||||||
PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" \
|
|
||||||
PGX_TEST_SCRAM_PLUS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" \
|
|
||||||
PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_pw password=secret dbname=pgx_test" \
|
|
||||||
PGX_TEST_TLS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" \
|
|
||||||
PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost port=$port user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" \
|
|
||||||
PGX_SSL_PASSWORD=certpw \
|
|
||||||
go test -count=1 "${extra_args[@]}" ./...; then
|
|
||||||
log_err "Tests FAILED against $label"
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
log_ok "Tests passed against $label"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Main
|
|
||||||
main() {
|
|
||||||
local target="${1:-pg18}"
|
|
||||||
|
|
||||||
if [ "$target" = "all" ]; then
|
|
||||||
shift || true
|
|
||||||
local targets=(pg14 pg15 pg16 pg17 pg18 crdb)
|
|
||||||
local failed=()
|
|
||||||
|
|
||||||
for t in "${targets[@]}"; do
|
|
||||||
echo ""
|
|
||||||
log_info "=========================================="
|
|
||||||
log_info "Target: $t"
|
|
||||||
log_info "=========================================="
|
|
||||||
if ! run_tests "$t" "$@"; then
|
|
||||||
failed+=("$t")
|
|
||||||
log_err "FAILED: $t"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
if [ ${#failed[@]} -gt 0 ]; then
|
|
||||||
log_err "Failed targets: ${failed[*]}"
|
|
||||||
return 1
|
|
||||||
else
|
|
||||||
log_ok "All targets passed"
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
shift || true
|
|
||||||
run_tests "$target" "$@"
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
main "$@"
|
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
.git
|
||||||
|
*.swp
|
||||||
|
|
||||||
|
# IntelliJ
|
||||||
|
.idea/
|
||||||
|
*.iml
|
||||||
|
|
||||||
|
# VS code
|
||||||
|
*.code-workspace
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
## Decimal v1.4.0
|
||||||
|
#### BREAKING
|
||||||
|
- Drop support for Go version older than 1.10 [#361](https://github.com/shopspring/decimal/pull/361)
|
||||||
|
|
||||||
|
#### FEATURES
|
||||||
|
- Add implementation of natural logarithm [#339](https://github.com/shopspring/decimal/pull/339) [#357](https://github.com/shopspring/decimal/pull/357)
|
||||||
|
- Add improved implementation of power operation [#358](https://github.com/shopspring/decimal/pull/358)
|
||||||
|
- Add Compare method which forwards calls to Cmp [#346](https://github.com/shopspring/decimal/pull/346)
|
||||||
|
- Add NewFromBigRat constructor [#288](https://github.com/shopspring/decimal/pull/288)
|
||||||
|
- Add NewFromUint64 constructor [#352](https://github.com/shopspring/decimal/pull/352)
|
||||||
|
|
||||||
|
#### ENHANCEMENTS
|
||||||
|
- Migrate to Github Actions [#245](https://github.com/shopspring/decimal/pull/245) [#340](https://github.com/shopspring/decimal/pull/340)
|
||||||
|
- Fix examples for RoundDown, RoundFloor, RoundUp, and RoundCeil [#285](https://github.com/shopspring/decimal/pull/285) [#328](https://github.com/shopspring/decimal/pull/328) [#341](https://github.com/shopspring/decimal/pull/341)
|
||||||
|
- Use Godoc standard to mark deprecated Equals and StringScaled methods [#342](https://github.com/shopspring/decimal/pull/342)
|
||||||
|
- Removed unnecessary min function for RescalePair method [#265](https://github.com/shopspring/decimal/pull/265)
|
||||||
|
- Avoid reallocation of initial slice in MarshalBinary (GobEncode) [#355](https://github.com/shopspring/decimal/pull/355)
|
||||||
|
- Optimize NumDigits method [#301](https://github.com/shopspring/decimal/pull/301) [#356](https://github.com/shopspring/decimal/pull/356)
|
||||||
|
- Optimize BigInt method [#359](https://github.com/shopspring/decimal/pull/359)
|
||||||
|
- Support scanning uint64 [#131](https://github.com/shopspring/decimal/pull/131) [#364](https://github.com/shopspring/decimal/pull/364)
|
||||||
|
- Add docs section with alternative libraries [#363](https://github.com/shopspring/decimal/pull/363)
|
||||||
|
|
||||||
|
#### BUGFIXES
|
||||||
|
- Fix incorrect calculation of decimal modulo [#258](https://github.com/shopspring/decimal/pull/258) [#317](https://github.com/shopspring/decimal/pull/317)
|
||||||
|
- Allocate new(big.Int) in Copy method to deeply clone it [#278](https://github.com/shopspring/decimal/pull/278)
|
||||||
|
- Fix overflow edge case in QuoRem method [#322](https://github.com/shopspring/decimal/pull/322)
|
||||||
|
|
||||||
|
## Decimal v1.3.1
|
||||||
|
|
||||||
|
#### ENHANCEMENTS
|
||||||
|
- Reduce memory allocation in case of initialization from big.Int [#252](https://github.com/shopspring/decimal/pull/252)
|
||||||
|
|
||||||
|
#### BUGFIXES
|
||||||
|
- Fix binary marshalling of decimal zero value [#253](https://github.com/shopspring/decimal/pull/253)
|
||||||
|
|
||||||
|
## Decimal v1.3.0
|
||||||
|
|
||||||
|
#### FEATURES
|
||||||
|
- Add NewFromFormattedString initializer [#184](https://github.com/shopspring/decimal/pull/184)
|
||||||
|
- Add NewNullDecimal initializer [#234](https://github.com/shopspring/decimal/pull/234)
|
||||||
|
- Add implementation of natural exponent function (Taylor, Hull-Abraham) [#229](https://github.com/shopspring/decimal/pull/229)
|
||||||
|
- Add RoundUp, RoundDown, RoundCeil, RoundFloor methods [#196](https://github.com/shopspring/decimal/pull/196) [#202](https://github.com/shopspring/decimal/pull/202) [#220](https://github.com/shopspring/decimal/pull/220)
|
||||||
|
- Add XML support for NullDecimal [#192](https://github.com/shopspring/decimal/pull/192)
|
||||||
|
- Add IsInteger method [#179](https://github.com/shopspring/decimal/pull/179)
|
||||||
|
- Add Copy helper method [#123](https://github.com/shopspring/decimal/pull/123)
|
||||||
|
- Add InexactFloat64 helper method [#205](https://github.com/shopspring/decimal/pull/205)
|
||||||
|
- Add CoefficientInt64 helper method [#244](https://github.com/shopspring/decimal/pull/244)
|
||||||
|
|
||||||
|
#### ENHANCEMENTS
|
||||||
|
- Performance optimization of NewFromString init method [#198](https://github.com/shopspring/decimal/pull/198)
|
||||||
|
- Performance optimization of Abs and Round methods [#240](https://github.com/shopspring/decimal/pull/240)
|
||||||
|
- Additional tests (CI) for ppc64le architecture [#188](https://github.com/shopspring/decimal/pull/188)
|
||||||
|
|
||||||
|
#### BUGFIXES
|
||||||
|
- Fix rounding in FormatFloat fallback path (roundShortest method, fix taken from Go main repository) [#161](https://github.com/shopspring/decimal/pull/161)
|
||||||
|
- Add slice range checks to UnmarshalBinary method [#232](https://github.com/shopspring/decimal/pull/232)
|
||||||
|
|
||||||
|
## Decimal v1.2.0
|
||||||
|
|
||||||
|
#### BREAKING
|
||||||
|
- Drop support for Go version older than 1.7 [#172](https://github.com/shopspring/decimal/pull/172)
|
||||||
|
|
||||||
|
#### FEATURES
|
||||||
|
- Add NewFromInt and NewFromInt32 initializers [#72](https://github.com/shopspring/decimal/pull/72)
|
||||||
|
- Add support for Go modules [#157](https://github.com/shopspring/decimal/pull/157)
|
||||||
|
- Add BigInt, BigFloat helper methods [#171](https://github.com/shopspring/decimal/pull/171)
|
||||||
|
|
||||||
|
#### ENHANCEMENTS
|
||||||
|
- Memory usage optimization [#160](https://github.com/shopspring/decimal/pull/160)
|
||||||
|
- Updated travis CI golang versions [#156](https://github.com/shopspring/decimal/pull/156)
|
||||||
|
- Update documentation [#173](https://github.com/shopspring/decimal/pull/173)
|
||||||
|
- Improve code quality [#174](https://github.com/shopspring/decimal/pull/174)
|
||||||
|
|
||||||
|
#### BUGFIXES
|
||||||
|
- Revert remove insignificant digits [#159](https://github.com/shopspring/decimal/pull/159)
|
||||||
|
- Remove 15 interval for RoundCash [#166](https://github.com/shopspring/decimal/pull/166)
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2015 Spring, Inc.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in
|
||||||
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
THE SOFTWARE.
|
||||||
|
|
||||||
|
- Based on https://github.com/oguzbilgic/fpd, which has the following license:
|
||||||
|
"""
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2013 Oguz Bilgic
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||||
|
this software and associated documentation files (the "Software"), to deal in
|
||||||
|
the Software without restriction, including without limitation the rights to
|
||||||
|
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||||
|
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||||
|
subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||||
|
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||||
|
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||||
|
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||||
|
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
"""
|
||||||
|
|
@ -0,0 +1,139 @@
|
||||||
|
# decimal
|
||||||
|
|
||||||
|
[](https://github.com/shopspring/decimal/actions/workflows/ci.yml)
|
||||||
|
[](https://godoc.org/github.com/shopspring/decimal)
|
||||||
|
[](https://goreportcard.com/report/github.com/shopspring/decimal)
|
||||||
|
|
||||||
|
Arbitrary-precision fixed-point decimal numbers in go.
|
||||||
|
|
||||||
|
_Note:_ Decimal library can "only" represent numbers with a maximum of 2^31 digits after the decimal point.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
* The zero-value is 0, and is safe to use without initialization
|
||||||
|
* Addition, subtraction, multiplication with no loss of precision
|
||||||
|
* Division with specified precision
|
||||||
|
* Database/sql serialization/deserialization
|
||||||
|
* JSON and XML serialization/deserialization
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
Run `go get github.com/shopspring/decimal`
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
Decimal library requires Go version `>=1.10`
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
http://godoc.org/github.com/shopspring/decimal
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
price, err := decimal.NewFromString("136.02")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
quantity := decimal.NewFromInt(3)
|
||||||
|
|
||||||
|
fee, _ := decimal.NewFromString(".035")
|
||||||
|
taxRate, _ := decimal.NewFromString(".08875")
|
||||||
|
|
||||||
|
subtotal := price.Mul(quantity)
|
||||||
|
|
||||||
|
preTax := subtotal.Mul(fee.Add(decimal.NewFromFloat(1)))
|
||||||
|
|
||||||
|
total := preTax.Mul(taxRate.Add(decimal.NewFromFloat(1)))
|
||||||
|
|
||||||
|
fmt.Println("Subtotal:", subtotal) // Subtotal: 408.06
|
||||||
|
fmt.Println("Pre-tax:", preTax) // Pre-tax: 422.3421
|
||||||
|
fmt.Println("Taxes:", total.Sub(preTax)) // Taxes: 37.482861375
|
||||||
|
fmt.Println("Total:", total) // Total: 459.824961375
|
||||||
|
fmt.Println("Tax rate:", total.Sub(preTax).Div(preTax)) // Tax rate: 0.08875
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Alternative libraries
|
||||||
|
|
||||||
|
When working with decimal numbers, you might face problems this library is not perfectly suited for.
|
||||||
|
Fortunately, thanks to the wonderful community we have a dozen other libraries that you can choose from.
|
||||||
|
Explore other alternatives to find the one that best fits your needs :)
|
||||||
|
|
||||||
|
* [cockroachdb/apd](https://github.com/cockroachdb/apd) - arbitrary precision, mutable and rich API similar to `big.Int`, more performant than this library
|
||||||
|
* [alpacahq/alpacadecimal](https://github.com/alpacahq/alpacadecimal) - high performance, low precision (12 digits), fully compatible API with this library
|
||||||
|
* [govalues/decimal](https://github.com/govalues/decimal) - high performance, zero-allocation, low precision (19 digits)
|
||||||
|
* [greatcloak/decimal](https://github.com/greatcloak/decimal) - fork focusing on billing and e-commerce web application related use cases, includes out-of-the-box BSON marshaling support
|
||||||
|
|
||||||
|
## FAQ
|
||||||
|
|
||||||
|
#### Why don't you just use float64?
|
||||||
|
|
||||||
|
Because float64 (or any binary floating point type, actually) can't represent
|
||||||
|
numbers such as `0.1` exactly.
|
||||||
|
|
||||||
|
Consider this code: http://play.golang.org/p/TQBd4yJe6B You might expect that
|
||||||
|
it prints out `10`, but it actually prints `9.999999999999831`. Over time,
|
||||||
|
these small errors can really add up!
|
||||||
|
|
||||||
|
#### Why don't you just use big.Rat?
|
||||||
|
|
||||||
|
big.Rat is fine for representing rational numbers, but Decimal is better for
|
||||||
|
representing money. Why? Here's a (contrived) example:
|
||||||
|
|
||||||
|
Let's say you use big.Rat, and you have two numbers, x and y, both
|
||||||
|
representing 1/3, and you have `z = 1 - x - y = 1/3`. If you print each one
|
||||||
|
out, the string output has to stop somewhere (let's say it stops at 3 decimal
|
||||||
|
digits, for simplicity), so you'll get 0.333, 0.333, and 0.333. But where did
|
||||||
|
the other 0.001 go?
|
||||||
|
|
||||||
|
Here's the above example as code: http://play.golang.org/p/lCZZs0w9KE
|
||||||
|
|
||||||
|
With Decimal, the strings being printed out represent the number exactly. So,
|
||||||
|
if you have `x = y = 1/3` (with precision 3), they will actually be equal to
|
||||||
|
0.333, and when you do `z = 1 - x - y`, `z` will be equal to .334. No money is
|
||||||
|
unaccounted for!
|
||||||
|
|
||||||
|
You still have to be careful. If you want to split a number `N` 3 ways, you
|
||||||
|
can't just send `N/3` to three different people. You have to pick one to send
|
||||||
|
`N - (2/3*N)` to. That person will receive the fraction of a penny remainder.
|
||||||
|
|
||||||
|
But, it is much easier to be careful with Decimal than with big.Rat.
|
||||||
|
|
||||||
|
#### Why isn't the API similar to big.Int's?
|
||||||
|
|
||||||
|
big.Int's API is built to reduce the number of memory allocations for maximal
|
||||||
|
performance. This makes sense for its use-case, but the trade-off is that the
|
||||||
|
API is awkward and easy to misuse.
|
||||||
|
|
||||||
|
For example, to add two big.Ints, you do: `z := new(big.Int).Add(x, y)`. A
|
||||||
|
developer unfamiliar with this API might try to do `z := a.Add(a, b)`. This
|
||||||
|
modifies `a` and sets `z` as an alias for `a`, which they might not expect. It
|
||||||
|
also modifies any other aliases to `a`.
|
||||||
|
|
||||||
|
Here's an example of the subtle bugs you can introduce with big.Int's API:
|
||||||
|
https://play.golang.org/p/x2R_78pa8r
|
||||||
|
|
||||||
|
In contrast, it's difficult to make such mistakes with decimal. Decimals
|
||||||
|
behave like other go numbers types: even though `a = b` will not deep copy
|
||||||
|
`b` into `a`, it is impossible to modify a Decimal, since all Decimal methods
|
||||||
|
return new Decimals and do not modify the originals. The downside is that
|
||||||
|
this causes extra allocations, so Decimal is less performant. My assumption
|
||||||
|
is that if you're using Decimals, you probably care more about correctness
|
||||||
|
than performance.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
This is a heavily modified fork of [fpd.Decimal](https://github.com/oguzbilgic/fpd), which was also released under the MIT License.
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
package decimal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
strLn10 = "2.302585092994045684017991454684364207601101488628772976033327900967572609677352480235997205089598298341967784042286248633409525465082806756666287369098781689482907208325554680843799894826233198528393505308965377732628846163366222287698219886746543667474404243274365155048934314939391479619404400222105101714174800368808401264708068556774321622835522011480466371565912137345074785694768346361679210180644507064800027750268491674655058685693567342067058113642922455440575892572420824131469568901675894025677631135691929203337658714166023010570308963457207544037084746994016826928280848118428931484852494864487192780967627127577539702766860595249671667418348570442250719796500471495105049221477656763693866297697952211071826454973477266242570942932258279850258550978526538320760672631716430950599508780752371033310119785754733154142180842754386359177811705430982748238504564801909561029929182431823752535770975053956518769751037497088869218020518933950723853920514463419726528728696511086257149219884997874887377134568620916705849807828059751193854445009978131146915934666241071846692310107598438319191292230792503747298650929009880391941702654416816335727555703151596113564846546190897042819763365836983716328982174407366009162177850541779276367731145041782137660111010731042397832521894898817597921798666394319523936855916447118246753245630912528778330963604262982153040874560927760726641354787576616262926568298704957954913954918049209069438580790032763017941503117866862092408537949861264933479354871737451675809537088281067452440105892444976479686075120275724181874989395971643105518848195288330746699317814634930000321200327765654130472621883970596794457943468343218395304414844803701305753674262153675579814770458031413637793236291560128185336498466942261465206459942072917119370602444929358037007718981097362533224548366988505528285966192805098447175198503666680874970496982273220244823343097169111136813588418696549323714996941979687803008850408979618598756579894836445212043698216415292987811742973332588607915912510967187510929248475023930572665446276200923068791518135803477701295593646298412366497023355174586195564772461857717369368404676577047874319780573853271810933883496338813069945569399346101090745616033312247949360455361849123333063704751724871276379140924398331810164737823379692265637682071706935846394531616949411701841938119405416449466111274712819705817783293841742231409930022911502362192186723337268385688273533371925103412930705632544426611429765388301822384091026198582888433587455960453004548370789052578473166283701953392231047527564998119228742789713715713228319641003422124210082180679525276689858180956119208391760721080919923461516952599099473782780648128058792731993893453415320185969711021407542282796298237068941764740642225757212455392526179373652434440560595336591539160312524480149313234572453879524389036839236450507881731359711238145323701508413491122324390927681724749607955799151363982881058285740538000653371655553014196332241918087621018204919492651483892"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ln10 = newConstApproximation(strLn10)
|
||||||
|
)
|
||||||
|
|
||||||
|
type constApproximation struct {
|
||||||
|
exact Decimal
|
||||||
|
approximations []Decimal
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConstApproximation(value string) constApproximation {
|
||||||
|
parts := strings.Split(value, ".")
|
||||||
|
coeff, fractional := parts[0], parts[1]
|
||||||
|
|
||||||
|
coeffLen := len(coeff)
|
||||||
|
maxPrecision := len(fractional)
|
||||||
|
|
||||||
|
var approximations []Decimal
|
||||||
|
for p := 1; p < maxPrecision; p *= 2 {
|
||||||
|
r := RequireFromString(value[:coeffLen+p])
|
||||||
|
approximations = append(approximations, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
return constApproximation{
|
||||||
|
RequireFromString(value),
|
||||||
|
approximations,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the smallest approximation available that's at least as precise
|
||||||
|
// as the passed precision (places after decimal point), i.e. Floor[ log2(precision) ] + 1
|
||||||
|
func (c constApproximation) withPrecision(precision int32) Decimal {
|
||||||
|
i := 0
|
||||||
|
|
||||||
|
if precision >= 1 {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
for precision >= 16 {
|
||||||
|
precision /= 16
|
||||||
|
i += 4
|
||||||
|
}
|
||||||
|
|
||||||
|
for precision >= 2 {
|
||||||
|
precision /= 2
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
if i >= len(c.approximations) {
|
||||||
|
return c.exact
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.approximations[i]
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,415 @@
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Multiprecision decimal numbers.
|
||||||
|
// For floating-point formatting only; not general purpose.
|
||||||
|
// Only operations are assign and (binary) left/right shift.
|
||||||
|
// Can do binary floating point in multiprecision decimal precisely
|
||||||
|
// because 2 divides 10; cannot do decimal floating point
|
||||||
|
// in multiprecision binary precisely.
|
||||||
|
|
||||||
|
package decimal
|
||||||
|
|
||||||
|
type decimal struct {
|
||||||
|
d [800]byte // digits, big-endian representation
|
||||||
|
nd int // number of digits used
|
||||||
|
dp int // decimal point
|
||||||
|
neg bool // negative flag
|
||||||
|
trunc bool // discarded nonzero digits beyond d[:nd]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *decimal) String() string {
|
||||||
|
n := 10 + a.nd
|
||||||
|
if a.dp > 0 {
|
||||||
|
n += a.dp
|
||||||
|
}
|
||||||
|
if a.dp < 0 {
|
||||||
|
n += -a.dp
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, n)
|
||||||
|
w := 0
|
||||||
|
switch {
|
||||||
|
case a.nd == 0:
|
||||||
|
return "0"
|
||||||
|
|
||||||
|
case a.dp <= 0:
|
||||||
|
// zeros fill space between decimal point and digits
|
||||||
|
buf[w] = '0'
|
||||||
|
w++
|
||||||
|
buf[w] = '.'
|
||||||
|
w++
|
||||||
|
w += digitZero(buf[w : w+-a.dp])
|
||||||
|
w += copy(buf[w:], a.d[0:a.nd])
|
||||||
|
|
||||||
|
case a.dp < a.nd:
|
||||||
|
// decimal point in middle of digits
|
||||||
|
w += copy(buf[w:], a.d[0:a.dp])
|
||||||
|
buf[w] = '.'
|
||||||
|
w++
|
||||||
|
w += copy(buf[w:], a.d[a.dp:a.nd])
|
||||||
|
|
||||||
|
default:
|
||||||
|
// zeros fill space between digits and decimal point
|
||||||
|
w += copy(buf[w:], a.d[0:a.nd])
|
||||||
|
w += digitZero(buf[w : w+a.dp-a.nd])
|
||||||
|
}
|
||||||
|
return string(buf[0:w])
|
||||||
|
}
|
||||||
|
|
||||||
|
func digitZero(dst []byte) int {
|
||||||
|
for i := range dst {
|
||||||
|
dst[i] = '0'
|
||||||
|
}
|
||||||
|
return len(dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
// trim trailing zeros from number.
|
||||||
|
// (They are meaningless; the decimal point is tracked
|
||||||
|
// independent of the number of digits.)
|
||||||
|
func trim(a *decimal) {
|
||||||
|
for a.nd > 0 && a.d[a.nd-1] == '0' {
|
||||||
|
a.nd--
|
||||||
|
}
|
||||||
|
if a.nd == 0 {
|
||||||
|
a.dp = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign v to a.
|
||||||
|
func (a *decimal) Assign(v uint64) {
|
||||||
|
var buf [24]byte
|
||||||
|
|
||||||
|
// Write reversed decimal in buf.
|
||||||
|
n := 0
|
||||||
|
for v > 0 {
|
||||||
|
v1 := v / 10
|
||||||
|
v -= 10 * v1
|
||||||
|
buf[n] = byte(v + '0')
|
||||||
|
n++
|
||||||
|
v = v1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse again to produce forward decimal in a.d.
|
||||||
|
a.nd = 0
|
||||||
|
for n--; n >= 0; n-- {
|
||||||
|
a.d[a.nd] = buf[n]
|
||||||
|
a.nd++
|
||||||
|
}
|
||||||
|
a.dp = a.nd
|
||||||
|
trim(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maximum shift that we can do in one pass without overflow.
|
||||||
|
// A uint has 32 or 64 bits, and we have to be able to accommodate 9<<k.
|
||||||
|
const uintSize = 32 << (^uint(0) >> 63)
|
||||||
|
const maxShift = uintSize - 4
|
||||||
|
|
||||||
|
// Binary shift right (/ 2) by k bits. k <= maxShift to avoid overflow.
|
||||||
|
func rightShift(a *decimal, k uint) {
|
||||||
|
r := 0 // read pointer
|
||||||
|
w := 0 // write pointer
|
||||||
|
|
||||||
|
// Pick up enough leading digits to cover first shift.
|
||||||
|
var n uint
|
||||||
|
for ; n>>k == 0; r++ {
|
||||||
|
if r >= a.nd {
|
||||||
|
if n == 0 {
|
||||||
|
// a == 0; shouldn't get here, but handle anyway.
|
||||||
|
a.nd = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for n>>k == 0 {
|
||||||
|
n = n * 10
|
||||||
|
r++
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
c := uint(a.d[r])
|
||||||
|
n = n*10 + c - '0'
|
||||||
|
}
|
||||||
|
a.dp -= r - 1
|
||||||
|
|
||||||
|
var mask uint = (1 << k) - 1
|
||||||
|
|
||||||
|
// Pick up a digit, put down a digit.
|
||||||
|
for ; r < a.nd; r++ {
|
||||||
|
c := uint(a.d[r])
|
||||||
|
dig := n >> k
|
||||||
|
n &= mask
|
||||||
|
a.d[w] = byte(dig + '0')
|
||||||
|
w++
|
||||||
|
n = n*10 + c - '0'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put down extra digits.
|
||||||
|
for n > 0 {
|
||||||
|
dig := n >> k
|
||||||
|
n &= mask
|
||||||
|
if w < len(a.d) {
|
||||||
|
a.d[w] = byte(dig + '0')
|
||||||
|
w++
|
||||||
|
} else if dig > 0 {
|
||||||
|
a.trunc = true
|
||||||
|
}
|
||||||
|
n = n * 10
|
||||||
|
}
|
||||||
|
|
||||||
|
a.nd = w
|
||||||
|
trim(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cheat sheet for left shift: table indexed by shift count giving
|
||||||
|
// number of new digits that will be introduced by that shift.
|
||||||
|
//
|
||||||
|
// For example, leftcheats[4] = {2, "625"}. That means that
|
||||||
|
// if we are shifting by 4 (multiplying by 16), it will add 2 digits
|
||||||
|
// when the string prefix is "625" through "999", and one fewer digit
|
||||||
|
// if the string prefix is "000" through "624".
|
||||||
|
//
|
||||||
|
// Credit for this trick goes to Ken.
|
||||||
|
|
||||||
|
type leftCheat struct {
|
||||||
|
delta int // number of new digits
|
||||||
|
cutoff string // minus one digit if original < a.
|
||||||
|
}
|
||||||
|
|
||||||
|
var leftcheats = []leftCheat{
|
||||||
|
// Leading digits of 1/2^i = 5^i.
|
||||||
|
// 5^23 is not an exact 64-bit floating point number,
|
||||||
|
// so have to use bc for the math.
|
||||||
|
// Go up to 60 to be large enough for 32bit and 64bit platforms.
|
||||||
|
/*
|
||||||
|
seq 60 | sed 's/^/5^/' | bc |
|
||||||
|
awk 'BEGIN{ print "\t{ 0, \"\" }," }
|
||||||
|
{
|
||||||
|
log2 = log(2)/log(10)
|
||||||
|
printf("\t{ %d, \"%s\" },\t// * %d\n",
|
||||||
|
int(log2*NR+1), $0, 2**NR)
|
||||||
|
}'
|
||||||
|
*/
|
||||||
|
{0, ""},
|
||||||
|
{1, "5"}, // * 2
|
||||||
|
{1, "25"}, // * 4
|
||||||
|
{1, "125"}, // * 8
|
||||||
|
{2, "625"}, // * 16
|
||||||
|
{2, "3125"}, // * 32
|
||||||
|
{2, "15625"}, // * 64
|
||||||
|
{3, "78125"}, // * 128
|
||||||
|
{3, "390625"}, // * 256
|
||||||
|
{3, "1953125"}, // * 512
|
||||||
|
{4, "9765625"}, // * 1024
|
||||||
|
{4, "48828125"}, // * 2048
|
||||||
|
{4, "244140625"}, // * 4096
|
||||||
|
{4, "1220703125"}, // * 8192
|
||||||
|
{5, "6103515625"}, // * 16384
|
||||||
|
{5, "30517578125"}, // * 32768
|
||||||
|
{5, "152587890625"}, // * 65536
|
||||||
|
{6, "762939453125"}, // * 131072
|
||||||
|
{6, "3814697265625"}, // * 262144
|
||||||
|
{6, "19073486328125"}, // * 524288
|
||||||
|
{7, "95367431640625"}, // * 1048576
|
||||||
|
{7, "476837158203125"}, // * 2097152
|
||||||
|
{7, "2384185791015625"}, // * 4194304
|
||||||
|
{7, "11920928955078125"}, // * 8388608
|
||||||
|
{8, "59604644775390625"}, // * 16777216
|
||||||
|
{8, "298023223876953125"}, // * 33554432
|
||||||
|
{8, "1490116119384765625"}, // * 67108864
|
||||||
|
{9, "7450580596923828125"}, // * 134217728
|
||||||
|
{9, "37252902984619140625"}, // * 268435456
|
||||||
|
{9, "186264514923095703125"}, // * 536870912
|
||||||
|
{10, "931322574615478515625"}, // * 1073741824
|
||||||
|
{10, "4656612873077392578125"}, // * 2147483648
|
||||||
|
{10, "23283064365386962890625"}, // * 4294967296
|
||||||
|
{10, "116415321826934814453125"}, // * 8589934592
|
||||||
|
{11, "582076609134674072265625"}, // * 17179869184
|
||||||
|
{11, "2910383045673370361328125"}, // * 34359738368
|
||||||
|
{11, "14551915228366851806640625"}, // * 68719476736
|
||||||
|
{12, "72759576141834259033203125"}, // * 137438953472
|
||||||
|
{12, "363797880709171295166015625"}, // * 274877906944
|
||||||
|
{12, "1818989403545856475830078125"}, // * 549755813888
|
||||||
|
{13, "9094947017729282379150390625"}, // * 1099511627776
|
||||||
|
{13, "45474735088646411895751953125"}, // * 2199023255552
|
||||||
|
{13, "227373675443232059478759765625"}, // * 4398046511104
|
||||||
|
{13, "1136868377216160297393798828125"}, // * 8796093022208
|
||||||
|
{14, "5684341886080801486968994140625"}, // * 17592186044416
|
||||||
|
{14, "28421709430404007434844970703125"}, // * 35184372088832
|
||||||
|
{14, "142108547152020037174224853515625"}, // * 70368744177664
|
||||||
|
{15, "710542735760100185871124267578125"}, // * 140737488355328
|
||||||
|
{15, "3552713678800500929355621337890625"}, // * 281474976710656
|
||||||
|
{15, "17763568394002504646778106689453125"}, // * 562949953421312
|
||||||
|
{16, "88817841970012523233890533447265625"}, // * 1125899906842624
|
||||||
|
{16, "444089209850062616169452667236328125"}, // * 2251799813685248
|
||||||
|
{16, "2220446049250313080847263336181640625"}, // * 4503599627370496
|
||||||
|
{16, "11102230246251565404236316680908203125"}, // * 9007199254740992
|
||||||
|
{17, "55511151231257827021181583404541015625"}, // * 18014398509481984
|
||||||
|
{17, "277555756156289135105907917022705078125"}, // * 36028797018963968
|
||||||
|
{17, "1387778780781445675529539585113525390625"}, // * 72057594037927936
|
||||||
|
{18, "6938893903907228377647697925567626953125"}, // * 144115188075855872
|
||||||
|
{18, "34694469519536141888238489627838134765625"}, // * 288230376151711744
|
||||||
|
{18, "173472347597680709441192448139190673828125"}, // * 576460752303423488
|
||||||
|
{19, "867361737988403547205962240695953369140625"}, // * 1152921504606846976
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is the leading prefix of b lexicographically less than s?
|
||||||
|
func prefixIsLessThan(b []byte, s string) bool {
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
if i >= len(b) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if b[i] != s[i] {
|
||||||
|
return b[i] < s[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Binary shift left (* 2) by k bits. k <= maxShift to avoid overflow.
|
||||||
|
func leftShift(a *decimal, k uint) {
|
||||||
|
delta := leftcheats[k].delta
|
||||||
|
if prefixIsLessThan(a.d[0:a.nd], leftcheats[k].cutoff) {
|
||||||
|
delta--
|
||||||
|
}
|
||||||
|
|
||||||
|
r := a.nd // read index
|
||||||
|
w := a.nd + delta // write index
|
||||||
|
|
||||||
|
// Pick up a digit, put down a digit.
|
||||||
|
var n uint
|
||||||
|
for r--; r >= 0; r-- {
|
||||||
|
n += (uint(a.d[r]) - '0') << k
|
||||||
|
quo := n / 10
|
||||||
|
rem := n - 10*quo
|
||||||
|
w--
|
||||||
|
if w < len(a.d) {
|
||||||
|
a.d[w] = byte(rem + '0')
|
||||||
|
} else if rem != 0 {
|
||||||
|
a.trunc = true
|
||||||
|
}
|
||||||
|
n = quo
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put down extra digits.
|
||||||
|
for n > 0 {
|
||||||
|
quo := n / 10
|
||||||
|
rem := n - 10*quo
|
||||||
|
w--
|
||||||
|
if w < len(a.d) {
|
||||||
|
a.d[w] = byte(rem + '0')
|
||||||
|
} else if rem != 0 {
|
||||||
|
a.trunc = true
|
||||||
|
}
|
||||||
|
n = quo
|
||||||
|
}
|
||||||
|
|
||||||
|
a.nd += delta
|
||||||
|
if a.nd >= len(a.d) {
|
||||||
|
a.nd = len(a.d)
|
||||||
|
}
|
||||||
|
a.dp += delta
|
||||||
|
trim(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Binary shift left (k > 0) or right (k < 0).
|
||||||
|
func (a *decimal) Shift(k int) {
|
||||||
|
switch {
|
||||||
|
case a.nd == 0:
|
||||||
|
// nothing to do: a == 0
|
||||||
|
case k > 0:
|
||||||
|
for k > maxShift {
|
||||||
|
leftShift(a, maxShift)
|
||||||
|
k -= maxShift
|
||||||
|
}
|
||||||
|
leftShift(a, uint(k))
|
||||||
|
case k < 0:
|
||||||
|
for k < -maxShift {
|
||||||
|
rightShift(a, maxShift)
|
||||||
|
k += maxShift
|
||||||
|
}
|
||||||
|
rightShift(a, uint(-k))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we chop a at nd digits, should we round up?
|
||||||
|
func shouldRoundUp(a *decimal, nd int) bool {
|
||||||
|
if nd < 0 || nd >= a.nd {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a.d[nd] == '5' && nd+1 == a.nd { // exactly halfway - round to even
|
||||||
|
// if we truncated, a little higher than what's recorded - always round up
|
||||||
|
if a.trunc {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return nd > 0 && (a.d[nd-1]-'0')%2 != 0
|
||||||
|
}
|
||||||
|
// not halfway - digit tells all
|
||||||
|
return a.d[nd] >= '5'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Round a to nd digits (or fewer).
|
||||||
|
// If nd is zero, it means we're rounding
|
||||||
|
// just to the left of the digits, as in
|
||||||
|
// 0.09 -> 0.1.
|
||||||
|
func (a *decimal) Round(nd int) {
|
||||||
|
if nd < 0 || nd >= a.nd {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if shouldRoundUp(a, nd) {
|
||||||
|
a.RoundUp(nd)
|
||||||
|
} else {
|
||||||
|
a.RoundDown(nd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Round a down to nd digits (or fewer).
|
||||||
|
func (a *decimal) RoundDown(nd int) {
|
||||||
|
if nd < 0 || nd >= a.nd {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a.nd = nd
|
||||||
|
trim(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Round a up to nd digits (or fewer).
|
||||||
|
func (a *decimal) RoundUp(nd int) {
|
||||||
|
if nd < 0 || nd >= a.nd {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// round up
|
||||||
|
for i := nd - 1; i >= 0; i-- {
|
||||||
|
c := a.d[i]
|
||||||
|
if c < '9' { // can stop after this digit
|
||||||
|
a.d[i]++
|
||||||
|
a.nd = i + 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number is all 9s.
|
||||||
|
// Change to single 1 with adjusted decimal point.
|
||||||
|
a.d[0] = '1'
|
||||||
|
a.nd = 1
|
||||||
|
a.dp++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract integer part, rounded appropriately.
|
||||||
|
// No guarantees about overflow.
|
||||||
|
func (a *decimal) RoundedInteger() uint64 {
|
||||||
|
if a.dp > 20 {
|
||||||
|
return 0xFFFFFFFFFFFFFFFF
|
||||||
|
}
|
||||||
|
var i int
|
||||||
|
n := uint64(0)
|
||||||
|
for i = 0; i < a.dp && i < a.nd; i++ {
|
||||||
|
n = n*10 + uint64(a.d[i]-'0')
|
||||||
|
}
|
||||||
|
for ; i < a.dp; i++ {
|
||||||
|
n *= 10
|
||||||
|
}
|
||||||
|
if shouldRoundUp(a, a.dp) {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,160 @@
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Multiprecision decimal numbers.
|
||||||
|
// For floating-point formatting only; not general purpose.
|
||||||
|
// Only operations are assign and (binary) left/right shift.
|
||||||
|
// Can do binary floating point in multiprecision decimal precisely
|
||||||
|
// because 2 divides 10; cannot do decimal floating point
|
||||||
|
// in multiprecision binary precisely.
|
||||||
|
|
||||||
|
package decimal
|
||||||
|
|
||||||
|
type floatInfo struct {
|
||||||
|
mantbits uint
|
||||||
|
expbits uint
|
||||||
|
bias int
|
||||||
|
}
|
||||||
|
|
||||||
|
var float32info = floatInfo{23, 8, -127}
|
||||||
|
var float64info = floatInfo{52, 11, -1023}
|
||||||
|
|
||||||
|
// roundShortest rounds d (= mant * 2^exp) to the shortest number of digits
|
||||||
|
// that will let the original floating point value be precisely reconstructed.
|
||||||
|
func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) {
|
||||||
|
// If mantissa is zero, the number is zero; stop now.
|
||||||
|
if mant == 0 {
|
||||||
|
d.nd = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute upper and lower such that any decimal number
|
||||||
|
// between upper and lower (possibly inclusive)
|
||||||
|
// will round to the original floating point number.
|
||||||
|
|
||||||
|
// We may see at once that the number is already shortest.
|
||||||
|
//
|
||||||
|
// Suppose d is not denormal, so that 2^exp <= d < 10^dp.
|
||||||
|
// The closest shorter number is at least 10^(dp-nd) away.
|
||||||
|
// The lower/upper bounds computed below are at distance
|
||||||
|
// at most 2^(exp-mantbits).
|
||||||
|
//
|
||||||
|
// So the number is already shortest if 10^(dp-nd) > 2^(exp-mantbits),
|
||||||
|
// or equivalently log2(10)*(dp-nd) > exp-mantbits.
|
||||||
|
// It is true if 332/100*(dp-nd) >= exp-mantbits (log2(10) > 3.32).
|
||||||
|
minexp := flt.bias + 1 // minimum possible exponent
|
||||||
|
if exp > minexp && 332*(d.dp-d.nd) >= 100*(exp-int(flt.mantbits)) {
|
||||||
|
// The number is already shortest.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// d = mant << (exp - mantbits)
|
||||||
|
// Next highest floating point number is mant+1 << exp-mantbits.
|
||||||
|
// Our upper bound is halfway between, mant*2+1 << exp-mantbits-1.
|
||||||
|
upper := new(decimal)
|
||||||
|
upper.Assign(mant*2 + 1)
|
||||||
|
upper.Shift(exp - int(flt.mantbits) - 1)
|
||||||
|
|
||||||
|
// d = mant << (exp - mantbits)
|
||||||
|
// Next lowest floating point number is mant-1 << exp-mantbits,
|
||||||
|
// unless mant-1 drops the significant bit and exp is not the minimum exp,
|
||||||
|
// in which case the next lowest is mant*2-1 << exp-mantbits-1.
|
||||||
|
// Either way, call it mantlo << explo-mantbits.
|
||||||
|
// Our lower bound is halfway between, mantlo*2+1 << explo-mantbits-1.
|
||||||
|
var mantlo uint64
|
||||||
|
var explo int
|
||||||
|
if mant > 1<<flt.mantbits || exp == minexp {
|
||||||
|
mantlo = mant - 1
|
||||||
|
explo = exp
|
||||||
|
} else {
|
||||||
|
mantlo = mant*2 - 1
|
||||||
|
explo = exp - 1
|
||||||
|
}
|
||||||
|
lower := new(decimal)
|
||||||
|
lower.Assign(mantlo*2 + 1)
|
||||||
|
lower.Shift(explo - int(flt.mantbits) - 1)
|
||||||
|
|
||||||
|
// The upper and lower bounds are possible outputs only if
|
||||||
|
// the original mantissa is even, so that IEEE round-to-even
|
||||||
|
// would round to the original mantissa and not the neighbors.
|
||||||
|
inclusive := mant%2 == 0
|
||||||
|
|
||||||
|
// As we walk the digits we want to know whether rounding up would fall
|
||||||
|
// within the upper bound. This is tracked by upperdelta:
|
||||||
|
//
|
||||||
|
// If upperdelta == 0, the digits of d and upper are the same so far.
|
||||||
|
//
|
||||||
|
// If upperdelta == 1, we saw a difference of 1 between d and upper on a
|
||||||
|
// previous digit and subsequently only 9s for d and 0s for upper.
|
||||||
|
// (Thus rounding up may fall outside the bound, if it is exclusive.)
|
||||||
|
//
|
||||||
|
// If upperdelta == 2, then the difference is greater than 1
|
||||||
|
// and we know that rounding up falls within the bound.
|
||||||
|
var upperdelta uint8
|
||||||
|
|
||||||
|
// Now we can figure out the minimum number of digits required.
|
||||||
|
// Walk along until d has distinguished itself from upper and lower.
|
||||||
|
for ui := 0; ; ui++ {
|
||||||
|
// lower, d, and upper may have the decimal points at different
|
||||||
|
// places. In this case upper is the longest, so we iterate from
|
||||||
|
// ui==0 and start li and mi at (possibly) -1.
|
||||||
|
mi := ui - upper.dp + d.dp
|
||||||
|
if mi >= d.nd {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
li := ui - upper.dp + lower.dp
|
||||||
|
l := byte('0') // lower digit
|
||||||
|
if li >= 0 && li < lower.nd {
|
||||||
|
l = lower.d[li]
|
||||||
|
}
|
||||||
|
m := byte('0') // middle digit
|
||||||
|
if mi >= 0 {
|
||||||
|
m = d.d[mi]
|
||||||
|
}
|
||||||
|
u := byte('0') // upper digit
|
||||||
|
if ui < upper.nd {
|
||||||
|
u = upper.d[ui]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Okay to round down (truncate) if lower has a different digit
|
||||||
|
// or if lower is inclusive and is exactly the result of rounding
|
||||||
|
// down (i.e., and we have reached the final digit of lower).
|
||||||
|
okdown := l != m || inclusive && li+1 == lower.nd
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case upperdelta == 0 && m+1 < u:
|
||||||
|
// Example:
|
||||||
|
// m = 12345xxx
|
||||||
|
// u = 12347xxx
|
||||||
|
upperdelta = 2
|
||||||
|
case upperdelta == 0 && m != u:
|
||||||
|
// Example:
|
||||||
|
// m = 12345xxx
|
||||||
|
// u = 12346xxx
|
||||||
|
upperdelta = 1
|
||||||
|
case upperdelta == 1 && (m != '9' || u != '0'):
|
||||||
|
// Example:
|
||||||
|
// m = 1234598x
|
||||||
|
// u = 1234600x
|
||||||
|
upperdelta = 2
|
||||||
|
}
|
||||||
|
// Okay to round up if upper has a different digit and either upper
|
||||||
|
// is inclusive or upper is bigger than the result of rounding up.
|
||||||
|
okup := upperdelta > 0 && (inclusive || upperdelta > 1 || ui+1 < upper.nd)
|
||||||
|
|
||||||
|
// If it's okay to do either, then round to the nearest one.
|
||||||
|
// If it's okay to do only one, do it.
|
||||||
|
switch {
|
||||||
|
case okdown && okup:
|
||||||
|
d.Round(mi + 1)
|
||||||
|
return
|
||||||
|
case okdown:
|
||||||
|
d.RoundDown(mi + 1)
|
||||||
|
return
|
||||||
|
case okup:
|
||||||
|
d.RoundUp(mi + 1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -286,6 +286,9 @@ github.com/redis/go-redis/v9/push
|
||||||
## explicit; go 1.24.0
|
## explicit; go 1.24.0
|
||||||
github.com/rubenv/sql-migrate
|
github.com/rubenv/sql-migrate
|
||||||
github.com/rubenv/sql-migrate/sqlparse
|
github.com/rubenv/sql-migrate/sqlparse
|
||||||
|
# github.com/shopspring/decimal v1.4.0
|
||||||
|
## explicit; go 1.10
|
||||||
|
github.com/shopspring/decimal
|
||||||
# github.com/sirupsen/logrus v1.9.3
|
# github.com/sirupsen/logrus v1.9.3
|
||||||
## explicit; go 1.13
|
## explicit; go 1.13
|
||||||
github.com/sirupsen/logrus
|
github.com/sirupsen/logrus
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue