package mysqladdress

import (
	"context"
	"database/sql"
	"errors"

	"git.gocasts.ir/ebhomengo/niki/entity"
	errmsg "git.gocasts.ir/ebhomengo/niki/pkg/err_msg"
	richerror "git.gocasts.ir/ebhomengo/niki/pkg/rich_error"
	"git.gocasts.ir/ebhomengo/niki/repository/mysql"
)

func (d *DB) CreateBenefactorAddress(ctx context.Context, address entity.Address) (entity.Address, error) {
	const op = "mysqladdress.createBenefactorAddress"

	provinceID, err := d.getProvinceIDByCityID(ctx, address.CityID)
	if err != nil {
		return entity.Address{}, err
	}
	address.ProvinceID = provinceID

	query := `INSERT INTO addresses (postal_code, address, lat, lon, name, city_id, province_id, benefactor_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`
	//nolint
	stmt, err := d.conn.PrepareStatement(ctx, mysql.StatementKeyAddressCreateForBenefactor, query)
	if err != nil {
		return entity.Address{}, richerror.New(op).WithErr(err).
			WithMessage(errmsg.ErrorMsgCantPrepareStatement).WithKind(richerror.KindUnexpected)
	}

	res, err := stmt.ExecContext(ctx,
		address.PostalCode, address.Address, address.Lat, address.Lon, address.Name, address.CityID, provinceID, address.BenefactorID)
	if err != nil {
		return entity.Address{}, richerror.New(op).WithErr(err).
			WithMessage(errmsg.ErrorMsgCantInsertRecord).WithKind(richerror.KindUnexpected)
	}

	// Get the ID of the newly inserted record
	id, err := res.LastInsertId()
	if err != nil {
		return entity.Address{}, richerror.New(op).WithErr(err).
			WithMessage(errmsg.ErrorMsgCantRetrieveLastInsertID).WithKind(richerror.KindUnexpected)
	}
	address.ID = uint(id)

	return address, nil
}

func (d *DB) getProvinceIDByCityID(ctx context.Context, cityID uint) (uint, error) {
	const op = "mysqladdress.getProvinceIDByCityID"

	query := `SELECT province_id FROM cities WHERE id = ?`
	//nolint
	stmt, err := d.conn.PrepareStatement(ctx, mysql.StatementKeyCityGetProvinceIDByID, query)
	if err != nil {
		return 0, richerror.New(op).WithErr(err).
			WithMessage(errmsg.ErrorMsgCantPrepareStatement).WithKind(richerror.KindUnexpected)
	}

	var provinceID uint
	pErr := stmt.QueryRowContext(ctx, cityID).Scan(&provinceID)
	if pErr != nil && !errors.Is(pErr, sql.ErrNoRows) {
		return 0, richerror.New(op).WithErr(pErr).
			WithMessage(errmsg.ErrorMsgCantScanQueryResult).WithKind(richerror.KindUnexpected)
	}

	return provinceID, nil
}