forked from ebhomengo/niki
90 lines
1.9 KiB
Go
90 lines
1.9 KiB
Go
package mysql
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
querier "git.gocasts.ir/ebhomengo/niki/pkg/query_transaction/sql"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type Config struct {
|
|
Username string `koanf:"username"`
|
|
Password string `koanf:"password"`
|
|
Port int `koanf:"port"`
|
|
Host string `koanf:"host"`
|
|
DBName string `koanf:"db_name"`
|
|
}
|
|
|
|
type DB struct {
|
|
config Config
|
|
db *querier.SqlDB
|
|
mu sync.Mutex
|
|
statements map[string]*sql.Stmt
|
|
}
|
|
|
|
func (db *DB) Conn() *querier.SqlDB {
|
|
return db.db
|
|
}
|
|
|
|
// TODO: this temporary to ignore linter error (magic number).
|
|
const (
|
|
dbMaxConnLifetime = time.Minute * 3
|
|
dbMaxOpenConns = 10
|
|
dbMaxIdleConns = 10
|
|
)
|
|
|
|
func New(config Config) *DB {
|
|
// parseTime=true changes the output type of DATE and DATETIME values to time.Time
|
|
// instead of []byte / string
|
|
// The date or datetime like 0000-00-00 00:00:00 is converted into zero value of time.Time
|
|
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@(%s:%d)/%s?parseTime=true",
|
|
config.Username, config.Password, config.Host, config.Port, config.DBName))
|
|
if err != nil {
|
|
panic(fmt.Errorf("can't open mysql db: %w", err))
|
|
}
|
|
|
|
// See "Important settings" section.
|
|
db.SetConnMaxLifetime(dbMaxConnLifetime)
|
|
db.SetMaxOpenConns(dbMaxOpenConns)
|
|
db.SetMaxIdleConns(dbMaxIdleConns)
|
|
|
|
return &DB{
|
|
config: config,
|
|
db: &querier.SqlDB{DB: db},
|
|
statements: make(map[string]*sql.Stmt),
|
|
}
|
|
}
|
|
|
|
func (p *DB) PrepareStatement(key string, query string) (*sql.Stmt, error) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
if stmt, ok := p.statements[key]; ok {
|
|
return stmt, nil
|
|
}
|
|
|
|
stmt, err := p.db.Prepare(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.statements[key] = stmt
|
|
|
|
return stmt, nil
|
|
}
|
|
|
|
func (p *DB) CloseStatements() error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
for _, stmt := range p.statements {
|
|
err := stmt.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
p.statements = make(map[string]*sql.Stmt)
|
|
return nil
|
|
}
|