# Created by aaronkueh on 10/14/2025
# aom/tools/rag/retrieve.py

from __future__ import annotations

import re
from pathlib import Path
from typing import List, Dict, Any, Optional, Union, Iterable
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchAny
from langchain_core.tools import tool
from sentence_transformers import CrossEncoder

from aom.utils.utilities import qdrant_cfg
from aom.tools.rag.rag_helper import load_local_st, list_doc_ids, load_rag_profile


# -------------------- Read profile + retrieve block --------------------
_PROF = load_rag_profile("rag_config.yaml")
_RETRIEVE = _PROF.get("retrieve", {})
_COLLECTION = _PROF.get("collection", "mas_manuals")

_EMBED_MODEL_DIR = _RETRIEVE.get("embed_model_dir")
_EMBED_DEVICE = str(_RETRIEVE.get("embed_device"))
_MAX_SEQ_LEN = int(_RETRIEVE.get("max_seq_len"))
_TOP_K_DEFAULT = int(_RETRIEVE.get("top_k"))
_MIN_SCORE_DEFAULT = float(_RETRIEVE.get("min_score"))

_QDRANT_URL, _QDRANT_API_KEY, _QDRANT_DISTANCE = qdrant_cfg(map_distance=False)

# -------------------- Re-ranking config (optional) --------------------
# Expected rag_config.yaml structure under the same profile:
# rerank:
#   enabled: true
#   model_dir: "<cross-encoder-model-path-or-id>"
#   device: "cpu"            # or "cuda"
#   top_k: 5                 # default final top_k for rerank helper
#   initial_k: 30            # how many to fetch from Qdrant before rerank
# _RERANK = _PROF.get("rerank", {}) or {}
# _RERANK_ENABLED: bool = bool(_RERANK.get("enabled", False))
# _RERANK_MODEL_DIR: Optional[str] = _RERANK.get("model_dir")
# _RERANK_DEVICE: str = str(_RERANK.get("device", _EMBED_DEVICE or "cpu"))
# _RERANK_TOP_K: int = int(_RERANK.get("top_k", _TOP_K_DEFAULT))
# _RERANK_INITIAL_K: int = int(
#     _RERANK.get(
#         "initial_k",
#         max(_RERANK_TOP_K, _TOP_K_DEFAULT * 3)
#     )
# )
#
# _RERANK_MODEL: Optional[CrossEncoder] = None
# if _RERANK_ENABLED and _RERANK_MODEL_DIR:
#     try:
#         _RERANK_MODEL = CrossEncoder(_RERANK_MODEL_DIR, device=_RERANK_DEVICE)
#         # _RERANK_MODEL = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
#     except Exception as e:
#         print(f"[retrieve] Failed to load rerank model from '{_RERANK_MODEL_DIR}': {e}")
#         _RERANK_MODEL = None


# Optional hygiene: drop numeric-heavy table dumps (rare if ingest kept them separate)
_NUMERIC_HEAVY = re.compile(r"(?:\d[\d\W]{0,8}){12,}")


# -------------------- Embedding model (local) --------------------

if not _EMBED_MODEL_DIR:
    raise RuntimeError("profiles.eqp_manuals.retrieve.embed_model_dir is required in rag_config.yaml")

ST = load_local_st(_EMBED_MODEL_DIR, device=_EMBED_DEVICE, max_seq_len=_MAX_SEQ_LEN)
QDR = QdrantClient(url=_QDRANT_URL, api_key=_QDRANT_API_KEY) if _QDRANT_API_KEY else QdrantClient(url=_QDRANT_URL)


# -------------------- Qdrant helpers --------------------
def _page_filter(
    doc_ids: Optional[Union[str, Iterable[str]]] = None,
    pages: Optional[Iterable[int]] = None,
) -> Filter:
    """
    Build a Qdrant filter that:
      - matches doc_id (one or many)
      - restricts type to {'page','table'}
      - optionally restricts to specific pages (one or many)

    Args:
        doc_ids: str or iterable of str (e.g., "ABB-ops" or ["ABB-ops","Yokogawa-maint"])
        pages: iterable of ints (e.g., [20, 21])

    Returns:
        qdrant_client.models.Filter
    """
    must = []

    # doc_id (single or multiple)
    if doc_ids:
        if isinstance(doc_ids, str):
            must.append(FieldCondition(key="doc_id", match=MatchValue(value=doc_ids)))
        else:
            doc_list = list(doc_ids)
            if len(doc_list) == 1:
                must.append(FieldCondition(key="doc_id", match=MatchValue(value=doc_list[0])))
            else:
                must.append(FieldCondition(key="doc_id", match=MatchAny(any=doc_list)))

    # type ∈ {"page","table"}
    must.append(FieldCondition(key="type", match=MatchAny(any=["page", "table"])))
    #
    # # pages (optional, single or multiple)
    # if pages:
    #     page_list = list(pages)
    #     if len(page_list) == 1:
    #         must.append(FieldCondition(key="page", match=MatchValue(value=page_list[0])))
    #     else:
    #         must.append(FieldCondition(key="page", match=MatchAny(any=page_list)))

    return Filter(must=must)


def _search_points(qvec: List[float], limit: int, doc_id: Optional[Union[str, Iterable[str]]]) -> List[Any]:
    """
    Query Qdrant. Collection should use COSINE distance (as in ingest).
    Normalize embeddings at encoded time; Qdrant returns a relevance score.
    """
    resp = QDR.query_points(
        collection_name=_COLLECTION,
        query=qvec,
        limit=limit,
        with_payload=True,
        with_vectors=False,
        query_filter=_page_filter(doc_id),
    )
    return resp.points


def _hygiene(points: List[Any]) -> List[Any]:
    cleaned = []
    for p in points:
        txt = (p.payload or {}).get("text", "") or ""
        if _NUMERIC_HEAVY.search(txt):
            continue
        cleaned.append(p)
    return cleaned


def filter_doc_ids(doc_ids: List[str], maker: Union[str, List[str]]) -> List[str]:
    """Return doc_ids that contain any of the maker substring(s) (case-insensitive)."""
    # Normalize maker to a list
    if isinstance(maker, str):
        makers = [maker]
    else:
        makers = maker

    # Special handling for PUB Demo only - To be removed
    makers.append('Pump-Guide')

    # Create case-insensitive keys
    keys = [m.casefold().strip() for m in makers if m]

    # Filter doc_ids that contain any of the maker keys
    filtered = []
    for doc_id in doc_ids:
        doc_lower = doc_id.casefold()
        if any(key in doc_lower for key in keys):
            filtered.append(doc_id)

    return sorted(set(filtered))


# -------------------- Programmatic API --------------------
def retrieve(
    query: str,
    top_k: int = _TOP_K_DEFAULT,
    doc_id: Optional[Union[str, Iterable[str]]] = None,
    maker: List[str] = None,
) -> List[Dict[str, Any]]:
    """
    Encode the query with a local model (normalized), search Qdrant, and return:
      {
        "score": float, # relevance (COSINE-based)
        "doc_id": str or list of str,
        "page": int,
        "title": str,
        "text": str,
        "page_image": str|None, # direct link to the PAGE IMAGE
        "chunk_id": str|None, # original chunk id
        "type": "page"|"table",
        "point_id": str,
        "maker": List str|None,
      }
    """
    all_doc_id = list_doc_ids(client=QDR, collection=_COLLECTION)

    if doc_id is None:
        if maker is not None:
            doc_id_list = filter_doc_ids(all_doc_id, maker)
            doc_id = doc_id_list
        else:
            doc_id = all_doc_id

    qv = ST.encode_query(query, normalize_embeddings=True)
    if isinstance(qv, np.ndarray):
        qv = qv.tolist()

    points = _search_points(qv, limit=top_k, doc_id=doc_id)
    points = _hygiene(points)

    out: List[Dict[str, Any]] = []
    for p in points:
        pl = p.payload or {}
        out.append({
            "score": float(p.score),
            "doc_id": pl.get("doc_id"),
            "page": pl.get("page"),
            "title": pl.get("title") or "",
            "text": pl.get("text") or "",
            "page_image": pl.get("page_image"),
            "chunk_id": pl.get("chunk_id"),
            "type": pl.get("type"),
            "point_id": str(p.id),
        })
    return out


# -------------------- Re-ranking helpers & API --------------------
# def _apply_rerank(
#     query: str,
#     hits: List[Dict[str, Any]],
#     top_k: int,
# ) -> List[Dict[str, Any]]:
#     """
#     Apply cross-encoder re-ranking on existing hits.
#     If rerank model is not available, just return the first top_k hits.
#     """
#     if not hits:
#         return hits
#
#     if not _RERANK_MODEL:
#         # Reranking not configured or failed to load; keep original order.
#         return hits[:top_k]
#
#     # Build (query, passage) pairs
#     pairs = [(query, h.get("text") or "") for h in hits]
#
#     try:
#         scores = _RERANK_MODEL.predict(pairs)
#     except Exception as e:
#         print(f"[retrieve] Rerank failed: {e}")
#         return hits[:top_k]
#
#     enriched: List[Dict[str, Any]] = []
#     for h, s in zip(hits, scores):
#         h2 = dict(h)
#         h2["rerank_score"] = float(s)
#         enriched.append(h2)
#
#     enriched.sort(key=lambda x: x["rerank_score"], reverse=True)
#     return enriched[:top_k]


# def retrieve_rerank(
#     query: str,
#     top_k: int = _RERANK_TOP_K,
#     doc_id: Optional[Union[str, Iterable[str]]] = None,
#     maker: List[str] = None,
#     initial_k: Optional[int] = None,
# ) -> List[Dict[str, Any]]:
#     """
#     Two-stage retrieval:
#       1) Use existing `retrieve` to fetch an initial candidate set (embedding search).
#       2) Re-rank candidates with a cross-encoder if configured.
#     """
#     if initial_k is None:
#         initial_k = max(top_k, _RERANK_INITIAL_K, _TOP_K_DEFAULT)
#
#     # Stage 1: vector search (reuse original behavior)
#     base_hits = retrieve(query, top_k=initial_k, doc_id=doc_id, maker=maker)
#
#     # Stage 2: rerank (no change to original retrieve)
#     return _apply_rerank(query, base_hits, top_k=top_k)


def _format_context(hits: List[Dict[str, Any]]) -> str:
    """
    Build a compact plain-text context for the orchestrator summarizer.
    Each item contains a 'CITE:' line with doc_id, page, title, and page image link.
    """
    if not hits:
        return ""
    parts: List[str] = []
    for h in hits:
        header = (
            f"[{(h.get('type') or 'page').upper()}] "
            f"doc_id={h.get('doc_id','')} · page={h.get('page')} · "
            f"score={float(h.get('score',0.0)):.3f}"
        )
        title = (h.get("title") or "").strip()
        text  = (h.get("text")  or "").strip()
        img   = h.get("page_image") or ""
        cite  = (
            "CITE: "
            f"doc_id={h.get('doc_id','')}; "
            f"page={h.get('page')}; "
            f"title={title}; "
            f"image={img}"
        )
        parts.append("\n".join([header, f"Title: {title}", text, cite]))
    return "\n\n---\n\n".join(parts)


# -------------------- LangChain tool wrappers --------------------
@tool("retrieve_context")
def retrieve_context(
    prompt: str,
    doc_id: Optional[Union[str, Iterable[str]]] = None,
    top_k: int = _TOP_K_DEFAULT,
    min_score: float = _MIN_SCORE_DEFAULT,
    maker: List[str] = None,
) -> str:
    """
    Retrieve relevant context (page and summarized tables) for a prompt using cosine similarity.
    Returns a plain-text block ready for the orchestrator summarizer.
    Each chunk includes a mandatory 'CITE:' line with doc_id, page, title, and page image link.
    Only results with score >= min_score are included.
    """
    hits = retrieve(prompt, top_k=top_k, doc_id=doc_id, maker=maker)
    hits = [h for h in hits if float(h.get("score", 0.0)) >= float(min_score)]
    return _format_context(hits)


# @tool("retrieve_context_rerank")
# def retrieve_context_rerank(
#     prompt: str,
#     doc_id: Optional[Union[str, Iterable[str]]] = None,
#     top_k: int = _RERANK_TOP_K,
#     min_score: float = _MIN_SCORE_DEFAULT,
#     maker: List[str] = None,
#     initial_k: Optional[int] = None,
# ) -> str:
#     """
#     Same as retrieve_context, but with an optional re-ranking stage.
#     Uses `retrieve_rerank` (vector search + cross-encoder rerank) when configured.
#     Falls back to plain vector search ordering if rerank model is unavailable.
#     """
#     hits = retrieve_rerank(prompt, top_k=top_k, doc_id=doc_id, maker=maker, initial_k=initial_k)
#     # Still apply a floor on original vector score to avoid very weak candidates.
#     hits = [h for h in hits if float(h.get("score", 0.0)) >= float(min_score)]
#     return _format_context(hits)


# -------------------- CLI / Demo --------------------
def _print_results(query: str, top_k: int = _TOP_K_DEFAULT, doc_id: Optional[str] = None) -> None:
    hits = retrieve(query, top_k=top_k, doc_id=doc_id)
    if not hits:
        print("No results.")
        return
    for i, h in enumerate(hits, 1):
        print(f"{i:02d}. score={h['score']:.3f}  [Page {h['page']}]  {h['doc_id']}")
        if h.get("page_image"):
            print("    page_image:", h["page_image"])
        snippet = (h.get("text") or "").replace("\n", " ")
        print("    ", snippet, "\n")


if __name__ == "__main__":
    EXAMPLE_QUERY = "Show me the solutions for misalignment issue"
    DOC_ID_FILTER = ['Pump-Guide-Web-1', 'Pump-Guide-1', 'Pump-Guide-2',
                     'EBARA-operation-maintenance', 'EBARA-operation-maintenance-2']
    _print_results(EXAMPLE_QUERY, top_k=_TOP_K_DEFAULT, doc_id=DOC_ID_FILTER)
