"""
Staging pipeline — Story 9.1.

Orchestrates: upload → parse → normalize → dedup → stage → review → commit.

Architecture constraints (AR-11):
  - Uses a SEPARATE SQLite engine for staging (instance/import_staging.db)
  - Main DB is touched only to: create ImportBatch, check existing hashes, commit transactions
  - Staging rows are soft-deleted (committed_at set) after successful commit
  - Staging DB is managed by this module — callers never open it directly
"""
from __future__ import annotations

import json
import os
from datetime import datetime, date
from decimal import Decimal
from typing import TYPE_CHECKING

from sqlalchemy import create_engine, text
from sqlalchemy.orm import Session

from app.models.staged_transaction import StagingBase, StagedTransactionModel
from app.services.pdf_parsers.base import StagedTransaction, ParseError, compute_dedup_hash
from app.services import merchant_normalizer
from app.services.duplicate_detector import flag_duplicates

if TYPE_CHECKING:
    from app.models.import_batch import ImportBatch


# ── staging DB engine ─────────────────────────────────────────────────────────

def _get_staging_engine(staging_db_path: str):
    engine = create_engine(f"sqlite:///{staging_db_path}", echo=False)
    StagingBase.metadata.create_all(engine)
    return engine


def _default_staging_path() -> str:
    instance_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "instance")
    os.makedirs(instance_dir, exist_ok=True)
    return os.path.join(instance_dir, "import_staging.db")


# ── public API ────────────────────────────────────────────────────────────────

def stage_transactions(
    staged_txns: list[StagedTransaction],
    parse_errors: list[ParseError],
    import_batch_id: int,
    issuer: str,
    staging_db_path: str | None = None,
) -> dict:
    """
    Stage parsed transactions for user review.

    - Normalizes merchant names
    - Checks for duplicates against existing transaction hashes
    - Writes to staging DB
    - Returns summary dict with counts

    Returns:
        {"staged": int, "duplicates_flagged": int, "parse_errors": int}
    """
    from app.extensions import db
    from app.models.transaction import Transaction
    from app.models.merchant_mapping import MerchantMapping

    path = staging_db_path or _default_staging_path()
    engine = _get_staging_engine(path)

    # Load user-confirmed merchant mappings for normalization override
    user_mappings = [
        {"raw_pattern": m.raw_pattern, "normalized": m.normalized}
        for m in MerchantMapping.query.filter_by(user_confirmed=True).all()
    ]

    # Build set of existing dedup hashes for duplicate detection (3-day window handled by detector)
    existing_hashes = [
        row[0] for row in
        db.session.query(Transaction.dedup_hash).filter(Transaction.dedup_hash.isnot(None)).all()
    ]

    new_hashes = [t.dedup_hash for t in staged_txns]
    duplicate_flags = flag_duplicates(new_hashes, existing_hashes)

    staged_count = 0
    dup_count = sum(1 for f in duplicate_flags if f)

    with Session(engine) as session:
        for txn, is_dup in zip(staged_txns, duplicate_flags):
            # Apply user mapping override if available
            user_name = merchant_normalizer.apply_user_mappings(txn.merchant_raw, user_mappings)
            normalized = user_name if user_name else txn.merchant_normalized

            model = StagedTransactionModel(
                date=txn.date.isoformat(),
                merchant_raw=txn.merchant_raw,
                merchant_normalized=normalized,
                amount=txn.amount,
                is_credit=txn.is_credit,
                issuer=issuer,
                confidence_score=txn.confidence_score,
                dedup_hash=txn.dedup_hash,
                raw_text=txn.raw_text,
                import_batch_id=import_batch_id,
                status="duplicate" if is_dup else "pending",
            )
            session.add(model)
            staged_count += 1
        session.commit()

    return {
        "staged": staged_count,
        "duplicates_flagged": dup_count,
        "parse_errors": len(parse_errors),
    }


def get_staged_transactions(
    import_batch_id: int,
    staging_db_path: str | None = None,
    include_duplicates: bool = False,
) -> list[StagedTransactionModel]:
    """Load staged transactions for a batch from the staging DB."""
    path = staging_db_path or _default_staging_path()
    engine = _get_staging_engine(path)
    with Session(engine) as session:
        q = session.query(StagedTransactionModel).filter(
            StagedTransactionModel.import_batch_id == import_batch_id,
            StagedTransactionModel.committed_at.is_(None),
        )
        if not include_duplicates:
            q = q.filter(StagedTransactionModel.status != "rejected")
        rows = q.order_by(StagedTransactionModel.date.desc()).all()
        # Detach from session for use outside
        session.expunge_all()
        return rows


def update_staged_transaction(
    staged_id: int,
    merchant_normalized: str | None = None,
    amount: Decimal | None = None,
    category_id: int | None = None,
    account_id: int | None = None,
    status: str | None = None,
    staging_db_path: str | None = None,
) -> None:
    """Update a single staged transaction row."""
    path = staging_db_path or _default_staging_path()
    engine = _get_staging_engine(path)
    with Session(engine) as session:
        row = session.get(StagedTransactionModel, staged_id)
        if row is None:
            raise ValueError(f"Staged transaction {staged_id} not found")
        if merchant_normalized is not None:
            row.merchant_normalized = merchant_normalized
        if amount is not None:
            row.amount = amount
        if category_id is not None:
            row.category_id = category_id
        if account_id is not None:
            row.account_id = account_id
        if status is not None:
            row.status = status
        session.commit()


def commit_staged(
    import_batch_id: int,
    default_account_id: int,
    staging_db_path: str | None = None,
) -> int:
    """
    Promote accepted staged transactions to the main Transaction table.

    - Only commits rows with status='pending' or 'accepted' (not rejected/duplicate)
    - Sets committed_at on each staging row after commit
    - Updates ImportBatch.status to 'committed'
    - Returns count of committed transactions
    """
    from app.extensions import db
    from app.models.transaction import Transaction as TxnModel
    from app.models.import_batch import ImportBatch

    path = staging_db_path or _default_staging_path()
    engine = _get_staging_engine(path)

    committed = 0
    with Session(engine) as staging_session:
        rows = staging_session.query(StagedTransactionModel).filter(
            StagedTransactionModel.import_batch_id == import_batch_id,
            StagedTransactionModel.committed_at.is_(None),
            StagedTransactionModel.status.in_(["pending", "accepted"]),
        ).all()

        now = datetime.utcnow()
        for row in rows:
            acct_id = row.account_id or default_account_id
            new_txn = TxnModel(
                date=row.date,
                merchant_normalized=row.merchant_normalized,
                merchant_raw=row.merchant_raw,
                amount=row.amount,
                is_credit=row.is_credit,
                is_manual=False,
                issuer=row.issuer,
                dedup_hash=row.dedup_hash,
                confidence_score=row.confidence_score,
                account_id=acct_id,
                category_id=row.category_id,
                import_batch_id=import_batch_id,
            )
            db.session.add(new_txn)
            row.committed_at = now
            row.status = "committed"
            committed += 1

        staging_session.commit()

    batch = ImportBatch.query.get(import_batch_id)
    if batch:
        batch.status = "committed"
        batch.committed_at = datetime.utcnow()
        batch.row_count = committed
    db.session.commit()

    return committed


def staged_txn_to_model(txn: StagedTransaction, import_batch_id: int) -> StagedTransactionModel:
    """Convert a StagedTransaction dataclass to a StagedTransactionModel ORM object."""
    return StagedTransactionModel(
        date=txn.date.isoformat(),
        merchant_raw=txn.merchant_raw,
        merchant_normalized=txn.merchant_normalized,
        amount=txn.amount,
        is_credit=txn.is_credit,
        issuer=txn.issuer,
        confidence_score=txn.confidence_score,
        dedup_hash=txn.dedup_hash,
        raw_text=txn.raw_text,
        import_batch_id=import_batch_id,
        status="pending",
    )


def auto_categorize(
    merchant_normalized: str,
    user_mappings: list[dict],
    default_category_map: dict[str, int],
) -> int | None:
    """
    Story 9.5: Resolve category_id for a merchant from user mappings.

    user_mappings: list of dicts with 'normalized', 'category_id'
    default_category_map: {merchant_normalized: category_id} for built-in rules
    Returns category_id or None if no match.
    """
    # Check exact user-confirmed mappings first
    for m in user_mappings:
        if m.get("normalized", "").lower() == merchant_normalized.lower() and m.get("category_id"):
            return m["category_id"]

    # Fall back to built-in default map
    return default_category_map.get(merchant_normalized)
