from collections import defaultdict
from datetime import date
from decimal import Decimal

from sqlalchemy.orm import Session

from app.core.exceptions import ValidationError
from app.repositories.report_repository import ReportRepository
from app.schemas.reports import ReportColumn, ReportFilters, ReportResponse, ReportTypeInfo

REPORT_TYPES: dict[str, ReportTypeInfo] = {
    "total-policies": ReportTypeInfo(
        id="total-policies",
        title="Total Policies",
        description="Policy count grouped by status",
    ),
    "company-wise": ReportTypeInfo(
        id="company-wise",
        title="Company-wise Policies",
        description="Policy count and total premium per insurance company",
    ),
    "expired": ReportTypeInfo(
        id="expired",
        title="Expired Policies",
        description="Policies past expiry date",
    ),
    "renewed": ReportTypeInfo(
        id="renewed",
        title="Renewed Policies",
        description="Policies marked as renewed with old and new policy numbers",
    ),
    "monthly-renewals": ReportTypeInfo(
        id="monthly-renewals",
        title="Monthly Renewals",
        description="Renewal count and commission earned per month",
    ),
    "expiring-soon": ReportTypeInfo(
        id="expiring-soon",
        title="Policies Expiring Soon",
        description="Policies expiring within the selected window",
    ),
    "payment-summary": ReportTypeInfo(
        id="payment-summary",
        title="Payment Summary",
        description="Commission, paid, and pending amounts per policy",
    ),
}


class ReportService:
    def __init__(self, db: Session):
        self.db = db
        self.reports = ReportRepository(db)

    def list_types(self) -> list[ReportTypeInfo]:
        return list(REPORT_TYPES.values())

    def generate(self, agency_id: int, report_type: str, filters: ReportFilters) -> ReportResponse:
        if report_type not in REPORT_TYPES:
            raise ValidationError(f"Unknown report type: {report_type}")

        generators = {
            "total-policies": self._total_policies,
            "company-wise": self._company_wise,
            "expired": self._expired,
            "renewed": self._renewed,
            "monthly-renewals": self._monthly_renewals,
            "expiring-soon": self._expiring_soon,
            "payment-summary": self._payment_summary,
        }
        return generators[report_type](agency_id, filters)

    def _filters_applied(self, filters: ReportFilters) -> dict:
        data = {}
        if filters.date_from:
            data["date_from"] = filters.date_from.isoformat()
        if filters.date_to:
            data["date_to"] = filters.date_to.isoformat()
        if filters.company_id:
            data["company_id"] = filters.company_id
        if filters.payment_status:
            data["payment_status"] = filters.payment_status
        if filters.agent_id:
            data["agent_id"] = filters.agent_id
        return data

    def _total_policies(self, agency_id: int, filters: ReportFilters) -> ReportResponse:
        rows_data = self.reports.total_policies_by_status(agency_id, filters)
        rows = [{"status": status.value, "count": count} for status, count in rows_data]
        total = sum(row["count"] for row in rows)
        return ReportResponse(
            report_type="total-policies",
            title=REPORT_TYPES["total-policies"].title,
            columns=[
                ReportColumn(key="status", label="Status"),
                ReportColumn(key="count", label="Count"),
            ],
            rows=rows,
            summary={"total_policies": total},
            filters_applied=self._filters_applied(filters),
        )

    def _company_wise(self, agency_id: int, filters: ReportFilters) -> ReportResponse:
        rows_data = self.reports.company_wise_policies(agency_id, filters)
        rows = [
            {
                "company": name or "Unknown",
                "policy_count": count,
                "total_premium": float(premium or 0),
            }
            for name, count, premium in rows_data
        ]
        return ReportResponse(
            report_type="company-wise",
            title=REPORT_TYPES["company-wise"].title,
            columns=[
                ReportColumn(key="company", label="Company"),
                ReportColumn(key="policy_count", label="Policy Count"),
                ReportColumn(key="total_premium", label="Total Premium (₹)"),
            ],
            rows=rows,
            summary={
                "total_policies": sum(r["policy_count"] for r in rows),
                "total_premium": sum(r["total_premium"] for r in rows),
            },
            filters_applied=self._filters_applied(filters),
        )

    def _expired(self, agency_id: int, filters: ReportFilters) -> ReportResponse:
        policies = self.reports.expired_policies(agency_id, filters)
        rows = [self._policy_expiry_row(p) for p in policies]
        return ReportResponse(
            report_type="expired",
            title=REPORT_TYPES["expired"].title,
            columns=[
                ReportColumn(key="customer", label="Customer"),
                ReportColumn(key="vehicle", label="Vehicle"),
                ReportColumn(key="company", label="Company"),
                ReportColumn(key="expiry_date", label="Expiry Date"),
                ReportColumn(key="policy_number", label="Policy No"),
            ],
            rows=rows,
            summary={"total": len(rows)},
            filters_applied=self._filters_applied(filters),
        )

    def _renewed(self, agency_id: int, filters: ReportFilters) -> ReportResponse:
        pairs = self.reports.renewed_policies(agency_id, filters)
        rows = []
        for old_policy, new_policy in pairs:
            renewal_date = old_policy.updated_at.date() if old_policy.updated_at else None
            rows.append(
                {
                    "customer": old_policy.customer.name if old_policy.customer else "",
                    "old_policy": old_policy.policy_number or f"#{old_policy.id}",
                    "new_policy": (new_policy.policy_number if new_policy else None) or "—",
                    "renewal_date": renewal_date.isoformat() if renewal_date else "",
                }
            )
        return ReportResponse(
            report_type="renewed",
            title=REPORT_TYPES["renewed"].title,
            columns=[
                ReportColumn(key="customer", label="Customer"),
                ReportColumn(key="old_policy", label="Old Policy"),
                ReportColumn(key="new_policy", label="New Policy"),
                ReportColumn(key="renewal_date", label="Renewal Date"),
            ],
            rows=rows,
            summary={"total": len(rows)},
            filters_applied=self._filters_applied(filters),
        )

    def _monthly_renewals(self, agency_id: int, filters: ReportFilters) -> ReportResponse:
        pairs = self.reports.renewed_policies(agency_id, filters)
        by_month: dict[str, dict] = defaultdict(lambda: {"count": 0, "commission": Decimal("0")})
        for old_policy, new_policy in pairs:
            if not old_policy.updated_at:
                continue
            month_key = old_policy.updated_at.strftime("%Y-%m")
            by_month[month_key]["count"] += 1
            if new_policy and new_policy.total_commission:
                by_month[month_key]["commission"] += new_policy.total_commission

        rows = [
            {
                "month": month,
                "renewal_count": data["count"],
                "commission_earned": float(data["commission"]),
            }
            for month, data in sorted(by_month.items(), reverse=True)
        ]
        return ReportResponse(
            report_type="monthly-renewals",
            title=REPORT_TYPES["monthly-renewals"].title,
            columns=[
                ReportColumn(key="month", label="Month"),
                ReportColumn(key="renewal_count", label="Renewal Count"),
                ReportColumn(key="commission_earned", label="Commission Earned (₹)"),
            ],
            rows=rows,
            summary={
                "total_renewals": sum(r["renewal_count"] for r in rows),
                "total_commission": sum(r["commission_earned"] for r in rows),
            },
            filters_applied=self._filters_applied(filters),
        )

    def _expiring_soon(self, agency_id: int, filters: ReportFilters) -> ReportResponse:
        policies = self.reports.expiring_soon_policies(agency_id, filters)
        today = date.today()
        rows = []
        for p in policies:
            row = self._policy_expiry_row(p)
            if p.policy_end_date:
                row["days_left"] = (p.policy_end_date - today).days
            else:
                row["days_left"] = None
            row["mobile"] = p.customer.mobile if p.customer else ""
            row["pending_payment"] = float(p.pending_amount or 0)
            rows.append(row)
        return ReportResponse(
            report_type="expiring-soon",
            title=REPORT_TYPES["expiring-soon"].title,
            columns=[
                ReportColumn(key="customer", label="Customer"),
                ReportColumn(key="mobile", label="Mobile"),
                ReportColumn(key="vehicle", label="Vehicle"),
                ReportColumn(key="company", label="Company"),
                ReportColumn(key="expiry_date", label="Expiry Date"),
                ReportColumn(key="days_left", label="Days Left"),
                ReportColumn(key="pending_payment", label="Pending (₹)"),
            ],
            rows=rows,
            summary={"total": len(rows), "expiry_days": filters.expiry_days},
            filters_applied={**self._filters_applied(filters), "expiry_days": filters.expiry_days},
        )

    def _payment_summary(self, agency_id: int, filters: ReportFilters) -> ReportResponse:
        policies = self.reports.payment_summary(agency_id, filters)
        rows = [
            {
                "policy_number": p.policy_number or f"#{p.id}",
                "customer": p.customer.name if p.customer else "",
                "commission": float(p.total_commission or 0),
                "paid": float(p.total_paid or 0),
                "pending": float(p.pending_amount or 0),
                "payment_status": p.payment_status.value,
            }
            for p in policies
        ]
        return ReportResponse(
            report_type="payment-summary",
            title=REPORT_TYPES["payment-summary"].title,
            columns=[
                ReportColumn(key="policy_number", label="Policy"),
                ReportColumn(key="customer", label="Customer"),
                ReportColumn(key="commission", label="Commission (₹)"),
                ReportColumn(key="paid", label="Paid (₹)"),
                ReportColumn(key="pending", label="Pending (₹)"),
                ReportColumn(key="payment_status", label="Status"),
            ],
            rows=rows,
            summary={
                "total_commission": sum(r["commission"] for r in rows),
                "total_paid": sum(r["paid"] for r in rows),
                "total_pending": sum(r["pending"] for r in rows),
            },
            filters_applied=self._filters_applied(filters),
        )

    @staticmethod
    def _policy_expiry_row(policy) -> dict:
        return {
            "customer": policy.customer.name if policy.customer else "",
            "vehicle": policy.vehicle.registration_number if policy.vehicle else "",
            "company": policy.insurance_company.name if policy.insurance_company else "",
            "expiry_date": policy.policy_end_date.isoformat() if policy.policy_end_date else "",
            "policy_number": policy.policy_number or f"#{policy.id}",
        }
