"""Platform-wide metrics and activity for super admins."""

from __future__ import annotations

from sqlalchemy import func
from sqlalchemy.orm import Session

from app.models.agency import Agency
from app.models.enums import ExtractionStatus
from app.models.extraction_result import ExtractionResult
from app.models.parser_correction import ParserCorrection
from app.models.policy import Policy
from app.models.user import User
from app.services.parser_performance_service import ParserPerformanceService


class SuperAdminService:
    def __init__(self, db: Session):
        self.db = db

    def get_overview(self) -> dict:
        agencies = self.db.query(Agency).order_by(Agency.name.asc()).all()
        agency_ids = [agency.id for agency in agencies]

        user_count = self.db.query(User).filter(User.deleted_at.is_(None), User.is_active.is_(True)).count()
        policy_count = self.db.query(Policy).count()
        extractions = self.db.query(ExtractionResult).order_by(ExtractionResult.created_at.desc()).all()
        correction_count = self.db.query(ParserCorrection).count()

        status_breakdown = {
            ExtractionStatus.SUCCESS.value: 0,
            ExtractionStatus.PARTIAL.value: 0,
            ExtractionStatus.FAILED.value: 0,
        }
        pending_reviews = 0
        for row in extractions:
            status_breakdown[row.status.value] = status_breakdown.get(row.status.value, 0) + 1
            if not row.policy_id:
                pending_reviews += 1

        platform_performance = self._platform_parser_performance(agency_ids)

        agency_items = []
        for agency in agencies:
            agency_items.append(self._agency_overview(agency))

        recent_activity = self._recent_activity(limit=40)

        return {
            "overview": {
                "agency_count": len(agencies),
                "user_count": user_count,
                "policy_count": policy_count,
                "extraction_count": len(extractions),
                "pending_reviews": pending_reviews,
                "correction_reviews": correction_count,
                **platform_performance,
                "status_breakdown": status_breakdown,
            },
            "agencies": agency_items,
            "recent_activity": recent_activity,
        }

    def list_agencies(self) -> list[dict]:
        agencies = self.db.query(Agency).order_by(Agency.name.asc()).all()
        return [self._agency_overview(agency) for agency in agencies]

    def list_activity(self, *, limit: int = 50, agency_id: int | None = None) -> list[dict]:
        return self._recent_activity(limit=limit, agency_id=agency_id)

    def _agency_overview(self, agency: Agency) -> dict:
        user_count = (
            self.db.query(User)
            .filter(User.agency_id == agency.id, User.deleted_at.is_(None), User.is_active.is_(True))
            .count()
        )
        policy_count = self.db.query(Policy).filter(Policy.agency_id == agency.id).count()
        extraction_count = self.db.query(ExtractionResult).filter(ExtractionResult.agency_id == agency.id).count()
        pending_reviews = (
            self.db.query(ExtractionResult)
            .filter(ExtractionResult.agency_id == agency.id, ExtractionResult.policy_id.is_(None))
            .count()
        )

        last_extraction = (
            self.db.query(func.max(ExtractionResult.created_at))
            .filter(ExtractionResult.agency_id == agency.id)
            .scalar()
        )
        last_correction = (
            self.db.query(func.max(ParserCorrection.created_at))
            .filter(ParserCorrection.agency_id == agency.id)
            .scalar()
        )
        last_activity = last_extraction
        if last_correction and (not last_activity or last_correction > last_activity):
            last_activity = last_correction

        return {
            "id": agency.id,
            "name": agency.name,
            "phone": agency.phone,
            "email": agency.email,
            "user_count": user_count,
            "policy_count": policy_count,
            "extraction_count": extraction_count,
            "pending_reviews": pending_reviews,
            "last_activity_at": last_activity.isoformat() if last_activity else None,
        }

    def _platform_parser_performance(self, agency_ids: list[int]) -> dict:
        if not agency_ids:
            return {
                "extraction_success_rate": 0.0,
                "field_accuracy_rate": 0.0,
                "overall_parser_score": 0.0,
            }

        total_success = 0
        total_extractions = 0
        accuracy_rates: list[float] = []

        for agency_id in agency_ids:
            agency = self.db.query(Agency).filter(Agency.id == agency_id).first()
            if not agency:
                continue
            if agency.name == "Platform Operations":
                continue
            perf = ParserPerformanceService(self.db).get_performance(agency_id)
            summary = perf["summary"]
            if summary["total_extractions"]:
                total_extractions += summary["total_extractions"]
                breakdown = summary["status_breakdown"]
                total_success += breakdown["success"] + breakdown["partial"]
            if summary["total_reviews"]:
                accuracy_rates.append(summary["field_accuracy_rate"])

        extraction_success_rate = round(total_success / total_extractions, 4) if total_extractions else 0.0
        field_accuracy_rate = round(sum(accuracy_rates) / len(accuracy_rates), 4) if accuracy_rates else 0.0
        overall_parser_score = round((extraction_success_rate + field_accuracy_rate) / 2, 4)
        return {
            "extraction_success_rate": extraction_success_rate,
            "field_accuracy_rate": field_accuracy_rate,
            "overall_parser_score": overall_parser_score,
        }

    def _recent_activity(self, *, limit: int, agency_id: int | None = None) -> list[dict]:
        agency_map = {row.id: row.name for row in self.db.query(Agency).all()}
        user_map = {row.id: row.full_name for row in self.db.query(User).all()}
        items: list[dict] = []

        extraction_query = self.db.query(ExtractionResult).order_by(ExtractionResult.created_at.desc())
        correction_query = self.db.query(ParserCorrection).order_by(ParserCorrection.created_at.desc())
        if agency_id is not None:
            extraction_query = extraction_query.filter(ExtractionResult.agency_id == agency_id)
            correction_query = correction_query.filter(ParserCorrection.agency_id == agency_id)

        for row in extraction_query.limit(limit).all():
            items.append(
                {
                    "id": f"extraction-{row.id}",
                    "activity_type": "extraction",
                    "agency_id": row.agency_id or 0,
                    "agency_name": agency_map.get(row.agency_id, "Unknown"),
                    "user_id": row.uploaded_by,
                    "user_name": user_map.get(row.uploaded_by),
                    "title": row.original_filename,
                    "subtitle": "PDF extracted",
                    "status": row.status.value,
                    "created_at": row.created_at.isoformat() if row.created_at else None,
                }
            )

        for row in correction_query.limit(limit).all():
            changed = sum(1 for field in row.fields or [] if field.get("was_corrected"))
            items.append(
                {
                    "id": f"correction-{row.id}",
                    "activity_type": "correction",
                    "agency_id": row.agency_id,
                    "agency_name": agency_map.get(row.agency_id, "Unknown"),
                    "user_id": row.corrected_by,
                    "user_name": user_map.get(row.corrected_by),
                    "title": row.original_filename,
                    "subtitle": f"{changed} field(s) corrected · {row.source or 'parser_lab'}",
                    "status": row.company_code,
                    "created_at": row.created_at.isoformat() if row.created_at else None,
                }
            )

        items.sort(key=lambda item: item["created_at"] or "", reverse=True)
        return items[:limit]
