"""
Categorizer — orchestrates evidence building, scoring, and DB writes
for a single product within a categorization batch.
"""

import json
import logging
from datetime import datetime, timezone
from typing import List, Dict, Any, Optional

import psycopg2

from services import evidence_builder, confidence_scorer, llm_categorizer

log = logging.getLogger(__name__)


def categorize_product(
    conn,
    product: Dict[str, Any],
    batch_id: int,
    threshold: float,
    claude_api_key: Optional[str] = None,
    categories: Optional[List[Dict[str, Any]]] = None,
    research_sites: Optional[List[Dict[str, Any]]] = None,
) -> bool:
    """
    Categorize a single product and write prediction + evidence to DB.

    Updates products.status = 'predicted' and the batch processed_products
    counter in the same call.  Caller is responsible for committing.

    Args:
        conn:           Open psycopg2 connection (autocommit=False).
        product:        Product DB row as a dict.
        batch_id:       ID of the parent categorization_batches row.
        threshold:      Confidence threshold from system_settings.
        claude_api_key: Optional Anthropic API key; enables LLM categorization.
        categories:     Optional list of {code, label} dicts for LLM selection.
        research_sites: Optional list of manufacturer research site rows for web lookup.

    Returns:
        True on success, False if the product was skipped (already predicted
        for this batch) or on recoverable failure.

    Raises:
        Exception: On unrecoverable DB errors.
    """
    product_id = product["id"]

    # Skip products that already have a prediction in this batch.
    with conn.cursor() as cur:
        cur.execute(
            "SELECT id FROM predictions WHERE product_id = %s AND batch_id = %s LIMIT 1",
            (product_id, batch_id),
        )
        if cur.fetchone():
            return False  # already done — retry-safe skip

    # Fetch product attributes for enrichment evidence.
    with conn.cursor() as cur:
        cur.execute(
            "SELECT attribute_name, attribute_value FROM product_attributes WHERE product_id = %s",
            (product_id,),
        )
        attributes = cur.fetchall() or []

    # Build rules-based evidence signals (always run; provides supporting evidence).
    signals = evidence_builder.build_evidence(dict(product), [dict(a) for a in attributes])

    # Try LLM categorization if configured; prepend a high-weight signal on success.
    if claude_api_key and categories:
        llm_result = llm_categorizer.categorize(
            api_key=claude_api_key,
            product=dict(product),
            attributes=[dict(a) for a in attributes],
            categories=categories,
            research_sites=research_sites,
        )
        if llm_result:
            log.info(
                "LLM categorized product %s → %s (confidence=%.2f)",
                product.get("id"),
                llm_result["code"],
                llm_result["confidence"],
            )
            signals.insert(0, {
                "source_type":              "llm_categorization",
                "source_label":             "Claude AI",
                "evidence_value":           llm_result["reasoning"],
                "weight":                   0.85,
                "signal_strength":          llm_result["confidence"],
                "suggested_category_code":  llm_result["code"],
                "suggested_category_label": llm_result["label"],
            })

    if not signals:
        signals = [evidence_builder.fallback_signal()]

    # Pick category and compute score from all signals combined.
    category = evidence_builder.pick_category(signals)
    score     = confidence_scorer.compute_score(signals)
    level     = confidence_scorer.derive_level(score, threshold)
    now       = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")

    with conn.cursor() as cur:
        # Insert prediction row.
        cur.execute(
            """
            INSERT INTO predictions
                (product_id, batch_id, suggested_category_code, suggested_category_label,
                 confidence_score, confidence_level, status, created_at, updated_at)
            VALUES (%s, %s, %s, %s, %s, %s, 'predicted', %s, %s)
            RETURNING id
            """,
            (
                product_id,
                batch_id,
                category["code"],
                category["label"],
                score,
                level,
                now,
                now,
            ),
        )
        prediction_id = cur.fetchone()["id"]

        # Insert evidence rows.
        total_weight = sum(s.get("weight", 0.0) for s in signals)
        for sig in signals:
            w = round(sig.get("weight", 0.001), 3)
            cur.execute(
                """
                INSERT INTO evidence
                    (prediction_id, source_type, source_label, evidence_value, weight, created_at)
                VALUES (%s, %s, %s, %s, %s, %s)
                """,
                (
                    prediction_id,
                    sig["source_type"],
                    sig["source_label"],
                    sig["evidence_value"],
                    w,
                    now,
                ),
            )

        # Update product status to 'predicted'.
        cur.execute(
            "UPDATE products SET status = 'predicted', updated_at = %s WHERE id = %s",
            (now, product_id),
        )

    return True
