# Created by aaronkueh on 10/17/2025
# aom/tools/rag/rag_helper.py
from __future__ import annotations
import torch
from sentence_transformers import SentenceTransformer, models
from pathlib import Path
from typing import Iterable, Union, List, Set, Any, Dict
from transformers import AutoTokenizer, PreTrainedTokenizerBase
import threading, os
from qdrant_client import QdrantClient
from aom.utils.utilities import config_param
from aom.definitions import CONFIG_DIR


ModelPath = Union[str, Path]
_TOKENIZER_CACHE: dict[str, PreTrainedTokenizerBase] = {}
_TOK_LOCK = threading.RLock()


def load_rag_profile(config: str) -> dict:
    cfg_path = f'{CONFIG_DIR}/{config}'
    profiles = config_param(config_file=str(cfg_path), field="profiles") or {}
    prof = profiles.get("eqp_manuals", {}) if isinstance(profiles, dict) else {}
    if not isinstance(prof, dict):
        raise ValueError("Invalid 'profiles.eqp_manuals' structure in rag_config.yaml.")
    return prof


def attach_helpers(st: SentenceTransformer) -> None:
    """
    Ensure the model has:
      - encode_query(texts, **kw)
      - encode_document(texts, **kw)
    Both normalize embeddings (good for COSINE in Qdrant).
    """
    if not hasattr(st, "encode_query"):
        SentenceTransformer.encode_query = lambda self, x, **kw: self.encode(
            x, normalize_embeddings=True, **kw
        )
    if not hasattr(st, "encode_document"):
        SentenceTransformer.encode_document = lambda self, x, **kw: self.encode(
            x, normalize_embeddings=True, **kw
        )

def load_local_st(model_dir: str, device: str, max_seq_len: int) -> SentenceTransformer:
    """
    Loads a SentenceTransformer model instance from a local directory. If loading directly fails, a fallback
    mechanism is implemented to initialize the model using specific transformer and pooling components.

    This function also attaches helper methods to the SentenceTransformer instance and sets the maximum
    sequence length for the model.

    :param model_dir: The directory path where the model files are stored.
    :type model_dir: Str
    :param device: The computation device to be used (e.g., "cpu" or "cuda").
    :type device: Str
    :param max_seq_len: The maximum sequence length to be set for the model.
    :type max_seq_len: Int
    :return: Loaded and configured SentenceTransformer model.
    :rtype: SentenceTransformer
    """
    try:
        st = SentenceTransformer(model_dir, device=device)
    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading SentenceTransformer model from {model_dir}: {e}")
        word = models.Transformer(
            model_dir,
            model_args={"torch_dtype": torch.float32, "local_files_only": True, "low_cpu_mem_usage": True},
        )
        pool = models.Pooling(word.get_word_embedding_dimension(), pooling_mode_mean_tokens=True)
        st = SentenceTransformer(modules=[word, pool], device=device)
    attach_helpers(st)
    st.max_seq_length = max_seq_len
    return st


def get_tokenizer(model_dir: ModelPath) -> PreTrainedTokenizerBase:
    # Normalize: if it points to a real local path, resolve it; else keep as HF id
    if isinstance(model_dir, Path) or os.path.sep in str(model_dir) or "/" in str(model_dir):
        key = str(Path(model_dir).resolve())
    else:
        key = str(model_dir)  # likely an HF model id like "google/embeddinggemma-300m"

    if key in _TOKENIZER_CACHE:
        return _TOKENIZER_CACHE[key]

    with _TOK_LOCK:
        if key in _TOKENIZER_CACHE:
            return _TOKENIZER_CACHE[key]
        try:
            tok = AutoTokenizer.from_pretrained(key, use_fast=True, local_files_only=True)
        except Exception:
            tok = AutoTokenizer.from_pretrained(key, use_fast=True)
        _TOKENIZER_CACHE[key] = tok
        return tok

def count_tokens(text: str, model_dir: ModelPath) -> int:
    tok = get_tokenizer(model_dir)
    return len(tok(text, add_special_tokens=False)["input_ids"])


def count_tokens_batch(texts: Iterable[str], model_dir: ModelPath) -> List[int]:
    tok = get_tokenizer(model_dir)
    enc = tok(list(texts), add_special_tokens=False, padding=False, truncation=False)
    return [len(ids) for ids in enc["input_ids"]]


def list_doc_ids(client: QdrantClient, collection: str, batch: int = 1000) -> List[str]:
    """Scroll the collection and return a sorted list of distinct payload['doc_id'] values."""
    doc_ids: Set[str] = set()
    next_page = None
    while True:
        points, next_page = client.scroll(
            collection_name=collection,
            scroll_filter=None,
            limit=batch,
            with_payload=True,
            with_vectors=False,
            offset=next_page,
        )
        if not points:
            break
        for p in points:
            payload: Dict[str, Any] = p.payload or {}
            d = payload.get("doc_id")
            if d:
                doc_ids.add(str(d))
        if not next_page:
            break
    return sorted(doc_ids)
