From a2e081542e9cc7908a27c50cad35e2f8160917ad Mon Sep 17 00:00:00 2001 From: Nik Date: Sun, 26 Oct 2025 21:05:58 +0530 Subject: [PATCH] Improved chunking strategy. --- .../legal_qna_with_rag_on_bare_acts.ipynb | 102 ++++++++++++++---- 1 file changed, 83 insertions(+), 19 deletions(-) diff --git a/week5/community-contributions/legal_qna_with_rag_on_bare_acts/legal_qna_with_rag_on_bare_acts.ipynb b/week5/community-contributions/legal_qna_with_rag_on_bare_acts/legal_qna_with_rag_on_bare_acts.ipynb index 313ea6f..1162771 100644 --- a/week5/community-contributions/legal_qna_with_rag_on_bare_acts/legal_qna_with_rag_on_bare_acts.ipynb +++ b/week5/community-contributions/legal_qna_with_rag_on_bare_acts/legal_qna_with_rag_on_bare_acts.ipynb @@ -16,6 +16,11 @@ "from openai import OpenAI\n", "import gradio as gr\n", "\n", + "from pathlib import Path\n", + "from typing import List, Tuple\n", + "from transformers import AutoTokenizer\n", + "\n", + "\n", "# ---- load env ----\n", "load_dotenv(override=True)\n", "\n", @@ -101,21 +106,51 @@ "metadata": {}, "outputs": [], "source": [ - "def chunk_text(text: str, size: int = 900, overlap: int = 150) -> List[str]:\n", - " \"\"\"Greedy fixed-size chunking with overlap (simple & fast).\"\"\"\n", - " out, i, n = [], 0, len(text)\n", - " while i < n:\n", - " j = min(i + size, n)\n", - " out.append(text[i:j])\n", - " i = max(j - overlap, j)\n", + "\n", + "# MiniLM embedding model & tokenizer (BERT WordPiece)\n", + "EMBED_MODEL_NAME = \"sentence-transformers/all-MiniLM-L6-v2\"\n", + "\n", + "# Use the model's practical window with 50% overlap\n", + "MAX_TOKENS = 256 # all-MiniLM-L6-v2 effective limit used by Sentence-Transformers\n", + "OVERLAP_RATIO = 0.50 # 50% sliding window overlap\n", + "\n", + "TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_NAME)\n", + "\n", + "def chunk_text(\n", + " text: str,\n", + " tokenizer: AutoTokenizer = TOKENIZER,\n", + " max_tokens: int = MAX_TOKENS,\n", + " overlap_ratio: float = OVERLAP_RATIO,\n", + ") -> List[str]:\n", + " \"\"\"\n", + " Token-aware sliding window chunking for MiniLM.\n", + " - Windows of `max_tokens`\n", + " - Step = max_tokens * (1 - overlap_ratio) -> 50% overlap by default\n", + " \"\"\"\n", + " ids = tokenizer.encode(text, add_special_tokens=False)\n", + " if not ids:\n", + " return []\n", + "\n", + " step = max(1, int(max_tokens * (1.0 - overlap_ratio)))\n", + " out: List[str] = []\n", + " for start in range(0, len(ids), step):\n", + " window = ids[start : start + max_tokens]\n", + " if not window:\n", + " break\n", + " toks = tokenizer.convert_ids_to_tokens(window)\n", + " chunk = tokenizer.convert_tokens_to_string(toks).strip()\n", + " if chunk:\n", + " out.append(chunk)\n", + " if start + max_tokens >= len(ids):\n", + " break\n", " return out\n", "\n", "def load_bare_acts(root: str = \"knowledge_base/bare_acts\") -> List[Tuple[str, str]]:\n", - " \"\"\"Return list of (source_id, text). source_id is filename stem.\"\"\"\n", + " \"\"\"Return list of (source_id, text). `source_id` is filename stem.\"\"\"\n", " base = Path(root)\n", " if not base.exists():\n", " raise FileNotFoundError(f\"Folder not found: {base.resolve()}\")\n", - " pairs = []\n", + " pairs: List[Tuple[str, str]] = []\n", " for p in sorted(base.glob(\"*.txt\")):\n", " pairs.append((p.stem, p.read_text(encoding=\"utf-8\")))\n", " if not pairs:\n", @@ -123,7 +158,8 @@ " return pairs\n", "\n", "acts_raw = load_bare_acts()\n", - "print(\"Bare Acts loaded:\", [s for s,_ in acts_raw])\n" + "print(\"Bare Acts loaded:\", [s for s, _ in acts_raw])\n", + "print(f\"Chunking → max_tokens={MAX_TOKENS}, overlap={int(OVERLAP_RATIO*100)}%\")\n" ] }, { @@ -136,18 +172,32 @@ "import chromadb\n", "from chromadb import PersistentClient\n", "from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n", + "from transformers import AutoTokenizer\n", + "from typing import Dict, List, Tuple\n", "\n", "class BareActsIndex:\n", - " \"\"\"Owns the vector DB lifecycle & retrieval.\"\"\"\n", - " def __init__(self, db_path: str = \"vector_db\", collection: str = \"bare_acts\",\n", - " embed_model: str = \"sentence-transformers/all-MiniLM-L6-v2\"):\n", + " \"\"\"Owns the vector DB lifecycle & retrieval (token-aware chunking).\"\"\"\n", + " def __init__(\n", + " self,\n", + " db_path: str = \"vector_db\",\n", + " collection: str = \"bare_acts\",\n", + " embed_model: str = EMBED_MODEL_NAME,\n", + " max_tokens: int = MAX_TOKENS,\n", + " overlap_ratio: float = OVERLAP_RATIO,\n", + " ):\n", " self.db_path = db_path\n", " self.collection_name = collection\n", - " self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=embed_model)\n", + " self.embed_model = embed_model\n", + " self.max_tokens = max_tokens\n", + " self.overlap_ratio = overlap_ratio\n", + "\n", + " self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=self.embed_model)\n", + " self.tokenizer = AutoTokenizer.from_pretrained(self.embed_model)\n", + "\n", " self.client: PersistentClient = PersistentClient(path=db_path)\n", " self.col = self.client.get_or_create_collection(\n", " name=self.collection_name,\n", - " embedding_function=self.embed_fn\n", + " embedding_function=self.embed_fn,\n", " )\n", "\n", " def rebuild(self, docs: List[Tuple[str, str]]):\n", @@ -156,19 +206,33 @@ " self.client.delete_collection(self.collection_name)\n", " except Exception:\n", " pass\n", + "\n", " self.col = self.client.get_or_create_collection(\n", " name=self.collection_name,\n", - " embedding_function=self.embed_fn\n", + " embedding_function=self.embed_fn,\n", " )\n", "\n", " ids, texts, metas = [], [], []\n", " for src, text in docs:\n", - " for idx, ch in enumerate(chunk_text(text)):\n", + " for idx, ch in enumerate(\n", + " chunk_text(\n", + " text,\n", + " tokenizer=self.tokenizer,\n", + " max_tokens=self.max_tokens,\n", + " overlap_ratio=self.overlap_ratio,\n", + " )\n", + " ):\n", " ids.append(f\"{src}-{idx}\")\n", " texts.append(ch)\n", " metas.append({\"source\": src, \"chunk_id\": idx})\n", - " self.col.add(ids=ids, documents=texts, metadatas=metas)\n", - " print(f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name}\")\n", + "\n", + " if ids:\n", + " self.col.add(ids=ids, documents=texts, metadatas=metas)\n", + "\n", + " print(\n", + " f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name} \"\n", + " f\"(tokens/chunk={self.max_tokens}, overlap={int(self.overlap_ratio*100)}%)\"\n", + " )\n", "\n", " def query(self, q: str, k: int = 6) -> List[Dict]:\n", " res = self.col.query(query_texts=[q], n_results=k)\n",