# Created by aaronkueh on 10/14/2025
# aom/tools/rag/pdf_extract.py
# Page-level extraction with DocLayout-YOLO masking (OCR removed):

from __future__ import annotations

import json
from dataclasses import dataclass
from typing import List, Tuple, Optional
from pathlib import Path

import cv2
import numpy as np
import fitz  # PyMuPDF
from PIL import Image

from aom.definitions import CONFIG_DIR
from aom.utils.utilities import config_param
from doclayout_yolo import YOLOv10


# -------------------------- Data structures --------------------------
@dataclass
class BBox:
    x0: float
    y0: float
    x1: float
    y1: float
    cls: str
    score: float

    @property
    def w(self) -> float:
        return max(0.0, self.x1 - self.x0)

    @property
    def h(self) -> float:
        return max(0.0, self.y1 - self.y0)

    @property
    def area(self) -> float:
        return self.w * self.h


# -------------------------- Small utilities --------------------------
def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)

def page_to_image(doc: fitz.Document, page_index: int, dpi: int) -> Tuple[fitz.Page, np.ndarray, float]:
    page = doc[page_index]
    zoom = dpi / 72.0
    mat = fitz.Matrix(zoom, zoom)
    pix = page.get_pixmap(matrix=mat, alpha=False)
    img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.h, pix.w, 3)
    return page, img, zoom

def clip_text(page: fitz.Page, rect_pts: Tuple[float, float, float, float]) -> str:
    rect = fitz.Rect(*rect_pts)
    blocks = page.get_text("blocks", clip=rect)
    lines = []
    for b in blocks:
        if len(b) >= 5:
            s = str(b[4]).strip()
            if s:
                lines.append(s)
    return "\n".join(lines).strip()

def load_detector(model_path: str) -> YOLOv10:
    return YOLOv10(model_path)

def to_pts(xyxy_pixels: Tuple[float, float, float, float], zoom: float) -> Tuple[float, float, float, float]:
    x0, y0, x1, y1 = xyxy_pixels
    inv = 1.0 / max(zoom, 1e-6)
    return x0 * inv, y0 * inv, x1 * inv, y1 * inv

def crop_image(arr: np.ndarray, box: BBox) -> Image.Image:
    x0, y0, x1, y1 = map(int, [box.x0, box.y0, box.x1, box.y1])
    x0 = max(0, x0); y0 = max(0, y0)
    x1 = min(arr.shape[1], x1); y1 = min(arr.shape[0], y1)
    crop = arr[y0:y1, x0:x1, :]
    return Image.fromarray(crop)

def find_nearby_caption_text(page: fitz.Page, pts_box: Tuple[float, float, float, float], search_margin_pts: float = 12.0) -> Optional[str]:
    """Search for a short text line just above or below the box that contains 'Table' to use in a filename."""
    x0, y0, x1, y1 = pts_box
    above = fitz.Rect(x0, max(0, y0 - search_margin_pts * 2), x1, y0)
    below = fitz.Rect(x0, y1, x1, y1 + search_margin_pts * 2)
    texts = []
    for band in (above, below):
        blocks = page.get_text("blocks", clip=band)
        for b in blocks:
            if len(b) >= 5:
                s = str(b[4]).strip()
                if s:
                    texts.append(s)
    for s in texts:
        if "table" in s.lower():
            sanitized = "".join(ch for ch in s if ch.isalnum() or ch in (" ", "_", "-")).strip().replace(" ", "_")
            return sanitized[:60]
    return None

def y_center(b: BBox) -> float:
    return (b.y0 + b.y1) * 0.5

def _split_and_sort_by_x(text_boxes: List[BBox], split_x: float) -> List[BBox]:
    """Split text boxes at split_x into left/right, sort each by (y0, x0), return left+right."""
    left  = [b for b in text_boxes if ((b.x0 + b.x1) * 0.5) <  split_x]
    right = [b for b in text_boxes if ((b.x0 + b.x1) * 0.5) >= split_x]
    left.sort(key=lambda b: (b.y0, b.x0))
    right.sort(key=lambda b: (b.y0, b.x0))
    return left + right


# -------------------------- Column ordering (config-driven) --------------------------

def order_text_boxes(
    boxes: List[BBox],
    page_width: int,
    *,
    mode: str = "auto",
    two_col_gap_pct: float = 0.06,
    vertical_two_split_pct: float = 0.50,
) -> List[BBox]:
    """
    Unified ordering:
      - natural:      single stream (top->bottom, then left->right)
      - vertical_two: force two vertical columns split at vertical_two_split_pct of page width
      - auto:         detect 2 columns via largest x-center's gap >= two_col_gap_pct * page_width,
                      else treat as single column.
    Returns TEXT boxes ordered for reading.
    """
    mode = (mode or "auto").strip().lower()

    def _is_texty(cls: str) -> bool:
        return (
            ("title" in cls) or ("heading" in cls) or
            ("plain" in cls and "text" in cls) or ("paragraph" in cls) or
            (cls == "text")
        )

    text_boxes = [b for b in boxes if _is_texty(b.cls)]
    if not text_boxes:
        return []

    # natural: simple sort top->bottom, then left->right
    if mode == "natural":
        return sorted(text_boxes, key=lambda b: (b.y0, b.x0))

    # vertical_two: fixed split at a configured ratio
    if mode == "vertical_two":
        split_x = float(page_width) * float(vertical_two_split_pct)
        return _split_and_sort_by_x(text_boxes, split_x)

    # auto: adaptive 2-column detection by largest x-center gap
    xs = sorted(((b.x0 + b.x1) * 0.5 for b in text_boxes))
    thresh = float(two_col_gap_pct) * float(page_width)

    gaps = [(xs[i + 1] - xs[i], i) for i in range(len(xs) - 1)]
    split_x = None
    if gaps:
        gmax, gi = max(gaps, key=lambda t: t[0])
        if gmax >= thresh:
            split_x = xs[gi] + gmax * 0.5

    if split_x is None:
        return sorted(text_boxes, key=lambda b: (b.y0, b.x0))

    return _split_and_sort_by_x(text_boxes, split_x)


# -------------------------- Core --------------------------
def run_layout(det: YOLOv10, img: np.ndarray, imgsz: int, conf: float, device: str) -> List[BBox]:
    res = det.predict(img, imgsz=imgsz, conf=conf, device=device, verbose=False)
    r = res[0]
    names = r.names if hasattr(r, "names") and isinstance(r.names, dict) else {}
    boxes: List[BBox] = []
    try:
        xyxy = r.boxes.xyxy.cpu().numpy()
        cls_ids = r.boxes.cls.cpu().numpy().astype(int)
        confs = r.boxes.conf.cpu().numpy()
        for i in range(len(xyxy)):
            cname = str(names.get(int(cls_ids[i]), int(cls_ids[i]))).lower()
            x0, y0, x1, y1 = [float(v) for v in xyxy[i]]
            boxes.append(BBox(x0, y0, x1, y1, cname, float(confs[i])))
    except Exception as e:
        print("Failed to parse YOLO output: %s", e)
    return boxes

def save_page_image(img: np.ndarray, out_dir: Path, page_num: int) -> Path:
    ensure_dir(out_dir)
    p = out_dir / f"page_{page_num:03d}.png"
    Image.fromarray(img).save(p, format="PNG")
    return p


def _read_extraction_cfg() -> dict:
    """
    Use provided config_param(config_file, field) to read thresholds from rag_config.yaml.
    Expected YAML shape:
      profiles:
        <profile>:
          pdf_extract:
            extraction:
              column_layout: auto|vertical_two|natural
              two_col_gap_pct: 0.06
              vertical_two_split_pct: 0.5
    """
    try:
        cfg_path = Path(CONFIG_DIR) / "rag_config.yaml"
        profiles = config_param(config_file=str(cfg_path), field="profiles") or {}
        prof = profiles.get('eqp_manuals', {}) if isinstance(profiles, dict) else {}

        pdf_extract = prof.get("pdf_extract", {}) if isinstance(prof, dict) else {}
        extraction = pdf_extract.get("extraction", {}) if isinstance(pdf_extract, dict) else {}

        column_layout = (extraction.get("column_layout") or "auto").strip().lower()
        two_col_gap_pct = float(extraction.get("two_col_gap_pct", 0.06))
        vertical_two_split_pct = float(extraction.get("vertical_two_split_pct", 0.50))

        return {
            "column_layout": column_layout,
            "two_col_gap_pct": two_col_gap_pct,
            "vertical_two_split_pct": vertical_two_split_pct,
        }
    except Exception as e:
        print(f"Config read failed, using defaults: {e}")
        return {
            "column_layout": "auto",
            "two_col_gap_pct": 0.06,
            "vertical_two_split_pct": 0.50,
        }


def process_pdf(
    pdf_path: str,
    out_root: str,
    model_path: str,
    dpi: int = 200,
    device: str = "cpu",
    imgsz: int = 1024,
    det_conf: float = 0.25,
    tiny_figure_area_ratio: float = 0.002,  # skip figures smaller than 0.2% of page pixels
) -> Path:
    pdf = Path(pdf_path)
    if not pdf.exists():
        raise FileNotFoundError(str(pdf))

    # Load layout/threshold's config at once
    xcfg = _read_extraction_cfg()
    column_layout = xcfg["column_layout"]
    two_col_gap_pct = xcfg["two_col_gap_pct"]
    vertical_two_split_pct = xcfg["vertical_two_split_pct"]

    base = Path(out_root) / pdf.stem
    page_img_dir = base / "page_img"
    table_img_dir = base / "table_img"
    figure_img_dir = base / "figure_img"
    text_dir = base / "text"
    layout_dir = base / "layout"
    for d in (page_img_dir, table_img_dir, figure_img_dir, text_dir, layout_dir):
        ensure_dir(d)

    det = load_detector(model_path)

    manifest = {
        "pdf": str(pdf.resolve()),
        "dpi": dpi,
        "detector": {"model": str(model_path), "imgsz": imgsz, "conf": det_conf, "device": device},
        "extraction": {
            "column_layout": column_layout,
            "two_col_gap_pct": two_col_gap_pct,
            "vertical_two_split_pct": vertical_two_split_pct,
            "tiny_figure_area_ratio": tiny_figure_area_ratio,
        },
        "pages": []
    }

    with fitz.open(str(pdf)) as doc:
        page_count = doc.page_count
        for i in range(page_count):
            page_num = i + 1
            page, img, zoom = page_to_image(doc, i, dpi)
            h, w = img.shape[:2]
            page_area = float(h * w)

            # Store raster
            save_page_image(img, page_img_dir, page_num)

            # Detect layout
            boxes = run_layout(det, img, imgsz=imgsz, conf=det_conf, device=device)

            # Persist raw layout
            layout_json_path = layout_dir / f"page_{page_num:03d}.json"
            layout_json_path.write_text(json.dumps([b.__dict__ for b in boxes], indent=2), encoding="utf-8")

            # Optional annotated image
            ann = img.copy()
            for b in boxes:
                color = (0, 255, 0) if "table" in b.cls else (255, 0, 0) if ("figure" in b.cls or "chart" in b.cls) else (0, 0, 255)
                cv2.rectangle(ann, (int(b.x0), int(b.y0)), (int(b.x1), int(b.y1)), color, 2)
                cv2.putText(ann, f"{b.cls}:{b.score:.2f}", (int(b.x0), max(0, int(b.y0) - 4)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA)
            cv2.imwrite(str(layout_dir / f"page_{page_num:03d}_annotated.jpg"), ann)

            # --- Export tables and figures ---
            table_count = 0
            figure_count = 0

            # Sort by top-left y then x for deterministic numbering
            for b in sorted(boxes, key=lambda b: (b.y0, b.x0)):
                cls = b.cls
                # Ignore abandon class completely
                if "abandon" in cls:
                    continue

                # TABLE
                if "table" in cls:
                    table_count += 1
                    pts = to_pts((b.x0, b.y0, b.x1, b.y1), zoom)
                    cap = find_nearby_caption_text(page, pts) or f"page_{page_num:03d}_table{table_count:02d}"
                    if not cap.lower().startswith("page_"):
                        fname = f"page_{page_num:03d}_{cap}"
                    else:
                        fname = cap
                    crop = crop_image(img, b)
                    out_path = table_img_dir / f"{fname}.png"
                    crop.save(out_path, format="PNG")

                # FIGURE (non-table, non-text)
                elif ("figure" in cls) or ("chart" in cls) or ("image" in cls) or ("picture" in cls):
                    if b.area / page_area < float(tiny_figure_area_ratio):
                        continue
                    figure_count += 1
                    crop = crop_image(img, b)
                    out_path = figure_img_dir / f"page_{page_num:03d}_figure{figure_count:02d}.png"
                    crop.save(out_path, format="PNG")

            # --- Extract title vs. plain_text as Markdown ---
            text_items: List[Tuple[float, str]] = []  # (y_center_px, markdown_line)

            # Order text boxes according to mode + thresholds from config
            ordered_text_boxes = order_text_boxes(
                boxes, w,
                mode=column_layout,
                two_col_gap_pct=two_col_gap_pct,
                vertical_two_split_pct=vertical_two_split_pct,
            )

            for b in ordered_text_boxes:
                cls = b.cls
                if ("title" in cls) or ("heading" in cls):
                    pts = to_pts((b.x0, b.y0, b.x1, b.y1), zoom)
                    content = clip_text(page, pts).strip()
                    if content:
                        text_items.append((y_center(b), f"# {content}"))
                elif ("plain" in cls and "text" in cls) or ("paragraph" in cls) or (cls == "text"):
                    pts = to_pts((b.x0, b.y0, b.x1, b.y1), zoom)
                    content = clip_text(page, pts).strip()
                    if content:
                        text_items.append((y_center(b), content))

            # Fallback: if nothing captured, store the whole page text so we don't lose content
            if not text_items:
                whole = page.get_text("text").strip()
                if whole:
                    text_items.append((h / 2.0, whole))

            # Preserve chosen column order when writing a Markdown
            md = "\n\n".join(t for _, t in text_items).strip()
            (text_dir / f"page_{page_num:03d}.md").write_text(md, encoding="utf-8")

            manifest["pages"].append({
                "page": page_num,
                "tables": table_count,
                "figures": figure_count,
                "text_file": f"text/page_{page_num:03d}.md"
            })

    # Save manifest at the end
    out_dir = base
    (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8")
    print("Extraction complete: %s", out_dir)
    return out_dir


if __name__ == "__main__":
    # Load profile config from rag_config.yaml (hard-coded profile: eqp_manuals)
    cfg_path = Path(CONFIG_DIR) / "rag_config.yaml"
    profiles = config_param(config_file=str(cfg_path), field="profiles") or {}
    profile_name = "eqp_manuals"  # ← hard-coded as requested
    prof = profiles.get(profile_name, {}) if isinstance(profiles, dict) else {}
    pdfx = prof.get("pdf_extract", {}) if isinstance(prof, dict) else {}

    # Pull values with safe fallbacks
    MODEL_PATH = str(pdfx.get("model_path"))
    PDF_PATH   = str(pdfx.get("pdf_path"))
    OUT_ROOT   = str(pdfx.get("out_root"))
    DEVICE     = str(pdfx.get("device", "cpu"))
    DPI        = int(pdfx.get("dpi"))
    IMGSZ      = int(pdfx.get("imgsz"))
    CONF       = float(pdfx.get("det_conf"))
    TINY_FIG   = float(pdfx.get("tiny_figure_area_ratio"))

    print(f"[pdf_extract] profile={profile_name}")
    print(f"  model_path={MODEL_PATH}")
    print(f"  pdf_path={PDF_PATH}")
    print(f"  out_root={OUT_ROOT}")
    print(f"  device={DEVICE}, dpi={DPI}, imgsz={IMGSZ}, det_conf={CONF}, tiny_figure_area_ratio={TINY_FIG}")

    process_pdf(
        pdf_path=PDF_PATH,
        out_root=OUT_ROOT,
        model_path=MODEL_PATH,
        dpi=DPI,
        device=DEVICE,
        imgsz=IMGSZ,
        det_conf=CONF,
        tiny_figure_area_ratio=TINY_FIG,
    )
    print("Extraction complete")
