"""
Budget analytics and merchant summary — Stories 10.2, 10.6, 10.7.

History-based budget recommendations using 3-month rolling averages.
Merchant spend summary top-20.
"""
from __future__ import annotations

from dataclasses import dataclass
from decimal import Decimal, ROUND_HALF_UP

from sqlalchemy import func, or_

from app.extensions import db
from app.models.transaction import Transaction
from app.models.category import Category
from app.services.insights.budget_recommender import BudgetRecommendation


@dataclass
class MerchantSpend:
    merchant: str
    total: Decimal
    transaction_count: int


def merchant_spend_summary(
    limit: int = 20,
    date_from: str | None = None,
    date_to: str | None = None,
) -> list[MerchantSpend]:
    """Top merchants by total spend (FR-3.6)."""
    q = (
        db.session.query(
            Transaction.merchant_normalized,
            func.sum(Transaction.amount).label("total"),
            func.count(Transaction.id).label("count"),
        )
        .filter(Transaction.is_credit == False)  # noqa: E712
    )
    if date_from:
        q = q.filter(Transaction.date >= date_from)
    if date_to:
        q = q.filter(Transaction.date <= date_to)
    rows = (
        q.group_by(Transaction.merchant_normalized)
        .order_by(func.sum(Transaction.amount).desc())
        .limit(limit)
        .all()
    )
    return [
        MerchantSpend(
            merchant=r.merchant_normalized or "Unknown",
            total=Decimal(str(r.total or 0)),
            transaction_count=r.count,
        )
        for r in rows
    ]


def from_history(monthly_income: Decimal) -> list[BudgetRecommendation]:
    """
    Story 10.7: Budget recommendations from 3-month rolling average spend.

    Completes FR-1.12. Same interface as fifty_thirty_twenty().
    Rounds suggested amounts to nearest $5.
    """
    # Find 3 most recent distinct months
    month_rows = (
        db.session.query(
            func.cast(func.substr(Transaction.date, 1, 4), db.Integer).label("year"),
            func.cast(func.substr(Transaction.date, 6, 2), db.Integer).label("month"),
        )
        .filter(Transaction.is_credit == False)  # noqa: E712
        .group_by("year", "month")
        .order_by(
            func.cast(func.substr(Transaction.date, 1, 4), db.Integer).desc(),
            func.cast(func.substr(Transaction.date, 6, 2), db.Integer).desc(),
        )
        .limit(3)
        .all()
    )

    if not month_rows:
        return []

    month_count = len(month_rows)
    month_prefixes = [f"{r.year}-{r.month:02d}-%" for r in month_rows]

    # Sum per category across those months using or_() for safe multi-month filter
    date_filter = or_(*[Transaction.date.like(p) for p in month_prefixes])

    category_rows = (
        db.session.query(
            Category.name,
            func.sum(Transaction.amount).label("total"),
        )
        .join(Category, Transaction.category_id == Category.id)
        .filter(
            Transaction.is_credit == False,  # noqa: E712
            date_filter,
        )
        .group_by(Category.name)
        .all()
    )

    def _round5(d: Decimal) -> Decimal:
        return (d / 5).quantize(Decimal("1"), rounding=ROUND_HALF_UP) * 5

    recommendations = []
    for row in category_rows:
        avg = Decimal(str(row.total or 0)) / month_count
        suggested = _round5(avg)
        recommendations.append(BudgetRecommendation(
            category_name=row.name,
            suggested_amount=suggested,
            bucket="history",
        ))

    return sorted(recommendations, key=lambda r: r.category_name)
