172 lines
3.8 KiB
Go
172 lines
3.8 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"git.gocasts.ir/msaskarzadeh/url-shortner.git/internal/model"
|
|
)
|
|
|
|
type URLRepository struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func NewURLRepository(dns string) (*URLRepository, error) {
|
|
db, err := sql.Open("postgres", dns)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open DB: %w", err)
|
|
}
|
|
|
|
db.SetMaxOpenConns(25)
|
|
db.SetMaxIdleConns(10)
|
|
db.SetConnMaxLifetime(5 * time.Minute)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
if err := db.PingContext(ctx); err != nil {
|
|
return nil, fmt.Errorf("failed to ping DB: %w", err)
|
|
}
|
|
|
|
return &URLRepository{db: db}, nil
|
|
}
|
|
|
|
func (r *URLRepository) Create(ctx context.Context, url *model.URL) error {
|
|
query := `
|
|
INSERT INTO urls (short_code, original_url, user_id, expires_at)
|
|
VALUES ($1, $2, $3, $4)
|
|
RETURNING id, created_at, updated_at`
|
|
|
|
return r.db.QueryRowContext(ctx, query,
|
|
url.ShortCode,
|
|
url.OriginalURL,
|
|
url.UserID,
|
|
url.ExpiresAt,
|
|
).Scan(&url.ID, &url.CreatedAt, &url.UpdatedAt)
|
|
}
|
|
|
|
func (r *URLRepository) GetShortCode(ctx context.Context, code string) (*model.URL, error) {
|
|
query := `
|
|
SELECT id, short_code, original_url, user_id, clicks,
|
|
is_active, created_at, expires_at, updated_at
|
|
FROM urls
|
|
WHERE short_code = $1`
|
|
|
|
url := &model.URL{}
|
|
err := r.db.QueryRowContext(ctx, query, code).Scan(
|
|
&url.ID,
|
|
&url.ShortCode,
|
|
&url.OriginalURL,
|
|
&url.UserID,
|
|
&url.Clicks,
|
|
&url.IsActive,
|
|
&url.CreatedAt,
|
|
&url.ExpiresAt,
|
|
&url.UpdatedAt,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get URL: %w", err)
|
|
}
|
|
|
|
return url, nil
|
|
}
|
|
func (r *URLRepository) IncrementClicks(ctx context.Context, id int64) error {
|
|
query := `
|
|
UPDATE urls
|
|
SET clicks = clicks + 1,
|
|
updated_at = NOW()
|
|
WHERE id = $1`
|
|
_, err := r.db.ExecContext(ctx, query, id)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to increment clicks: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *URLRepository) LogClicks(ctx context.Context, logEntry *model.ClickLog) error {
|
|
query := `
|
|
INSERT INTO click_logs (url_id, ip_address, user_agent, referer)
|
|
VALUES ($1, $2, $3, $4)
|
|
RETURNING id, clicked_at`
|
|
|
|
return r.db.QueryRowContext(
|
|
ctx, query,
|
|
logEntry.URLID,
|
|
logEntry.IPAddress,
|
|
logEntry.UserAgent,
|
|
logEntry.Referer,
|
|
).Scan(&logEntry.ID, &logEntry.ClickedAt)
|
|
}
|
|
|
|
func (r *URLRepository) GetClicks(ctx context.Context, urlID int64, limit int) ([]model.ClickLog, error) {
|
|
query := `
|
|
SELECT id, url_id, ip_address, user_agent, referer, clicked_at
|
|
FROM click_logs
|
|
WHERE url_id = $1
|
|
ORDER BY clicked_at DESC
|
|
LIMIT $2`
|
|
|
|
rows, err := r.db.QueryContext(ctx, query, urlID, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch recent clicks: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var logs []model.ClickLog
|
|
for rows.Next() {
|
|
var l model.ClickLog
|
|
if err := rows.Scan(
|
|
&l.ID, &l.URLID, &l.IPAddress,
|
|
&l.UserAgent, &l.Referer, &l.ClickedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logs = append(logs, l)
|
|
}
|
|
|
|
return logs, nil
|
|
}
|
|
|
|
func (r *URLRepository) DeactivateExpiredLinks(ctx context.Context) (int64, error) {
|
|
query := `
|
|
UPDATE urls
|
|
SET is_active = false,
|
|
updated_at = NOW()
|
|
WHERE expires_at IS NOT NULL
|
|
AND expires_at < NOW()
|
|
AND is_active = true`
|
|
|
|
result, err := r.db.ExecContext(ctx, query)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to deactivate expired links: %w", err)
|
|
}
|
|
|
|
affected, err := result.RowsAffected()
|
|
return affected, err
|
|
}
|
|
|
|
func (r *URLRepository) CheckShortExists(ctx context.Context, code string) (bool, error) {
|
|
query := `SELECT 1 FROM urls WHERE short_code = $1`
|
|
|
|
var exists int
|
|
err := r.db.QueryRowContext(ctx, query, code).Scan(&exists)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return false, nil
|
|
}
|
|
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to check code: %w", err)
|
|
}
|
|
|
|
return true, nil
|
|
}
|