from decimal import Decimal

from sqlalchemy.orm import Session

from app.models.policy_payment import PolicyPayment
from app.utils.helpers import update_payment_summary


class PaymentRepository:
    def __init__(self, db: Session):
        self.db = db

    def list_by_policy(self, policy_id: int) -> list[PolicyPayment]:
        return (
            self.db.query(PolicyPayment)
            .filter(PolicyPayment.policy_id == policy_id)
            .order_by(PolicyPayment.payment_date.desc())
            .all()
        )

    def get_by_id(self, payment_id: int, policy_id: int) -> PolicyPayment | None:
        return (
            self.db.query(PolicyPayment)
            .filter(PolicyPayment.id == payment_id, PolicyPayment.policy_id == policy_id)
            .first()
        )

    def create(self, policy_id: int, created_by: int, data: dict) -> PolicyPayment:
        payment = PolicyPayment(policy_id=policy_id, created_by=created_by, **data)
        self.db.add(payment)
        self.db.flush()
        return payment

    def update(self, payment: PolicyPayment, data: dict) -> PolicyPayment:
        for key, value in data.items():
            if value is not None:
                setattr(payment, key, value)
        self.db.flush()
        return payment

    def delete(self, payment: PolicyPayment) -> None:
        self.db.delete(payment)
        self.db.flush()

    def sum_by_policy(self, policy_id: int) -> Decimal:
        from sqlalchemy import func

        result = (
            self.db.query(func.coalesce(func.sum(PolicyPayment.amount), 0))
            .filter(PolicyPayment.policy_id == policy_id)
            .scalar()
        )
        return Decimal(str(result or 0))

    @staticmethod
    def recalculate_policy_payments(policy, total_paid: Decimal) -> None:
        pending, _, status = update_payment_summary(policy.total_commission, total_paid)
        policy.total_paid = total_paid
        policy.pending_amount = pending
        policy.payment_status = status
