#!/usr/bin/env python3
"""
AI Cats — Batch Categorization Entry Point

Usage:
    python3 python/categorize.py <batch_id>

PHP spawns this script as a background process:
    python3 python/categorize.py {batch_id} >> logs/categorize_{batch_id}.log 2>&1 &

The script reads a categorization_batches record, processes each product,
writes predictions and evidence to the DB, applies threshold routing, and
updates the batch status.  PHP polls /api/categorization-batches/{id}/status
— it does NOT read stdout/stderr from this process.

Exit codes:
    0 — completed or completed_with_errors
    1 — fatal error (batch marked 'failed')
"""

import sys
import os
import json
import logging
from datetime import datetime, timezone

# Ensure the python/ directory is on the module path.
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import config
from models.database import get_connection
from services import categorizer, confidence_scorer

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
    stream=sys.stdout,
)
log = logging.getLogger(__name__)

MAX_RETRIES = 3


def get_threshold(conn) -> float:
    """Read the confidence threshold from system_settings; fall back to config default."""
    try:
        with conn.cursor() as cur:
            cur.execute(
                "SELECT value FROM system_settings WHERE key = 'confidence_threshold' LIMIT 1"
            )
            row = cur.fetchone()
            if row and row["value"]:
                return float(row["value"])
    except Exception:
        conn.rollback()
    return config.DEFAULT_CONFIDENCE_THRESHOLD


def get_claude_settings(conn) -> tuple:
    """
    Load the Claude API key, categories list, and manufacturer research sites from the DB.

    Returns:
        (api_key: str | None, categories: list[dict], research_sites: list[dict])
        api_key is None when not configured; other lists are [] on error.
    """
    api_key = None
    try:
        with conn.cursor() as cur:
            cur.execute(
                "SELECT value FROM system_settings WHERE key = 'claude_api_key' LIMIT 1"
            )
            row = cur.fetchone()
            if row and row["value"]:
                api_key = str(row["value"]).strip() or None
    except Exception:
        conn.rollback()

    categories = []
    research_sites = []

    if api_key:
        try:
            with conn.cursor() as cur:
                cur.execute("SELECT akeneo_code AS code, label FROM categories ORDER BY label")
                rows = cur.fetchall()
                categories = [dict(r) for r in rows] if rows else []
        except Exception:
            conn.rollback()
            log.warning("Could not load categories for LLM; will use rules-based fallback.")

        try:
            with conn.cursor() as cur:
                cur.execute(
                    """SELECT manufacturer_code, manufacturer_name, name_match_keywords,
                              product_url_template, is_active
                         FROM manufacturer_research_sites
                        WHERE is_active = TRUE
                        ORDER BY manufacturer_name"""
                )
                rows = cur.fetchall()
                research_sites = [dict(r) for r in rows] if rows else []
        except Exception:
            conn.rollback()
            log.warning("Could not load manufacturer research sites; web lookup disabled.")

    return api_key, categories, research_sites


def mark_batch(conn, batch_id: int, status: str, error_message: str = None, extra: dict = None):
    """Update the batch status fields."""
    now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
    with conn.cursor() as cur:
        sets = ["status = %s", "completed_at = %s"]
        params = [status, now]
        if error_message is not None:
            sets.append("error_message = %s")
            params.append(error_message)
        if extra:
            for col, val in extra.items():
                sets.append(f"{col} = %s")
                params.append(val)
        params.append(batch_id)
        cur.execute(
            f"UPDATE categorization_batches SET {', '.join(sets)} WHERE id = %s",
            params,
        )
    conn.commit()


def write_audit(conn, entity_id: int, action: str, metadata: dict = None):
    now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
    with conn.cursor() as cur:
        cur.execute(
            """
            INSERT INTO audit_log (entity_type, entity_id, action, actor_type, metadata, created_at)
            VALUES ('categorization_batch', %s, %s, 'system', %s, %s)
            """,
            (entity_id, action, json.dumps(metadata) if metadata else None, now),
        )
    conn.commit()


def main():
    if len(sys.argv) < 2:
        log.error("Usage: python3 categorize.py <batch_id>")
        sys.exit(1)

    try:
        batch_id = int(sys.argv[1])
    except ValueError:
        log.error("batch_id must be an integer, got: %s", sys.argv[1])
        sys.exit(1)

    conn = get_connection()
    conn.autocommit = False

    # ------------------------------------------------------------------
    # Load batch record
    # ------------------------------------------------------------------
    with conn.cursor() as cur:
        cur.execute("SELECT * FROM categorization_batches WHERE id = %s", (batch_id,))
        batch = cur.fetchone()

    if not batch:
        log.error("Batch %d not found.", batch_id)
        sys.exit(1)

    batch = dict(batch)
    retry_count = int(batch.get("retry_count", 0))

    # Enforce retry cap.
    if retry_count >= MAX_RETRIES and batch["status"] not in ("pending",):
        log.error(
            "Batch %d has reached the maximum retry limit (%d). Manual intervention required.",
            batch_id, MAX_RETRIES,
        )
        sys.exit(1)

    # ------------------------------------------------------------------
    # Mark batch as processing
    # ------------------------------------------------------------------
    now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
    with conn.cursor() as cur:
        if retry_count > 0:
            cur.execute(
                "UPDATE categorization_batches SET status = 'processing', started_at = %s, retry_count = retry_count + 1 WHERE id = %s",
                (now, batch_id),
            )
            write_audit(conn, batch_id, "batch_retry_started", {"retry_count": retry_count + 1})
        else:
            cur.execute(
                "UPDATE categorization_batches SET status = 'processing', started_at = %s WHERE id = %s",
                (now, batch_id),
            )
    conn.commit()

    # ------------------------------------------------------------------
    # Fetch product IDs from JSONB column
    # ------------------------------------------------------------------
    product_ids = batch.get("product_ids")
    if isinstance(product_ids, str):
        product_ids = json.loads(product_ids)
    if not product_ids:
        mark_batch(conn, batch_id, "failed", "No product IDs in batch.")
        log.error("Batch %d: empty product_ids.", batch_id)
        sys.exit(1)

    threshold = get_threshold(conn)
    claude_api_key, categories, research_sites = get_claude_settings(conn)
    if claude_api_key:
        log.info(
            "Batch %d: Claude API key configured, %d categories, %d manufacturer research sites.",
            batch_id, len(categories), len(research_sites),
        )
    else:
        log.info("Batch %d: No Claude API key; using rules-based categorization.", batch_id)

    log.info("Batch %d: processing %d products (threshold=%.1f)", batch_id, len(product_ids), threshold)

    # Update total_products.
    with conn.cursor() as cur:
        cur.execute(
            "UPDATE categorization_batches SET total_products = %s WHERE id = %s",
            (len(product_ids), batch_id),
        )
    conn.commit()

    # ------------------------------------------------------------------
    # Process each product
    # ------------------------------------------------------------------
    success = 0
    errors  = 0
    processed = 0

    for pid in product_ids:
        with conn.cursor() as cur:
            cur.execute("SELECT * FROM products WHERE id = %s", (pid,))
            product = cur.fetchone()

        if not product:
            log.warning("Batch %d: product %d not found, skipping.", batch_id, pid)
            errors += 1
            processed += 1
            _update_counters(conn, batch_id, processed, success, errors)
            continue

        try:
            did_work = categorizer.categorize_product(
                conn, dict(product), batch_id, threshold,
                claude_api_key=claude_api_key,
                categories=categories,
                research_sites=research_sites,
            )
            conn.commit()
            if did_work:
                success += 1
            # If did_work=False it was already predicted in a prior run — count as success.
            else:
                success += 1
        except Exception as exc:
            conn.rollback()
            log.error("Batch %d: error categorizing product %d: %s", batch_id, pid, exc)
            errors += 1
            # Write audit entry for per-product failure.
            try:
                write_audit(conn, batch_id, "product_categorization_failed", {
                    "product_id": pid,
                    "error": str(exc),
                })
            except Exception:
                pass

        processed += 1
        _update_counters(conn, batch_id, processed, success, errors)

    # ------------------------------------------------------------------
    # Final status
    # ------------------------------------------------------------------
    final_status = "completed" if errors == 0 else "completed_with_errors"
    mark_batch(conn, batch_id, final_status, extra={
        "success_products": success,
        "error_products": errors,
        "processed_products": processed,
    })
    log.info(
        "Batch %d finished: status=%s, success=%d, errors=%d",
        batch_id, final_status, success, errors,
    )
    conn.close()
    sys.exit(0)


def _update_counters(conn, batch_id: int, processed: int, success: int, errors: int):
    try:
        with conn.cursor() as cur:
            cur.execute(
                """UPDATE categorization_batches
                      SET processed_products = %s,
                          success_products   = %s,
                          error_products     = %s
                    WHERE id = %s""",
                (processed, success, errors, batch_id),
            )
        conn.commit()
    except Exception:
        conn.rollback()


if __name__ == "__main__":
    main()
