"""
Trend analytics — Stories 10.1, 10.2, 10.3.

Monthly spend by category (stacked bar + trend lines) and MoM/YoY comparisons.
SQL aggregation only — no Python loops over raw rows.
"""
from __future__ import annotations

from dataclasses import dataclass
from decimal import Decimal

from sqlalchemy import func

from app.extensions import db
from app.models.transaction import Transaction
from app.models.category import Category


@dataclass
class MonthlySpend:
    year: int
    month: int
    category_name: str
    total: Decimal


@dataclass
class MonthSummary:
    year: int
    month: int
    total: Decimal


def spending_by_category_monthly(months_back: int = 12) -> list[MonthlySpend]:
    """Per-category totals for each of the last N distinct months."""
    rows = (
        db.session.query(
            func.cast(func.substr(Transaction.date, 1, 4), db.Integer).label("year"),
            func.cast(func.substr(Transaction.date, 6, 2), db.Integer).label("month"),
            Category.name.label("category_name"),
            func.sum(Transaction.amount).label("total"),
        )
        .join(Category, Transaction.category_id == Category.id, isouter=True)
        .filter(Transaction.is_credit == False)  # noqa: E712
        .group_by("year", "month", Category.name)
        .order_by("year", "month")
        .all()
    )
    result = [
        MonthlySpend(
            year=r.year, month=r.month,
            category_name=r.category_name or "Uncategorized",
            total=Decimal(str(r.total or 0)),
        )
        for r in rows
    ]
    distinct_months = sorted({(r.year, r.month) for r in result}, reverse=True)[:months_back]
    cutoff = set(distinct_months)
    return [r for r in result if (r.year, r.month) in cutoff]


def monthly_totals(months_back: int = 24) -> list[MonthSummary]:
    """Total monthly spend for the last N months."""
    rows = (
        db.session.query(
            func.cast(func.substr(Transaction.date, 1, 4), db.Integer).label("year"),
            func.cast(func.substr(Transaction.date, 6, 2), db.Integer).label("month"),
            func.sum(Transaction.amount).label("total"),
        )
        .filter(Transaction.is_credit == False)  # noqa: E712
        .group_by("year", "month")
        .order_by("year", "month")
        .all()
    )
    return [
        MonthSummary(year=r.year, month=r.month, total=Decimal(str(r.total or 0)))
        for r in rows
    ]


def mom_comparison(year: int, month: int) -> dict:
    """Current vs. prior month totals."""
    cur_prefix = f"{year}-{month:02d}-%"
    prior_year, prior_month = (year - 1, 12) if month == 1 else (year, month - 1)
    prior_prefix = f"{prior_year}-{prior_month:02d}-%"

    def _total(prefix):
        val = (
            db.session.query(func.sum(Transaction.amount))
            .filter(Transaction.date.like(prefix), Transaction.is_credit == False)  # noqa: E712
            .scalar()
        )
        return Decimal(str(val or 0))

    current, prior = _total(cur_prefix), _total(prior_prefix)
    return {
        "current_year": year, "current_month": month, "current_total": current,
        "prior_year": prior_year, "prior_month": prior_month, "prior_total": prior,
        "delta": current - prior,
        "pct_change": float((current - prior) / prior * 100) if prior > 0 else 0.0,
    }


def yoy_comparison(year: int, month: int) -> dict:
    """Current month vs. same month last year."""
    cur_prefix = f"{year}-{month:02d}-%"
    prior_prefix = f"{year - 1}-{month:02d}-%"

    def _total(prefix):
        val = (
            db.session.query(func.sum(Transaction.amount))
            .filter(Transaction.date.like(prefix), Transaction.is_credit == False)  # noqa: E712
            .scalar()
        )
        return Decimal(str(val or 0))

    current, prior = _total(cur_prefix), _total(prior_prefix)
    return {
        "current_year": year, "prior_year": year - 1, "month": month,
        "current_total": current, "prior_total": prior,
        "delta": current - prior,
        "pct_change": float((current - prior) / prior * 100) if prior > 0 else 0.0,
    }
