Improved chunking strategy.

This commit is contained in:
Nik
2025-10-26 21:05:58 +05:30
parent d3318988a9
commit a2e081542e

View File

@@ -16,6 +16,11 @@
"from openai import OpenAI\n",
"import gradio as gr\n",
"\n",
"from pathlib import Path\n",
"from typing import List, Tuple\n",
"from transformers import AutoTokenizer\n",
"\n",
"\n",
"# ---- load env ----\n",
"load_dotenv(override=True)\n",
"\n",
@@ -101,21 +106,51 @@
"metadata": {},
"outputs": [],
"source": [
"def chunk_text(text: str, size: int = 900, overlap: int = 150) -> List[str]:\n",
" \"\"\"Greedy fixed-size chunking with overlap (simple & fast).\"\"\"\n",
" out, i, n = [], 0, len(text)\n",
" while i < n:\n",
" j = min(i + size, n)\n",
" out.append(text[i:j])\n",
" i = max(j - overlap, j)\n",
"\n",
"# MiniLM embedding model & tokenizer (BERT WordPiece)\n",
"EMBED_MODEL_NAME = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
"\n",
"# Use the model's practical window with 50% overlap\n",
"MAX_TOKENS = 256 # all-MiniLM-L6-v2 effective limit used by Sentence-Transformers\n",
"OVERLAP_RATIO = 0.50 # 50% sliding window overlap\n",
"\n",
"TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_NAME)\n",
"\n",
"def chunk_text(\n",
" text: str,\n",
" tokenizer: AutoTokenizer = TOKENIZER,\n",
" max_tokens: int = MAX_TOKENS,\n",
" overlap_ratio: float = OVERLAP_RATIO,\n",
") -> List[str]:\n",
" \"\"\"\n",
" Token-aware sliding window chunking for MiniLM.\n",
" - Windows of `max_tokens`\n",
" - Step = max_tokens * (1 - overlap_ratio) -> 50% overlap by default\n",
" \"\"\"\n",
" ids = tokenizer.encode(text, add_special_tokens=False)\n",
" if not ids:\n",
" return []\n",
"\n",
" step = max(1, int(max_tokens * (1.0 - overlap_ratio)))\n",
" out: List[str] = []\n",
" for start in range(0, len(ids), step):\n",
" window = ids[start : start + max_tokens]\n",
" if not window:\n",
" break\n",
" toks = tokenizer.convert_ids_to_tokens(window)\n",
" chunk = tokenizer.convert_tokens_to_string(toks).strip()\n",
" if chunk:\n",
" out.append(chunk)\n",
" if start + max_tokens >= len(ids):\n",
" break\n",
" return out\n",
"\n",
"def load_bare_acts(root: str = \"knowledge_base/bare_acts\") -> List[Tuple[str, str]]:\n",
" \"\"\"Return list of (source_id, text). source_id is filename stem.\"\"\"\n",
" \"\"\"Return list of (source_id, text). `source_id` is filename stem.\"\"\"\n",
" base = Path(root)\n",
" if not base.exists():\n",
" raise FileNotFoundError(f\"Folder not found: {base.resolve()}\")\n",
" pairs = []\n",
" pairs: List[Tuple[str, str]] = []\n",
" for p in sorted(base.glob(\"*.txt\")):\n",
" pairs.append((p.stem, p.read_text(encoding=\"utf-8\")))\n",
" if not pairs:\n",
@@ -123,7 +158,8 @@
" return pairs\n",
"\n",
"acts_raw = load_bare_acts()\n",
"print(\"Bare Acts loaded:\", [s for s,_ in acts_raw])\n"
"print(\"Bare Acts loaded:\", [s for s, _ in acts_raw])\n",
"print(f\"Chunking → max_tokens={MAX_TOKENS}, overlap={int(OVERLAP_RATIO*100)}%\")\n"
]
},
{
@@ -136,18 +172,32 @@
"import chromadb\n",
"from chromadb import PersistentClient\n",
"from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n",
"from transformers import AutoTokenizer\n",
"from typing import Dict, List, Tuple\n",
"\n",
"class BareActsIndex:\n",
" \"\"\"Owns the vector DB lifecycle & retrieval.\"\"\"\n",
" def __init__(self, db_path: str = \"vector_db\", collection: str = \"bare_acts\",\n",
" embed_model: str = \"sentence-transformers/all-MiniLM-L6-v2\"):\n",
" \"\"\"Owns the vector DB lifecycle & retrieval (token-aware chunking).\"\"\"\n",
" def __init__(\n",
" self,\n",
" db_path: str = \"vector_db\",\n",
" collection: str = \"bare_acts\",\n",
" embed_model: str = EMBED_MODEL_NAME,\n",
" max_tokens: int = MAX_TOKENS,\n",
" overlap_ratio: float = OVERLAP_RATIO,\n",
" ):\n",
" self.db_path = db_path\n",
" self.collection_name = collection\n",
" self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=embed_model)\n",
" self.embed_model = embed_model\n",
" self.max_tokens = max_tokens\n",
" self.overlap_ratio = overlap_ratio\n",
"\n",
" self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=self.embed_model)\n",
" self.tokenizer = AutoTokenizer.from_pretrained(self.embed_model)\n",
"\n",
" self.client: PersistentClient = PersistentClient(path=db_path)\n",
" self.col = self.client.get_or_create_collection(\n",
" name=self.collection_name,\n",
" embedding_function=self.embed_fn\n",
" embedding_function=self.embed_fn,\n",
" )\n",
"\n",
" def rebuild(self, docs: List[Tuple[str, str]]):\n",
@@ -156,19 +206,33 @@
" self.client.delete_collection(self.collection_name)\n",
" except Exception:\n",
" pass\n",
"\n",
" self.col = self.client.get_or_create_collection(\n",
" name=self.collection_name,\n",
" embedding_function=self.embed_fn\n",
" embedding_function=self.embed_fn,\n",
" )\n",
"\n",
" ids, texts, metas = [], [], []\n",
" for src, text in docs:\n",
" for idx, ch in enumerate(chunk_text(text)):\n",
" for idx, ch in enumerate(\n",
" chunk_text(\n",
" text,\n",
" tokenizer=self.tokenizer,\n",
" max_tokens=self.max_tokens,\n",
" overlap_ratio=self.overlap_ratio,\n",
" )\n",
" ):\n",
" ids.append(f\"{src}-{idx}\")\n",
" texts.append(ch)\n",
" metas.append({\"source\": src, \"chunk_id\": idx})\n",
" self.col.add(ids=ids, documents=texts, metadatas=metas)\n",
" print(f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name}\")\n",
"\n",
" if ids:\n",
" self.col.add(ids=ids, documents=texts, metadatas=metas)\n",
"\n",
" print(\n",
" f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name} \"\n",
" f\"(tokens/chunk={self.max_tokens}, overlap={int(self.overlap_ratio*100)}%)\"\n",
" )\n",
"\n",
" def query(self, q: str, k: int = 6) -> List[Dict]:\n",
" res = self.col.query(query_texts=[q], n_results=k)\n",