"""
Bad spending pattern flags — Story 10.5.

FR-3.4: Surface automatically when:
  1. A category exceeds its budget by >20% for 2 consecutive months
  2. A subscription has no matching transaction in 60+ days
  3. Total dining/entertainment > 15% of monthly income
"""
from __future__ import annotations

from dataclasses import dataclass
from datetime import date, timedelta
from decimal import Decimal

from sqlalchemy import func

from app.extensions import db
from app.models.transaction import Transaction
from app.models.budget import Budget
from app.models.category import Category


_DINING_ENTERTAINMENT = {"Dining", "Entertainment"}
_BUDGET_OVERAGE_THRESHOLD = Decimal("1.20")   # >20% over budget
_INCOME_DINING_THRESHOLD = Decimal("0.15")    # >15% of income


@dataclass
class SpendingFlag:
    flag_type: str     # over_budget | inactive_subscription | dining_entertainment_ratio
    message: str
    category_name: str | None = None
    amount: Decimal | None = None
    threshold: Decimal | None = None


def get_spending_flags(
    monthly_income: Decimal | None,
    subscription_merchants: list[str] | None = None,
) -> list[SpendingFlag]:
    """
    Compute all three flag types for the current month context.

    monthly_income: from Settings (passed as plain value)
    subscription_merchants: list of merchant names detected as subscriptions
    """
    today = date.today()
    flags: list[SpendingFlag] = []

    # ── Flag 1: Over-budget >20% for 2 consecutive months ─────────────────
    flags.extend(_check_consecutive_over_budget(today))

    # ── Flag 2: Inactive subscription (>60 days without charge) ───────────
    if subscription_merchants:
        flags.extend(_check_inactive_subscriptions(subscription_merchants, today))

    # ── Flag 3: Dining/Entertainment > 15% of monthly income ──────────────
    if monthly_income and monthly_income > 0:
        flags.extend(_check_dining_entertainment_ratio(monthly_income, today))

    return flags


def _check_consecutive_over_budget(today: date) -> list[SpendingFlag]:
    flags = []
    for month_offset in [0, 1]:  # current month and previous
        if month_offset == 0:
            year, month = today.year, today.month
        else:
            if today.month == 1:
                year, month = today.year - 1, 12
            else:
                year, month = today.year, today.month - 1

        month_prefix = f"{year}-{month:02d}-%"
        budgets = Budget.query.filter_by(month=month, year=year).all()
        if not budgets:
            continue

        spent_rows = (
            db.session.query(
                Transaction.category_id,
                func.sum(Transaction.amount).label("total"),
            )
            .filter(Transaction.date.like(month_prefix), Transaction.is_credit == False)  # noqa: E712
            .group_by(Transaction.category_id)
            .all()
        )
        spent_map = {r.category_id: Decimal(str(r.total or 0)) for r in spent_rows}

        for budget in budgets:
            if budget.amount <= 0:
                continue
            spent = spent_map.get(budget.category_id, Decimal("0"))
            if spent > budget.amount * _BUDGET_OVERAGE_THRESHOLD:
                cat = Category.query.get(budget.category_id)
                cat_name = cat.name if cat else "Unknown"
                # Check if flagged in the other month too — only add once per pair
                if month_offset == 1:
                    # This is the prior month — check current month also exceeds
                    cur_prefix = f"{today.year}-{today.month:02d}-%"
                    cur_spent_row = (
                        db.session.query(func.sum(Transaction.amount))
                        .filter(
                            Transaction.date.like(cur_prefix),
                            Transaction.category_id == budget.category_id,
                            Transaction.is_credit == False,  # noqa: E712
                        )
                        .scalar()
                    )
                    cur_spent = Decimal(str(cur_spent_row or 0))
                    cur_budget = Budget.query.filter_by(
                        category_id=budget.category_id,
                        month=today.month, year=today.year
                    ).first()
                    if cur_budget and cur_spent > cur_budget.amount * _BUDGET_OVERAGE_THRESHOLD:
                        flags.append(SpendingFlag(
                            flag_type="over_budget",
                            message=f"{cat_name} has exceeded its budget by >20% for 2 consecutive months.",
                            category_name=cat_name,
                            amount=spent,
                            threshold=budget.amount * _BUDGET_OVERAGE_THRESHOLD,
                        ))
    return flags


def _check_inactive_subscriptions(
    subscription_merchants: list[str],
    today: date,
) -> list[SpendingFlag]:
    cutoff = (today - timedelta(days=60)).isoformat()
    flags = []
    for merchant in subscription_merchants:
        last_txn = (
            Transaction.query
            .filter(
                Transaction.merchant_normalized == merchant,
                Transaction.is_credit == False,  # noqa: E712
            )
            .order_by(Transaction.date.desc())
            .first()
        )
        if last_txn is None or last_txn.date < cutoff:
            flags.append(SpendingFlag(
                flag_type="inactive_subscription",
                message=f"Subscription '{merchant}' has no charge in 60+ days.",
                category_name=None,
                amount=None,
            ))
    return flags


def _check_dining_entertainment_ratio(
    monthly_income: Decimal,
    today: date,
) -> list[SpendingFlag]:
    month_prefix = f"{today.year}-{today.month:02d}-%"
    cat_names = _DINING_ENTERTAINMENT
    cats = Category.query.filter(Category.name.in_(cat_names), Category.is_active == True).all()  # noqa: E712
    if not cats:
        return []
    cat_ids = [c.id for c in cats]

    total_spent = (
        db.session.query(func.sum(Transaction.amount))
        .filter(
            Transaction.date.like(month_prefix),
            Transaction.category_id.in_(cat_ids),
            Transaction.is_credit == False,  # noqa: E712
        )
        .scalar()
    )
    total = Decimal(str(total_spent or 0))
    threshold = monthly_income * _INCOME_DINING_THRESHOLD

    if total > threshold:
        return [SpendingFlag(
            flag_type="dining_entertainment_ratio",
            message=(
                f"Dining + Entertainment this month (${total:.2f}) exceeds "
                f"15% of monthly income (${threshold:.2f})."
            ),
            category_name="Dining + Entertainment",
            amount=total,
            threshold=threshold,
        )]
    return []
