from datetime import date, datetime, timezone

from sqlalchemy.orm import Session

from app.models.device_push_token import DevicePushToken
from app.models.enums import NotificationType, PushPlatform
from app.models.notification import Notification
from app.models.user import User
from app.models.user_notification_preference import UserNotificationPreference


class NotificationRepository:
    def __init__(self, db: Session):
        self.db = db

    def get_preferences(self, user_id: int) -> UserNotificationPreference:
        prefs = self.db.query(UserNotificationPreference).filter(UserNotificationPreference.user_id == user_id).first()
        if not prefs:
            prefs = UserNotificationPreference(user_id=user_id)
            self.db.add(prefs)
            self.db.flush()
        return prefs

    def update_preferences(self, user_id: int, data: dict) -> UserNotificationPreference:
        prefs = self.get_preferences(user_id)
        for key, value in data.items():
            if value is not None:
                setattr(prefs, key, value)
        self.db.flush()
        return prefs

    def exists_notification(
        self,
        user_id: int,
        policy_id: int,
        notification_type: NotificationType,
        notification_date: date,
    ) -> bool:
        return (
            self.db.query(Notification.id)
            .filter(
                Notification.user_id == user_id,
                Notification.policy_id == policy_id,
                Notification.notification_type == notification_type,
                Notification.notification_date == notification_date,
            )
            .first()
            is not None
        )

    def create_notification(self, data: dict) -> Notification:
        notification = Notification(**data)
        self.db.add(notification)
        self.db.flush()
        return notification

    def list_for_user(
        self,
        user_id: int,
        *,
        unread_only: bool = False,
        limit: int = 20,
        offset: int = 0,
    ) -> tuple[list[Notification], int]:
        query = self.db.query(Notification).filter(Notification.user_id == user_id)
        if unread_only:
            query = query.filter(Notification.is_read.is_(False))
        total = query.count()
        items = query.order_by(Notification.created_at.desc()).offset(offset).limit(limit).all()
        return items, total

    def unread_count(self, user_id: int) -> int:
        return (
            self.db.query(Notification)
            .filter(Notification.user_id == user_id, Notification.is_read.is_(False))
            .count()
        )

    def get_by_id(self, notification_id: int, user_id: int) -> Notification | None:
        return (
            self.db.query(Notification)
            .filter(Notification.id == notification_id, Notification.user_id == user_id)
            .first()
        )

    def mark_read(self, notification: Notification) -> Notification:
        notification.is_read = True
        notification.read_at = datetime.now(timezone.utc)
        self.db.flush()
        return notification

    def mark_all_read(self, user_id: int) -> int:
        now = datetime.now(timezone.utc)
        count = (
            self.db.query(Notification)
            .filter(Notification.user_id == user_id, Notification.is_read.is_(False))
            .update({"is_read": True, "read_at": now}, synchronize_session=False)
        )
        self.db.flush()
        return count

    def list_active_users(self, agency_id: int) -> list[User]:
        return (
            self.db.query(User)
            .filter(User.agency_id == agency_id, User.is_active.is_(True), User.deleted_at.is_(None))
            .all()
        )

    def upsert_device_token(
        self,
        user_id: int,
        platform: PushPlatform,
        token: str,
        device_name: str | None,
    ) -> DevicePushToken:
        existing = (
            self.db.query(DevicePushToken)
            .filter(DevicePushToken.user_id == user_id, DevicePushToken.token == token)
            .first()
        )
        now = datetime.now(timezone.utc)
        if existing:
            existing.platform = platform
            existing.device_name = device_name
            existing.last_used_at = now
            self.db.flush()
            return existing
        record = DevicePushToken(
            user_id=user_id,
            platform=platform,
            token=token,
            device_name=device_name,
            last_used_at=now,
        )
        self.db.add(record)
        self.db.flush()
        return record

    def delete_device_token(self, user_id: int, token: str) -> bool:
        deleted = (
            self.db.query(DevicePushToken)
            .filter(DevicePushToken.user_id == user_id, DevicePushToken.token == token)
            .delete(synchronize_session=False)
        )
        self.db.flush()
        return deleted > 0

    def list_device_tokens(self, user_id: int) -> list[DevicePushToken]:
        return self.db.query(DevicePushToken).filter(DevicePushToken.user_id == user_id).all()
