from decimal import Decimal

from sqlalchemy.orm import Session

from app.api.deps import ROLE_ADMIN
from app.core.exceptions import ForbiddenError, NotFoundError, ValidationError
from app.models.enums import PaymentMode, PolicyCoverageType, PolicyStatus
from app.repositories.payment_repository import PaymentRepository
from app.repositories.policy_repository import PolicyRepository
from app.schemas.payment import PaymentCreate, PaymentUpdate
from app.schemas.policy import PolicyCreate, PolicyResponse, PolicyUpdate
from app.schemas.user_context import CurrentUser
from app.services.audit_service import POLICY_AUDIT_FIELDS, AuditService
from app.utils.helpers import update_payment_summary
from app.utils.body_types import body_type_label
from app.utils.policy_status import compute_policy_status, days_until_expiry
from app.utils.vehicle_types import vehicle_type_label


class PolicyService:
    def __init__(self, db: Session):
        self.db = db
        self.policies = PolicyRepository(db)
        self.payments = PaymentRepository(db)

    def _can_delete(self, policy, current_user: CurrentUser | None) -> bool:
        if not current_user:
            return False
        if current_user.role == ROLE_ADMIN:
            return True
        if policy.created_by is None:
            return False
        return policy.created_by == current_user.id

    def _to_response(self, policy, current_user: CurrentUser | None = None) -> PolicyResponse:
        return PolicyResponse(
            id=policy.id,
            agency_id=policy.agency_id,
            customer_id=policy.customer_id,
            vehicle_id=policy.vehicle_id,
            insurance_company_id=policy.insurance_company_id,
            policy_number=policy.policy_number,
            policy_type=policy.policy_type,
            coverage_type=policy.coverage_type.value if policy.coverage_type else None,
            policy_start_date=policy.policy_start_date,
            policy_end_date=policy.policy_end_date,
            premium_amount=policy.premium_amount,
            idv=policy.idv,
            ncb=policy.ncb,
            status=policy.status.value,
            total_commission=policy.total_commission,
            total_paid=policy.total_paid,
            pending_amount=policy.pending_amount,
            payment_status=policy.payment_status.value,
            notes=policy.notes,
            customer_name=policy.customer.name if policy.customer else None,
            customer_mobile=policy.customer.mobile if policy.customer else None,
            customer_whatsapp_opt_out=policy.customer.whatsapp_opt_out if policy.customer else False,
            vehicle_registration=policy.vehicle.registration_number if policy.vehicle else None,
            vehicle_type=policy.vehicle.vehicle_type if policy.vehicle else None,
            vehicle_type_label=vehicle_type_label(policy.vehicle.vehicle_type if policy.vehicle else None),
            body_type=policy.vehicle.body_type if policy.vehicle else None,
            body_type_label=body_type_label(policy.vehicle.body_type if policy.vehicle else None),
            insurance_company_name=policy.insurance_company.name if policy.insurance_company else None,
            days_until_expiry=days_until_expiry(policy.policy_end_date),
            created_by=policy.created_by,
            can_delete=self._can_delete(policy, current_user),
        )

    def list_policies(self, agency_id: int, current_user: CurrentUser | None = None, **kwargs) -> tuple[list[PolicyResponse], int]:
        items, total = self.policies.list_paginated(agency_id, **kwargs)
        return [self._to_response(p, current_user) for p in items], total

    def get_policy(self, agency_id: int, policy_id: int, current_user: CurrentUser | None = None) -> PolicyResponse:
        policy = self.policies.get_by_id(policy_id, agency_id)
        if not policy:
            raise NotFoundError("Policy not found")
        policy.status = compute_policy_status(policy.policy_end_date, policy.status)
        return self._to_response(policy, current_user)

    def create_policy(
        self, agency_id: int, payload: PolicyCreate, user_id: int | None = None, current_user: CurrentUser | None = None
    ) -> PolicyResponse:
        self._validate_minimum_fields(payload.policy_number, payload.policy_end_date)
        self.ensure_unique_policy_number(agency_id, payload.policy_number)
        data = payload.model_dump()
        if data.get("coverage_type"):
            data["coverage_type"] = PolicyCoverageType(data["coverage_type"])
        data["status"] = compute_policy_status(data.get("policy_end_date"), PolicyStatus.DRAFT)

        from app.services.commission_service import CommissionService
        from app.repositories.vehicle_repository import VehicleRepository

        vehicle = VehicleRepository(self.db).get_by_id(payload.vehicle_id) if payload.vehicle_id else None

        coverage_val = data.get("coverage_type")
        coverage_str = coverage_val.value if isinstance(coverage_val, PolicyCoverageType) else coverage_val
        data["total_commission"] = CommissionService(self.db).auto_commission(
            agency_id,
            data.get("premium_amount"),
            data.get("insurance_company_id"),
            vehicle.vehicle_type if vehicle else None,
            coverage_str,
            data.get("total_commission"),
        )
        pending, _, payment_status = update_payment_summary(data["total_commission"], Decimal("0.00"))
        data["pending_amount"] = pending
        data["payment_status"] = payment_status
        data["created_by"] = user_id
        policy = self.policies.create(agency_id, data)
        AuditService(self.db).log_create(
            agency_id=agency_id,
            user_id=user_id,
            entity_type="policy",
            entity_id=policy.id,
            fields=AuditService.snapshot_policy(policy),
            tracked_fields=POLICY_AUDIT_FIELDS,
        )
        self.db.commit()
        self.db.refresh(policy)
        return self.get_policy(agency_id, policy.id, current_user)

    def update_policy(
        self, agency_id: int, policy_id: int, payload: PolicyUpdate, user_id: int | None = None, current_user: CurrentUser | None = None
    ) -> PolicyResponse:
        policy = self.policies.get_by_id(policy_id, agency_id)
        if not policy:
            raise NotFoundError("Policy not found")
        before = AuditService.snapshot_policy(policy)
        data = {k: v for k, v in payload.model_dump().items() if v is not None}
        if "coverage_type" in data:
            data["coverage_type"] = PolicyCoverageType(data["coverage_type"])
        if "status" in data:
            data["status"] = PolicyStatus(data["status"])
        if "total_commission" in data:
            pending, _, payment_status = update_payment_summary(data["total_commission"], policy.total_paid)
            data["pending_amount"] = pending
            data["payment_status"] = payment_status
        if "policy_number" in data:
            if not str(data.get("policy_number") or "").strip():
                raise ValidationError(
                    "Policy number is required",
                    details=[{"field": "policy_number", "error": "required"}],
                )
            self.ensure_unique_policy_number(agency_id, data["policy_number"], exclude_policy_id=policy_id)
        policy = self.policies.update(policy, data)
        AuditService(self.db).log_update(
            agency_id=agency_id,
            user_id=user_id,
            entity_type="policy",
            entity_id=policy.id,
            before=before,
            after=AuditService.snapshot_policy(policy),
            tracked_fields=POLICY_AUDIT_FIELDS,
        )
        self.db.commit()
        return self.get_policy(agency_id, policy_id, current_user)

    def delete_policy(self, agency_id: int, policy_id: int, current_user: CurrentUser) -> None:
        policy = self.policies.get_by_id(policy_id, agency_id)
        if not policy:
            raise NotFoundError("Policy not found")
        if not self._can_delete(policy, current_user):
            raise ForbiddenError("You can only delete policies you created")
        AuditService(self.db).log_delete(
            agency_id=agency_id,
            user_id=current_user.id,
            entity_type="policy",
            entity_id=policy.id,
        )
        self.policies.soft_delete(policy)
        self.db.commit()

    def mark_renewed(
        self, agency_id: int, policy_id: int, new_policy_id: int | None = None, user_id: int | None = None, current_user: CurrentUser | None = None
    ) -> PolicyResponse:
        policy = self.policies.get_by_id(policy_id, agency_id)
        if not policy:
            raise NotFoundError("Policy not found")
        before = AuditService.snapshot_policy(policy)
        policy.status = PolicyStatus.RENEWED
        if new_policy_id:
            policy.renewed_from_policy_id = new_policy_id
        AuditService(self.db).log_update(
            agency_id=agency_id,
            user_id=user_id,
            entity_type="policy",
            entity_id=policy.id,
            before=before,
            after=AuditService.snapshot_policy(policy),
            tracked_fields=POLICY_AUDIT_FIELDS,
        )
        self.db.commit()
        return self.get_policy(agency_id, policy_id, current_user)

    def add_payment(self, agency_id: int, policy_id: int, user_id: int, payload: PaymentCreate):
        policy = self.policies.get_by_id(policy_id, agency_id)
        if not policy:
            raise NotFoundError("Policy not found")
        payment = self.payments.create(
            policy_id,
            user_id,
            {**payload.model_dump(), "payment_mode": PaymentMode(payload.payment_mode)},
        )
        total_paid = self.payments.sum_by_policy(policy_id)
        self.payments.recalculate_policy_payments(policy, total_paid)
        self.db.commit()
        self.db.refresh(payment)
        return payment

    def list_payments(self, agency_id: int, policy_id: int):
        policy = self.policies.get_by_id(policy_id, agency_id)
        if not policy:
            raise NotFoundError("Policy not found")
        return self.payments.list_by_policy(policy_id)

    def update_payment(self, agency_id: int, policy_id: int, payment_id: int, payload: PaymentUpdate):
        policy = self.policies.get_by_id(policy_id, agency_id)
        if not policy:
            raise NotFoundError("Policy not found")
        payment = self.payments.get_by_id(payment_id, policy_id)
        if not payment:
            raise NotFoundError("Payment not found")
        data = payload.model_dump(exclude_unset=True)
        if "payment_mode" in data and data["payment_mode"]:
            data["payment_mode"] = PaymentMode(data["payment_mode"])
        payment = self.payments.update(payment, data)
        total_paid = self.payments.sum_by_policy(policy_id)
        self.payments.recalculate_policy_payments(policy, total_paid)
        self.db.commit()
        self.db.refresh(payment)
        return payment

    def delete_payment(self, agency_id: int, policy_id: int, payment_id: int) -> None:
        policy = self.policies.get_by_id(policy_id, agency_id)
        if not policy:
            raise NotFoundError("Policy not found")
        payment = self.payments.get_by_id(payment_id, policy_id)
        if not payment:
            raise NotFoundError("Payment not found")
        self.payments.delete(payment)
        total_paid = self.payments.sum_by_policy(policy_id)
        self.payments.recalculate_policy_payments(policy, total_paid)
        self.db.commit()

    @staticmethod
    def _validate_minimum_fields(policy_number, policy_end_date):
        if not policy_end_date:
            raise ValidationError(
                "Policy end date is required",
                details=[{"field": "policy_end_date", "error": "required"}],
            )
        if not policy_number or not str(policy_number).strip():
            raise ValidationError(
                "Policy number is required",
                details=[{"field": "policy_number", "error": "required"}],
            )

    def ensure_unique_policy_number(
        self,
        agency_id: int,
        policy_number: str | None,
        *,
        exclude_policy_id: int | None = None,
    ) -> None:
        if not policy_number or not str(policy_number).strip():
            return
        existing = self.policies.find_by_policy_number(
            agency_id,
            str(policy_number),
            exclude_policy_id=exclude_policy_id,
        )
        if existing:
            raise ValidationError(
                "A policy with this policy number already exists",
                details=[{"field": "policy_number", "error": "duplicate", "existing_policy_id": existing.id}],
            )
