# Created by aaronkueh on 10/14/2025
# aom/tools/rag/ingest.py
# Ingest DocLayout-YOLO extracted outputs into Qdrant.
# - Page markdown -> split by headings -> one point per section
# - Table images -> summarize with table_summarizer -> one point per table
# - Points use UUIDv5 (deterministic)
# - Distance is read from top-level rag_config.yaml: qdrant.distance (e.g., COSINE)
# - Page image URLs use rag_config.yaml profiles.eqp_manuals.ingest.page_image_url_base
from __future__ import annotations

import re
from pathlib import Path
from typing import List, Dict, Tuple
from uuid import uuid5, NAMESPACE_URL

import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
from sentence_transformers import SentenceTransformer

from aom.definitions import CONFIG_DIR
from aom.utils.utilities import config_param, qdrant_cfg
from aom.tools.table_summarizer import summarize_table
from aom.tools.rag.rag_helper import load_local_st, count_tokens


# ---------------------- Config helpers ----------------------
def _load_profile_dict() -> dict:
    profiles = config_param(config_file=f'{CONFIG_DIR}/rag_config.yaml', field="profiles") or {}
    prof = profiles.get("eqp_manuals", {}) if isinstance(profiles, dict) else {}
    if not isinstance(prof, dict):
        raise ValueError("Invalid profile structure in rag_config.yaml for 'eqp_manuals'.")
    return prof

def _paths_from_ingest(ing: dict) -> Tuple[Path, Path, Path]:
    """
    Returns (extracted_dir, page_img_dir, text_dir) from ingest.extracted_dir.
    """
    extracted_dir = Path(ing["extracted_dir"])
    page_img_dir = extracted_dir / "page_img"
    text_dir = extracted_dir / "text"
    return extracted_dir, page_img_dir, text_dir

def _build_page_image_url(base: str, extracted_dir: Path, page: int) -> str:
    """
    URL pattern:
      <base>/<pdf_parent>/extracted/<pdf_stem>/page_img/page_XXX.png

    Where:
      pdf_stem = extracted_dir.name
      pdf_parent = extracted_dir.parent.parent.name (e.g., ".../<pdf_parent>/extracted/<pdf_stem>")
    """
    pdf_stem = extracted_dir.name
    pdf_parent = extracted_dir.parent.parent.name if extracted_dir.parent.parent else ""
    base = (base or "").rstrip("/")
    return f"{base}/{pdf_parent}/extracted/{pdf_stem}/page_img/page_{page:03d}.png"


# ---------------------- Qdrant helpers ----------------------

def _ensure_collection(c: QdrantClient, collection: str, dim: int, distance: Distance) -> None:
    """
    Ensure the collection exists. Uses collection_exists + create_collection.
    Does NOT recreate/drop if it already exists.
    """
    if not c.collection_exists(collection):
        c.create_collection(
            collection_name=collection,
            vectors_config=VectorParams(size=dim, distance=distance),
        )


def _delete_doc(c: QdrantClient, collection: str, doc_id: str) -> None:
    flt = Filter(must=[FieldCondition(key="doc_id", match=MatchValue(value=doc_id))])
    c.delete(collection_name=collection, points_selector=flt)


# ---------------------- File helpers ----------------------
HEADING_RE = re.compile(r"^\s{0,3}#{1,6}\s+(?P<title>.+?)\s*$", re.M)

def _read_page_md(text_dir: Path) -> List[Tuple[int, str]]:
    pages: List[Tuple[int, str]] = []
    for md in sorted(text_dir.glob("page_*.md")):
        m = re.match(r"page_(\d+)\.md$", md.name)
        if not m:
            continue
        page = int(m.group(1))
        pages.append((page, md.read_text(encoding="utf-8", errors="ignore")))
    return pages

def _split_sections(md_text: str) -> List[Tuple[str, str]]:
    lines = md_text.splitlines()
    out: List[Tuple[str, List[str]]] = []
    cur = None
    buf: List[str] = []
    for ln in lines:
        m = HEADING_RE.match(ln)
        if m:
            if cur is not None:
                out.append((cur, buf))
            cur = m.group("title").strip()
            buf = []
        else:
            buf.append(ln)
    if cur is not None:
        out.append((cur, buf))
    if not out:
        return [("untitled", md_text.strip())]
    return [(t, "\n".join(b).strip()) for t, b in out]

def _slugify(s: str, maxlen: int = 80) -> str:
    s = re.sub(r"[^0-9a-zA-Z\- _]+", "", (s or "").strip()).replace(" ", "_")
    s = re.sub(r"_+", "_", s)
    return s[:maxlen] if s else "untitled"

def _embed(st: SentenceTransformer, texts: List[str], batch_size: int) -> np.ndarray:
    if not texts:
        return np.zeros((0, st.get_sentence_embedding_dimension()), dtype=np.float32)
    vecs = st.encode_document(
        texts,
        batch_size=batch_size,
        normalize_embeddings=True,
        convert_to_numpy=True,
        show_progress_bar=True,
    )
    return vecs.astype(np.float32)

def _batched(iterable, n: int):
    buf = []
    for x in iterable:
        buf.append(x)
        if len(buf) >= n:
            yield buf
            buf = []
    if buf:
        yield buf

def _point_id_from_chunk_id(chunk_id: str) -> str:
    return str(uuid5(NAMESPACE_URL, chunk_id))  # deterministic


# ---------------------- Ingest core ----------------------
def ingest_extracted_dir(prof: dict) -> None:
    # Use whatever is present in rag_config.yaml (no validation)
    ing = prof.get("ingest", {})  # keys: extracted_dir, delete_existing_by_doc_id, embed_model_dir, embed_dim, embed_device, max_seq_len, batch_size, page_image_url_base
    collection = prof.get("collection", "mas_manuals")

    extracted_dir, page_img_dir, text_dir = _paths_from_ingest(ing)
    if not extracted_dir.exists():
        raise FileNotFoundError(f"Extracted dir not found: {extracted_dir}")

    # Page image URL base
    url_base = str(ing.get("page_image_url_base", "")).rstrip("/")

    # Qdrant from TOP-LEVEL qdrant block (env indirection)
    qdrant_url, qdrant_api_key, qdrant_distance = qdrant_cfg(map_distance=True)

    # Toggles/embedding settings
    delete_existing = bool(ing.get("delete_existing_by_doc_id", True))
    max_seq_len = int(ing.get("max_seq_len", 256))
    batch_size = int(ing.get("batch_size", 16))
    device = str(ing.get("embed_device", "cpu"))
    model_dir = Path(ing.get("embed_model_dir"))
    embed_dim_cfg = int(ing.get("embed_dim", 768))

    # Embeddings and collection
    st = load_local_st(str(model_dir), device=device, max_seq_len=max_seq_len)
    embed_dim = embed_dim_cfg  # honor config
    client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key) if qdrant_api_key else QdrantClient(url=qdrant_url)
    _ensure_collection(client, collection, embed_dim, distance=qdrant_distance)

    doc_id = extracted_dir.name
    if delete_existing:
        _delete_doc(client, collection, doc_id)

    chunks: List[Dict] = []

    # --- Pages: split by headings ---
    if text_dir.exists():
        for page, md in _read_page_md(text_dir):
            page_image_url = _build_page_image_url(url_base, extracted_dir, page)
            for title, body in _split_sections(md):
                body = (body or "").strip()
                if not body:
                    continue
                slug = _slugify(title)
                chunks.append({
                    "chunk_id": f"{doc_id}:{page:03d}:{slug}",
                    "doc_id": doc_id,
                    "page": page,
                    "title": title,
                    "type": "page",
                    "text": body,
                    "page_image": page_image_url,
                    "num_token": count_tokens(body, model_dir),
                })

    # --- Tables: summarize with LLM ---
    table_img_dir = extracted_dir / "table_img"
    if table_img_dir.exists():
        for img in sorted(table_img_dir.glob("*.png")):
            m = re.match(r"^page_(\d+)_([^.]+)\.png$", img.name)
            if not m:
                continue
            page = int(m.group(1))
            raw = m.group(2)  # caption/label part
            slug = _slugify(raw.replace("table", "table_"))

            try:
                summary = summarize_table(str(img))
            except Exception as e:
                summary = f"[TABLE {raw}] summarization_error: {e}"

            page_image_url = _build_page_image_url(url_base, extracted_dir, page)

            chunks.append({
                "chunk_id": f"{doc_id}:{page:03d}:{slug}",
                "doc_id": doc_id,
                "page": page,
                "title": raw,
                "type": "table",
                "text": summary,
                "page_image": page_image_url,
                "num_token": count_tokens(summary, model_dir),
            })

    if not chunks:
        print("No chunks found to ingest.")
        return

    # --- Embed & upsert ---
    vecs = _embed(st, [c["text"] for c in chunks], batch_size=batch_size)

    points: List[PointStruct] = []
    for c, v in zip(chunks, vecs):
        points.append(PointStruct(
            id=_point_id_from_chunk_id(c["chunk_id"]),
            vector=v.tolist(),
            payload=c,
        ))

    # Upsert in batches
    total = 0
    for batch in _batched(points, 2048):
        client.upsert(collection_name=collection, points=batch)
        total += len(batch)
    dist_str = getattr(qdrant_distance, "value", str(qdrant_distance))
    print(f"Upserted {len(points)} chunks to '{collection}' for doc_id='{doc_id}'. "
          f"(delete_existing={delete_existing}, device={device}, distance={dist_str})")


# ---------------------- Direct run ----------------------
if __name__ == "__main__":
    ing_prof = _load_profile_dict()
    ingest = ing_prof.get("ingest", {})
    top_qd_url, top_qd_key, top_qd_dist = qdrant_cfg(map_distance=True)
    extracted_folder, _, _ = _paths_from_ingest(ingest)

    print("[ingest] profile=eqp_manuals")
    print(f"  collection={ing_prof.get('collection')}")
    print(f"  qdrant_url={top_qd_url}, api_key={'yes' if top_qd_key else 'no'}")
    print(f"  extracted_dir={extracted_folder}")
    print(f"  embed_model_dir={ingest.get('embed_model_dir')}")
    print(f"  embed_dim={ingest.get('embed_dim')}, embed_device={ingest.get('embed_device')}")
    print(f"  max_seq_len={ingest.get('max_seq_len')}, batch_size={ingest.get('batch_size')}")
    print(f"  page_image_url_base={ingest.get('page_image_url_base')}")
    print(f"  delete_existing_by_doc_id={ingest.get('delete_existing_by_doc_id')}")

    ingest_extracted_dir(ing_prof)
