from decimal import Decimal

from sqlalchemy.orm import Session

from app.core.exceptions import NotFoundError
from app.models.commission_rule import CommissionRule
from app.repositories.insurance_company_repository import InsuranceCompanyRepository


class CommissionService:
    def __init__(self, db: Session):
        self.db = db
        self.companies = InsuranceCompanyRepository(db)

    def list_rules(self, agency_id: int) -> list[dict]:
        rows = (
            self.db.query(CommissionRule)
            .filter(CommissionRule.agency_id == agency_id)
            .order_by(CommissionRule.priority.asc(), CommissionRule.id.asc())
            .all()
        )
        return [self._to_dict(row) for row in rows]

    def create_rule(self, agency_id: int, payload: dict) -> dict:
        rule = CommissionRule(agency_id=agency_id, **payload)
        self.db.add(rule)
        self.db.commit()
        self.db.refresh(rule)
        return self._to_dict(rule)

    def update_rule(self, agency_id: int, rule_id: int, payload: dict) -> dict:
        rule = self._get_rule(agency_id, rule_id)
        for key, value in payload.items():
            if value is not None or key in ("insurance_company_id", "vehicle_type", "coverage_type", "notes"):
                setattr(rule, key, value)
        self.db.commit()
        self.db.refresh(rule)
        return self._to_dict(rule)

    def delete_rule(self, agency_id: int, rule_id: int) -> None:
        rule = self._get_rule(agency_id, rule_id)
        self.db.delete(rule)
        self.db.commit()

    def calculate(
        self,
        agency_id: int,
        *,
        premium_amount: Decimal,
        insurance_company_id: int | None = None,
        vehicle_type: str | None = None,
        coverage_type: str | None = None,
    ) -> dict:
        if premium_amount <= 0:
            return {"commission": Decimal("0.00"), "rule_id": None, "rule_type": None}

        rules = (
            self.db.query(CommissionRule)
            .filter(CommissionRule.agency_id == agency_id, CommissionRule.active.is_(True))
            .all()
        )
        matched = [r for r in rules if self._rule_matches(r, insurance_company_id, vehicle_type, coverage_type)]
        if not matched:
            return {"commission": Decimal("0.00"), "rule_id": None, "rule_type": None}

        matched.sort(key=lambda r: (-self._specificity(r), r.priority, r.id))
        rule = matched[0]
        commission = self._apply_rule(rule, premium_amount)
        return {"commission": commission.quantize(Decimal("0.01")), "rule_id": rule.id, "rule_type": rule.rule_type}

    def auto_commission(
        self,
        agency_id: int,
        premium_amount: Decimal | None,
        insurance_company_id: int | None,
        vehicle_type: str | None,
        coverage_type: str | None,
        provided_commission: Decimal | None,
    ) -> Decimal:
        if provided_commission and provided_commission > 0:
            return provided_commission
        if not premium_amount or premium_amount <= 0:
            return provided_commission or Decimal("0.00")
        result = self.calculate(
            agency_id,
            premium_amount=premium_amount,
            insurance_company_id=insurance_company_id,
            vehicle_type=vehicle_type,
            coverage_type=coverage_type,
        )
        return result["commission"]

    def _get_rule(self, agency_id: int, rule_id: int) -> CommissionRule:
        rule = (
            self.db.query(CommissionRule)
            .filter(CommissionRule.id == rule_id, CommissionRule.agency_id == agency_id)
            .first()
        )
        if not rule:
            raise NotFoundError("Commission rule not found")
        return rule

    @staticmethod
    def _rule_matches(
        rule: CommissionRule,
        insurance_company_id: int | None,
        vehicle_type: str | None,
        coverage_type: str | None,
    ) -> bool:
        if rule.insurance_company_id and rule.insurance_company_id != insurance_company_id:
            return False
        if rule.vehicle_type and (vehicle_type or "").lower() != rule.vehicle_type.lower():
            return False
        if rule.coverage_type and (coverage_type or "").lower() != rule.coverage_type.lower():
            return False
        return True

    @staticmethod
    def _specificity(rule: CommissionRule) -> int:
        score = 0
        if rule.insurance_company_id:
            score += 4
        if rule.vehicle_type:
            score += 2
        if rule.coverage_type:
            score += 1
        return score

    @staticmethod
    def _apply_rule(rule: CommissionRule, premium_amount: Decimal) -> Decimal:
        if rule.rule_type == "percent":
            return premium_amount * rule.value / Decimal("100")
        return Decimal(rule.value)

    def _to_dict(self, rule: CommissionRule) -> dict:
        company_name = None
        if rule.insurance_company_id:
            company = self.companies.get_by_id(rule.insurance_company_id)
            company_name = company.name if company else None
        return {
            "id": rule.id,
            "insurance_company_id": rule.insurance_company_id,
            "insurance_company_name": company_name,
            "vehicle_type": rule.vehicle_type,
            "coverage_type": rule.coverage_type,
            "rule_type": rule.rule_type,
            "value": rule.value,
            "priority": rule.priority,
            "active": rule.active,
            "notes": rule.notes,
        }
