Improved chunking strategy.

This commit is contained in:
Nik
2025-10-26 21:05:58 +05:30
parent d3318988a9
commit a2e081542e

View File

@@ -16,6 +16,11 @@
"from openai import OpenAI\n", "from openai import OpenAI\n",
"import gradio as gr\n", "import gradio as gr\n",
"\n", "\n",
"from pathlib import Path\n",
"from typing import List, Tuple\n",
"from transformers import AutoTokenizer\n",
"\n",
"\n",
"# ---- load env ----\n", "# ---- load env ----\n",
"load_dotenv(override=True)\n", "load_dotenv(override=True)\n",
"\n", "\n",
@@ -101,21 +106,51 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def chunk_text(text: str, size: int = 900, overlap: int = 150) -> List[str]:\n", "\n",
" \"\"\"Greedy fixed-size chunking with overlap (simple & fast).\"\"\"\n", "# MiniLM embedding model & tokenizer (BERT WordPiece)\n",
" out, i, n = [], 0, len(text)\n", "EMBED_MODEL_NAME = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
" while i < n:\n", "\n",
" j = min(i + size, n)\n", "# Use the model's practical window with 50% overlap\n",
" out.append(text[i:j])\n", "MAX_TOKENS = 256 # all-MiniLM-L6-v2 effective limit used by Sentence-Transformers\n",
" i = max(j - overlap, j)\n", "OVERLAP_RATIO = 0.50 # 50% sliding window overlap\n",
"\n",
"TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_NAME)\n",
"\n",
"def chunk_text(\n",
" text: str,\n",
" tokenizer: AutoTokenizer = TOKENIZER,\n",
" max_tokens: int = MAX_TOKENS,\n",
" overlap_ratio: float = OVERLAP_RATIO,\n",
") -> List[str]:\n",
" \"\"\"\n",
" Token-aware sliding window chunking for MiniLM.\n",
" - Windows of `max_tokens`\n",
" - Step = max_tokens * (1 - overlap_ratio) -> 50% overlap by default\n",
" \"\"\"\n",
" ids = tokenizer.encode(text, add_special_tokens=False)\n",
" if not ids:\n",
" return []\n",
"\n",
" step = max(1, int(max_tokens * (1.0 - overlap_ratio)))\n",
" out: List[str] = []\n",
" for start in range(0, len(ids), step):\n",
" window = ids[start : start + max_tokens]\n",
" if not window:\n",
" break\n",
" toks = tokenizer.convert_ids_to_tokens(window)\n",
" chunk = tokenizer.convert_tokens_to_string(toks).strip()\n",
" if chunk:\n",
" out.append(chunk)\n",
" if start + max_tokens >= len(ids):\n",
" break\n",
" return out\n", " return out\n",
"\n", "\n",
"def load_bare_acts(root: str = \"knowledge_base/bare_acts\") -> List[Tuple[str, str]]:\n", "def load_bare_acts(root: str = \"knowledge_base/bare_acts\") -> List[Tuple[str, str]]:\n",
" \"\"\"Return list of (source_id, text). source_id is filename stem.\"\"\"\n", " \"\"\"Return list of (source_id, text). `source_id` is filename stem.\"\"\"\n",
" base = Path(root)\n", " base = Path(root)\n",
" if not base.exists():\n", " if not base.exists():\n",
" raise FileNotFoundError(f\"Folder not found: {base.resolve()}\")\n", " raise FileNotFoundError(f\"Folder not found: {base.resolve()}\")\n",
" pairs = []\n", " pairs: List[Tuple[str, str]] = []\n",
" for p in sorted(base.glob(\"*.txt\")):\n", " for p in sorted(base.glob(\"*.txt\")):\n",
" pairs.append((p.stem, p.read_text(encoding=\"utf-8\")))\n", " pairs.append((p.stem, p.read_text(encoding=\"utf-8\")))\n",
" if not pairs:\n", " if not pairs:\n",
@@ -123,7 +158,8 @@
" return pairs\n", " return pairs\n",
"\n", "\n",
"acts_raw = load_bare_acts()\n", "acts_raw = load_bare_acts()\n",
"print(\"Bare Acts loaded:\", [s for s,_ in acts_raw])\n" "print(\"Bare Acts loaded:\", [s for s, _ in acts_raw])\n",
"print(f\"Chunking → max_tokens={MAX_TOKENS}, overlap={int(OVERLAP_RATIO*100)}%\")\n"
] ]
}, },
{ {
@@ -136,18 +172,32 @@
"import chromadb\n", "import chromadb\n",
"from chromadb import PersistentClient\n", "from chromadb import PersistentClient\n",
"from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n", "from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n",
"from transformers import AutoTokenizer\n",
"from typing import Dict, List, Tuple\n",
"\n", "\n",
"class BareActsIndex:\n", "class BareActsIndex:\n",
" \"\"\"Owns the vector DB lifecycle & retrieval.\"\"\"\n", " \"\"\"Owns the vector DB lifecycle & retrieval (token-aware chunking).\"\"\"\n",
" def __init__(self, db_path: str = \"vector_db\", collection: str = \"bare_acts\",\n", " def __init__(\n",
" embed_model: str = \"sentence-transformers/all-MiniLM-L6-v2\"):\n", " self,\n",
" db_path: str = \"vector_db\",\n",
" collection: str = \"bare_acts\",\n",
" embed_model: str = EMBED_MODEL_NAME,\n",
" max_tokens: int = MAX_TOKENS,\n",
" overlap_ratio: float = OVERLAP_RATIO,\n",
" ):\n",
" self.db_path = db_path\n", " self.db_path = db_path\n",
" self.collection_name = collection\n", " self.collection_name = collection\n",
" self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=embed_model)\n", " self.embed_model = embed_model\n",
" self.max_tokens = max_tokens\n",
" self.overlap_ratio = overlap_ratio\n",
"\n",
" self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=self.embed_model)\n",
" self.tokenizer = AutoTokenizer.from_pretrained(self.embed_model)\n",
"\n",
" self.client: PersistentClient = PersistentClient(path=db_path)\n", " self.client: PersistentClient = PersistentClient(path=db_path)\n",
" self.col = self.client.get_or_create_collection(\n", " self.col = self.client.get_or_create_collection(\n",
" name=self.collection_name,\n", " name=self.collection_name,\n",
" embedding_function=self.embed_fn\n", " embedding_function=self.embed_fn,\n",
" )\n", " )\n",
"\n", "\n",
" def rebuild(self, docs: List[Tuple[str, str]]):\n", " def rebuild(self, docs: List[Tuple[str, str]]):\n",
@@ -156,19 +206,33 @@
" self.client.delete_collection(self.collection_name)\n", " self.client.delete_collection(self.collection_name)\n",
" except Exception:\n", " except Exception:\n",
" pass\n", " pass\n",
"\n",
" self.col = self.client.get_or_create_collection(\n", " self.col = self.client.get_or_create_collection(\n",
" name=self.collection_name,\n", " name=self.collection_name,\n",
" embedding_function=self.embed_fn\n", " embedding_function=self.embed_fn,\n",
" )\n", " )\n",
"\n", "\n",
" ids, texts, metas = [], [], []\n", " ids, texts, metas = [], [], []\n",
" for src, text in docs:\n", " for src, text in docs:\n",
" for idx, ch in enumerate(chunk_text(text)):\n", " for idx, ch in enumerate(\n",
" chunk_text(\n",
" text,\n",
" tokenizer=self.tokenizer,\n",
" max_tokens=self.max_tokens,\n",
" overlap_ratio=self.overlap_ratio,\n",
" )\n",
" ):\n",
" ids.append(f\"{src}-{idx}\")\n", " ids.append(f\"{src}-{idx}\")\n",
" texts.append(ch)\n", " texts.append(ch)\n",
" metas.append({\"source\": src, \"chunk_id\": idx})\n", " metas.append({\"source\": src, \"chunk_id\": idx})\n",
"\n",
" if ids:\n",
" self.col.add(ids=ids, documents=texts, metadatas=metas)\n", " self.col.add(ids=ids, documents=texts, metadatas=metas)\n",
" print(f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name}\")\n", "\n",
" print(\n",
" f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name} \"\n",
" f\"(tokens/chunk={self.max_tokens}, overlap={int(self.overlap_ratio*100)}%)\"\n",
" )\n",
"\n", "\n",
" def query(self, q: str, k: int = 6) -> List[Dict]:\n", " def query(self, q: str, k: int = 6) -> List[Dict]:\n",
" res = self.col.query(query_texts=[q], n_results=k)\n", " res = self.col.query(query_texts=[q], n_results=k)\n",