"""Track suggested vs user-submitted field values for parser training feedback."""

from __future__ import annotations

import re
from decimal import Decimal
from typing import Any

from sqlalchemy.orm import Session

from app.models.extraction_result import ExtractionResult
from app.models.parser_correction import ParserCorrection
from app.schemas.upload import ConfirmExtractionRequest
from app.utils.date_parser import parse_flexible_date

TRACKED_FIELDS: tuple[str, ...] = (
    "customer_name",
    "mobile_number",
    "email_address",
    "address",
    "city",
    "state",
    "pincode",
    "nominee_name",
    "vehicle_registration_number",
    "engine_number",
    "chassis_number",
    "vehicle_make",
    "vehicle_model",
    "vehicle_type",
    "body_type",
    "fuel_type",
    "manufacturing_year",
    "insurance_company_id",
    "policy_number",
    "policy_type",
    "coverage_type",
    "policy_start_date",
    "policy_end_date",
    "premium_amount",
    "idv",
    "ncb",
)

SOURCE_PARSER_LAB = "parser_lab"
SOURCE_UPLOAD_REVIEW = "upload_review"


def _display(value: Any) -> str | None:
    if value is None:
        return None
    if isinstance(value, Decimal):
        text = format(value, "f").rstrip("0").rstrip(".")
        return text or "0"
    text = str(value).strip()
    return text or None


def normalize_for_compare(field_name: str, value: Any) -> str:
    text = _display(value) or ""
    if not text:
        return ""

    if field_name == "mobile_number":
        return re.sub(r"\D", "", text)[-10:]

    if field_name in {"policy_start_date", "policy_end_date"}:
        parsed = parse_flexible_date(text)
        return parsed.isoformat() if parsed else text.lower()

    if field_name in {"premium_amount", "idv"}:
        try:
            return format(Decimal(text.replace(",", "")), "f").rstrip("0").rstrip(".") or "0"
        except Exception:
            return text.lower()

    if field_name == "vehicle_registration_number":
        return re.sub(r"[\s\-]", "", text).upper()

    if field_name in {"vehicle_type", "body_type", "coverage_type", "fuel_type", "policy_type"}:
        return re.sub(r"[\s_]+", " ", text).lower()

    if field_name == "manufacturing_year":
        match = re.search(r"\d{4}", text)
        return match.group(0) if match else text

    return re.sub(r"\s+", " ", text).lower()


def values_differ(field_name: str, suggested: Any, submitted: Any) -> bool:
    return normalize_for_compare(field_name, suggested) != normalize_for_compare(field_name, submitted)


def suggested_values_from_defaults(form_defaults: dict[str, Any], extraction: ExtractionResult) -> dict[str, Any]:
    values = {field: form_defaults.get(field) for field in TRACKED_FIELDS}
    values["insurance_company_id"] = extraction.insurance_company_id
    return values


def submitted_values_from_payload(payload: ConfirmExtractionRequest) -> dict[str, Any]:
    return {
        "customer_name": payload.customer_name,
        "mobile_number": payload.mobile_number,
        "email_address": payload.email_address,
        "address": payload.address,
        "city": payload.city,
        "state": payload.state,
        "pincode": payload.pincode,
        "nominee_name": payload.nominee_name,
        "vehicle_registration_number": payload.vehicle_registration_number,
        "engine_number": payload.engine_number,
        "chassis_number": payload.chassis_number,
        "vehicle_make": payload.vehicle_make,
        "vehicle_model": payload.vehicle_model,
        "vehicle_type": payload.vehicle_type,
        "body_type": payload.body_type,
        "fuel_type": payload.fuel_type,
        "manufacturing_year": payload.manufacturing_year,
        "insurance_company_id": payload.insurance_company_id,
        "policy_number": payload.policy_number,
        "policy_type": payload.policy_type,
        "coverage_type": payload.coverage_type,
        "policy_start_date": payload.policy_start_date,
        "policy_end_date": payload.policy_end_date,
        "premium_amount": payload.premium_amount,
        "idv": payload.idv,
        "ncb": payload.ncb,
    }


def build_field_corrections(suggested: dict[str, Any], submitted: dict[str, Any]) -> list[dict]:
    stored: list[dict] = []
    for field_name in TRACKED_FIELDS:
        extracted = suggested.get(field_name)
        corrected = submitted.get(field_name)
        stored.append(
            {
                "field_name": field_name,
                "extracted_value": _display(extracted),
                "corrected_value": _display(corrected),
                "was_corrected": values_differ(field_name, extracted, corrected),
            }
        )
    return stored


def summarize_field_corrections(fields: list[dict]) -> dict[str, int]:
    changed = sum(1 for item in fields if item.get("was_corrected"))
    return {
        "fields_compared": len(fields),
        "fields_changed": changed,
        "fields_unchanged": len(fields) - changed,
    }


class ExtractionCorrectionService:
    def __init__(self, db: Session):
        self.db = db

    def record_correction(
        self,
        *,
        extraction: ExtractionResult,
        agency_id: int,
        user_id: int,
        suggested: dict[str, Any],
        submitted: dict[str, Any],
        source: str,
        notes: str | None = None,
    ) -> ParserCorrection:
        raw = extraction.raw_data or {}
        company_code = raw.get("company_code") or raw.get("company_detection", {}).get("code")
        fields = build_field_corrections(suggested, submitted)

        record = ParserCorrection(
            extraction_id=extraction.id,
            agency_id=agency_id,
            corrected_by=user_id,
            company_code=company_code,
            parser_key=company_code,
            original_filename=extraction.original_filename,
            fields=fields,
            notes=notes,
            source=source,
        )
        self.db.add(record)
        self.db.flush()
        return record

    def correction_stats(self, agency_id: int | None = None, *, company_code: str | None = None) -> dict:
        query = self.db.query(ParserCorrection)
        if agency_id is not None:
            query = query.filter(ParserCorrection.agency_id == agency_id)
        if company_code:
            query = query.filter(ParserCorrection.company_code == company_code)
        rows = query.order_by(ParserCorrection.created_at.desc()).all()

        by_field: dict[str, dict[str, int]] = {}
        by_source: dict[str, dict[str, int]] = {}
        by_company: dict[str, dict[str, int]] = {}
        total_comparisons = 0
        total_changes = 0

        for row in rows:
            source = row.source or SOURCE_PARSER_LAB
            company = row.company_code or "unknown"
            by_source.setdefault(source, {"reviews": 0, "comparisons": 0, "changes": 0})
            by_source[source]["reviews"] += 1
            by_company.setdefault(company, {"reviews": 0, "comparisons": 0, "changes": 0})
            by_company[company]["reviews"] += 1

            for field in row.fields or []:
                name = field.get("field_name")
                if not name:
                    continue
                by_field.setdefault(name, {"comparisons": 0, "changes": 0})
                by_field[name]["comparisons"] += 1
                total_comparisons += 1
                by_source[source]["comparisons"] += 1
                by_company[company]["comparisons"] += 1
                if field.get("was_corrected"):
                    by_field[name]["changes"] += 1
                    total_changes += 1
                    by_source[source]["changes"] += 1
                    by_company[company]["changes"] += 1

        field_stats = []
        for name in TRACKED_FIELDS:
            stats = by_field.get(name, {"comparisons": 0, "changes": 0})
            comparisons = stats["comparisons"]
            changes = stats["changes"]
            field_stats.append(
                {
                    "field_name": name,
                    "comparisons": comparisons,
                    "changes": changes,
                    "change_rate": round(changes / comparisons, 4) if comparisons else 0.0,
                }
            )
        field_stats.sort(key=lambda item: (-item["changes"], -item["change_rate"]))

        return {
            "total_reviews": len(rows),
            "total_field_comparisons": total_comparisons,
            "total_field_changes": total_changes,
            "overall_change_rate": round(total_changes / total_comparisons, 4) if total_comparisons else 0.0,
            "by_field": field_stats,
            "by_source": {
                source: {
                    **stats,
                    "change_rate": round(stats["changes"] / stats["comparisons"], 4)
                    if stats["comparisons"]
                    else 0.0,
                }
                for source, stats in sorted(by_source.items())
            },
            "by_company": [
                {
                    "company_code": company,
                    **stats,
                    "change_rate": round(stats["changes"] / stats["comparisons"], 4)
                    if stats["comparisons"]
                    else 0.0,
                }
                for company, stats in sorted(by_company.items(), key=lambda item: -item[1]["changes"])
            ],
        }

    @staticmethod
    def correction_to_dict(row: ParserCorrection | None) -> dict | None:
        if not row:
            return None
        summary = summarize_field_corrections(row.fields or [])
        return {
            "id": row.id,
            "extraction_id": row.extraction_id,
            "company_code": row.company_code,
            "parser_key": row.parser_key,
            "original_filename": row.original_filename,
            "fields": row.fields,
            "notes": row.notes,
            "source": row.source or SOURCE_PARSER_LAB,
            "changed_count": summary["fields_changed"],
            "fields_compared": summary["fields_compared"],
            "fields_unchanged": summary["fields_unchanged"],
            "created_at": row.created_at.isoformat() if row.created_at else None,
        }
