from datetime import date, timedelta

from sqlalchemy import func, or_
from sqlalchemy.orm import Session, aliased

from app.models.customer import Customer
from app.models.enums import PaymentStatus, PolicyStatus
from app.models.insurance_company import InsuranceCompany
from app.models.policy import Policy
from app.models.policy_payment import PolicyPayment
from app.models.vehicle import Vehicle
from app.schemas.reports import ReportFilters


class ReportRepository:
    def __init__(self, db: Session):
        self.db = db

    def _base_policy_query(self, agency_id: int, filters: ReportFilters):
        query = (
            self.db.query(Policy)
            .join(Customer, Policy.customer_id == Customer.id)
            .outerjoin(Vehicle, Policy.vehicle_id == Vehicle.id)
            .outerjoin(InsuranceCompany, Policy.insurance_company_id == InsuranceCompany.id)
            .filter(Policy.agency_id == agency_id, Policy.deleted_at.is_(None))
        )
        if filters.company_id:
            query = query.filter(Policy.insurance_company_id == filters.company_id)
        if filters.payment_status:
            query = query.filter(Policy.payment_status == PaymentStatus(filters.payment_status))
        return query

    def total_policies_by_status(self, agency_id: int, filters: ReportFilters) -> list[tuple]:
        query = self._base_policy_query(agency_id, filters)
        if filters.date_from:
            query = query.filter(func.date(Policy.created_at) >= filters.date_from)
        if filters.date_to:
            query = query.filter(func.date(Policy.created_at) <= filters.date_to)
        return (
            query.with_entities(Policy.status, func.count(Policy.id))
            .group_by(Policy.status)
            .order_by(Policy.status)
            .all()
        )

    def company_wise_policies(self, agency_id: int, filters: ReportFilters) -> list[tuple]:
        query = self._base_policy_query(agency_id, filters)
        if filters.date_from:
            query = query.filter(func.date(Policy.created_at) >= filters.date_from)
        if filters.date_to:
            query = query.filter(func.date(Policy.created_at) <= filters.date_to)
        return (
            query.with_entities(
                InsuranceCompany.name,
                func.count(Policy.id),
                func.coalesce(func.sum(Policy.premium_amount), 0),
            )
            .group_by(InsuranceCompany.id, InsuranceCompany.name)
            .order_by(func.count(Policy.id).desc())
            .all()
        )

    def expired_policies(self, agency_id: int, filters: ReportFilters) -> list[Policy]:
        today = date.today()
        query = self._base_policy_query(agency_id, filters)
        query = query.filter(
            Policy.status.notin_([PolicyStatus.RENEWED, PolicyStatus.CANCELLED]),
            or_(Policy.status == PolicyStatus.EXPIRED, Policy.policy_end_date < today),
        )
        if filters.date_from:
            query = query.filter(Policy.policy_end_date >= filters.date_from)
        if filters.date_to:
            query = query.filter(Policy.policy_end_date <= filters.date_to)
        return query.order_by(Policy.policy_end_date.desc()).all()

    def renewed_policies(self, agency_id: int, filters: ReportFilters) -> list[tuple[Policy, Policy | None]]:
        NewPolicy = aliased(Policy)
        query = (
            self._base_policy_query(agency_id, filters)
            .filter(Policy.status == PolicyStatus.RENEWED)
            .outerjoin(NewPolicy, Policy.renewed_from_policy_id == NewPolicy.id)
            .add_entity(NewPolicy)
        )
        if filters.date_from:
            query = query.filter(func.date(Policy.updated_at) >= filters.date_from)
        if filters.date_to:
            query = query.filter(func.date(Policy.updated_at) <= filters.date_to)
        return query.order_by(Policy.updated_at.desc()).all()

    def expiring_soon_policies(self, agency_id: int, filters: ReportFilters) -> list[Policy]:
        today = date.today()
        end = today + timedelta(days=filters.expiry_days)
        query = self._base_policy_query(agency_id, filters)
        query = query.filter(
            Policy.status.notin_([PolicyStatus.RENEWED, PolicyStatus.CANCELLED]),
            Policy.policy_end_date.isnot(None),
            Policy.policy_end_date.between(today, end),
        )
        if filters.date_from:
            query = query.filter(Policy.policy_end_date >= filters.date_from)
        if filters.date_to:
            query = query.filter(Policy.policy_end_date <= filters.date_to)
        return query.order_by(Policy.policy_end_date.asc()).all()

    def payment_summary(self, agency_id: int, filters: ReportFilters) -> list[Policy]:
        query = self._base_policy_query(agency_id, filters)
        if filters.agent_id:
            query = query.join(PolicyPayment, PolicyPayment.policy_id == Policy.id).filter(
                PolicyPayment.created_by == filters.agent_id
            )
        if filters.date_from:
            query = query.filter(func.date(Policy.created_at) >= filters.date_from)
        if filters.date_to:
            query = query.filter(func.date(Policy.created_at) <= filters.date_to)
        if filters.agent_id:
            query = query.distinct()
        return query.order_by(Policy.pending_amount.desc(), Customer.name.asc()).all()
