from datetime import date, datetime, timezone
from decimal import Decimal
from enum import Enum

from sqlalchemy.orm import Session

from app.repositories.audit_repository import AuditRepository


def _serialize(value) -> str | None:
    if value is None:
        return None
    if isinstance(value, datetime):
        return value.isoformat()
    if isinstance(value, Decimal):
        return str(value)
    if isinstance(value, Enum):
        return value.value
    if isinstance(value, date):
        return value.isoformat()
    return str(value)


POLICY_AUDIT_FIELDS = [
    "policy_number",
    "policy_start_date",
    "policy_end_date",
    "premium_amount",
    "total_commission",
    "total_paid",
    "pending_amount",
    "status",
    "payment_status",
    "insurance_company_id",
    "vehicle_id",
    "coverage_type",
    "notes",
]

CUSTOMER_AUDIT_FIELDS = [
    "name",
    "mobile",
    "email",
    "address",
    "city",
    "state",
    "pincode",
    "whatsapp_opt_out",
]


class AuditService:
    def __init__(self, db: Session):
        self.db = db
        self.repo = AuditRepository(db)

    def log_create(
        self,
        *,
        agency_id: int,
        user_id: int | None,
        entity_type: str,
        entity_id: int,
        fields: dict,
        tracked_fields: list[str],
    ) -> None:
        entries = []
        now = datetime.now(timezone.utc)
        for field in tracked_fields:
            value = fields.get(field)
            if value is not None:
                entries.append(
                    {
                        "agency_id": agency_id,
                        "user_id": user_id,
                        "entity_type": entity_type,
                        "entity_id": entity_id,
                        "action": "create",
                        "field_name": field,
                        "old_value": None,
                        "new_value": _serialize(value),
                        "created_at": now,
                    }
                )
        if entries:
            self.repo.create_entries(entries)

    def log_update(
        self,
        *,
        agency_id: int,
        user_id: int | None,
        entity_type: str,
        entity_id: int,
        before: dict,
        after: dict,
        tracked_fields: list[str],
    ) -> None:
        entries = []
        now = datetime.now(timezone.utc)
        for field in tracked_fields:
            if field not in after:
                continue
            old_val = _serialize(before.get(field))
            new_val = _serialize(after.get(field))
            if old_val != new_val:
                entries.append(
                    {
                        "agency_id": agency_id,
                        "user_id": user_id,
                        "entity_type": entity_type,
                        "entity_id": entity_id,
                        "action": "update",
                        "field_name": field,
                        "old_value": old_val,
                        "new_value": new_val,
                        "created_at": now,
                    }
                )
        if entries:
            self.repo.create_entries(entries)

    def log_delete(
        self,
        *,
        agency_id: int,
        user_id: int | None,
        entity_type: str,
        entity_id: int,
    ) -> None:
        self.repo.create_entries(
            [
                {
                    "agency_id": agency_id,
                    "user_id": user_id,
                    "entity_type": entity_type,
                    "entity_id": entity_id,
                    "action": "delete",
                    "field_name": None,
                    "old_value": None,
                    "new_value": None,
                    "created_at": datetime.now(timezone.utc),
                }
            ]
        )

    def list_entity_history(
        self,
        agency_id: int,
        entity_type: str,
        entity_id: int,
        *,
        page: int = 1,
        page_size: int = 50,
    ) -> tuple[list[dict], int]:
        offset = (page - 1) * page_size
        items, total = self.repo.list_for_entity(
            agency_id, entity_type, entity_id, limit=page_size, offset=offset
        )
        return [self._to_dict(row) for row in items], total

    @staticmethod
    def _to_dict(row) -> dict:
        return {
            "id": row.id,
            "entity_type": row.entity_type,
            "entity_id": row.entity_id,
            "action": row.action,
            "field_name": row.field_name,
            "old_value": row.old_value,
            "new_value": row.new_value,
            "user_id": row.user_id,
            "user_name": row.user.full_name if row.user else None,
            "created_at": row.created_at,
        }

    @staticmethod
    def snapshot_policy(policy) -> dict:
        return {field: getattr(policy, field, None) for field in POLICY_AUDIT_FIELDS}

    @staticmethod
    def snapshot_customer(customer) -> dict:
        return {field: getattr(customer, field, None) for field in CUSTOMER_AUDIT_FIELDS}
