Improved chunking strategy.
This commit is contained in:
@@ -16,6 +16,11 @@
|
||||
"from openai import OpenAI\n",
|
||||
"import gradio as gr\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import List, Tuple\n",
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# ---- load env ----\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"\n",
|
||||
@@ -101,21 +106,51 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def chunk_text(text: str, size: int = 900, overlap: int = 150) -> List[str]:\n",
|
||||
" \"\"\"Greedy fixed-size chunking with overlap (simple & fast).\"\"\"\n",
|
||||
" out, i, n = [], 0, len(text)\n",
|
||||
" while i < n:\n",
|
||||
" j = min(i + size, n)\n",
|
||||
" out.append(text[i:j])\n",
|
||||
" i = max(j - overlap, j)\n",
|
||||
"\n",
|
||||
"# MiniLM embedding model & tokenizer (BERT WordPiece)\n",
|
||||
"EMBED_MODEL_NAME = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
|
||||
"\n",
|
||||
"# Use the model's practical window with 50% overlap\n",
|
||||
"MAX_TOKENS = 256 # all-MiniLM-L6-v2 effective limit used by Sentence-Transformers\n",
|
||||
"OVERLAP_RATIO = 0.50 # 50% sliding window overlap\n",
|
||||
"\n",
|
||||
"TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_NAME)\n",
|
||||
"\n",
|
||||
"def chunk_text(\n",
|
||||
" text: str,\n",
|
||||
" tokenizer: AutoTokenizer = TOKENIZER,\n",
|
||||
" max_tokens: int = MAX_TOKENS,\n",
|
||||
" overlap_ratio: float = OVERLAP_RATIO,\n",
|
||||
") -> List[str]:\n",
|
||||
" \"\"\"\n",
|
||||
" Token-aware sliding window chunking for MiniLM.\n",
|
||||
" - Windows of `max_tokens`\n",
|
||||
" - Step = max_tokens * (1 - overlap_ratio) -> 50% overlap by default\n",
|
||||
" \"\"\"\n",
|
||||
" ids = tokenizer.encode(text, add_special_tokens=False)\n",
|
||||
" if not ids:\n",
|
||||
" return []\n",
|
||||
"\n",
|
||||
" step = max(1, int(max_tokens * (1.0 - overlap_ratio)))\n",
|
||||
" out: List[str] = []\n",
|
||||
" for start in range(0, len(ids), step):\n",
|
||||
" window = ids[start : start + max_tokens]\n",
|
||||
" if not window:\n",
|
||||
" break\n",
|
||||
" toks = tokenizer.convert_ids_to_tokens(window)\n",
|
||||
" chunk = tokenizer.convert_tokens_to_string(toks).strip()\n",
|
||||
" if chunk:\n",
|
||||
" out.append(chunk)\n",
|
||||
" if start + max_tokens >= len(ids):\n",
|
||||
" break\n",
|
||||
" return out\n",
|
||||
"\n",
|
||||
"def load_bare_acts(root: str = \"knowledge_base/bare_acts\") -> List[Tuple[str, str]]:\n",
|
||||
" \"\"\"Return list of (source_id, text). source_id is filename stem.\"\"\"\n",
|
||||
" \"\"\"Return list of (source_id, text). `source_id` is filename stem.\"\"\"\n",
|
||||
" base = Path(root)\n",
|
||||
" if not base.exists():\n",
|
||||
" raise FileNotFoundError(f\"Folder not found: {base.resolve()}\")\n",
|
||||
" pairs = []\n",
|
||||
" pairs: List[Tuple[str, str]] = []\n",
|
||||
" for p in sorted(base.glob(\"*.txt\")):\n",
|
||||
" pairs.append((p.stem, p.read_text(encoding=\"utf-8\")))\n",
|
||||
" if not pairs:\n",
|
||||
@@ -123,7 +158,8 @@
|
||||
" return pairs\n",
|
||||
"\n",
|
||||
"acts_raw = load_bare_acts()\n",
|
||||
"print(\"Bare Acts loaded:\", [s for s,_ in acts_raw])\n"
|
||||
"print(\"Bare Acts loaded:\", [s for s, _ in acts_raw])\n",
|
||||
"print(f\"Chunking → max_tokens={MAX_TOKENS}, overlap={int(OVERLAP_RATIO*100)}%\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -136,18 +172,32 @@
|
||||
"import chromadb\n",
|
||||
"from chromadb import PersistentClient\n",
|
||||
"from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n",
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"from typing import Dict, List, Tuple\n",
|
||||
"\n",
|
||||
"class BareActsIndex:\n",
|
||||
" \"\"\"Owns the vector DB lifecycle & retrieval.\"\"\"\n",
|
||||
" def __init__(self, db_path: str = \"vector_db\", collection: str = \"bare_acts\",\n",
|
||||
" embed_model: str = \"sentence-transformers/all-MiniLM-L6-v2\"):\n",
|
||||
" \"\"\"Owns the vector DB lifecycle & retrieval (token-aware chunking).\"\"\"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" db_path: str = \"vector_db\",\n",
|
||||
" collection: str = \"bare_acts\",\n",
|
||||
" embed_model: str = EMBED_MODEL_NAME,\n",
|
||||
" max_tokens: int = MAX_TOKENS,\n",
|
||||
" overlap_ratio: float = OVERLAP_RATIO,\n",
|
||||
" ):\n",
|
||||
" self.db_path = db_path\n",
|
||||
" self.collection_name = collection\n",
|
||||
" self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=embed_model)\n",
|
||||
" self.embed_model = embed_model\n",
|
||||
" self.max_tokens = max_tokens\n",
|
||||
" self.overlap_ratio = overlap_ratio\n",
|
||||
"\n",
|
||||
" self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=self.embed_model)\n",
|
||||
" self.tokenizer = AutoTokenizer.from_pretrained(self.embed_model)\n",
|
||||
"\n",
|
||||
" self.client: PersistentClient = PersistentClient(path=db_path)\n",
|
||||
" self.col = self.client.get_or_create_collection(\n",
|
||||
" name=self.collection_name,\n",
|
||||
" embedding_function=self.embed_fn\n",
|
||||
" embedding_function=self.embed_fn,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def rebuild(self, docs: List[Tuple[str, str]]):\n",
|
||||
@@ -156,19 +206,33 @@
|
||||
" self.client.delete_collection(self.collection_name)\n",
|
||||
" except Exception:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" self.col = self.client.get_or_create_collection(\n",
|
||||
" name=self.collection_name,\n",
|
||||
" embedding_function=self.embed_fn\n",
|
||||
" embedding_function=self.embed_fn,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" ids, texts, metas = [], [], []\n",
|
||||
" for src, text in docs:\n",
|
||||
" for idx, ch in enumerate(chunk_text(text)):\n",
|
||||
" for idx, ch in enumerate(\n",
|
||||
" chunk_text(\n",
|
||||
" text,\n",
|
||||
" tokenizer=self.tokenizer,\n",
|
||||
" max_tokens=self.max_tokens,\n",
|
||||
" overlap_ratio=self.overlap_ratio,\n",
|
||||
" )\n",
|
||||
" ):\n",
|
||||
" ids.append(f\"{src}-{idx}\")\n",
|
||||
" texts.append(ch)\n",
|
||||
" metas.append({\"source\": src, \"chunk_id\": idx})\n",
|
||||
" self.col.add(ids=ids, documents=texts, metadatas=metas)\n",
|
||||
" print(f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name}\")\n",
|
||||
"\n",
|
||||
" if ids:\n",
|
||||
" self.col.add(ids=ids, documents=texts, metadatas=metas)\n",
|
||||
"\n",
|
||||
" print(\n",
|
||||
" f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name} \"\n",
|
||||
" f\"(tokens/chunk={self.max_tokens}, overlap={int(self.overlap_ratio*100)}%)\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def query(self, q: str, k: int = 6) -> List[Dict]:\n",
|
||||
" res = self.col.query(query_texts=[q], n_results=k)\n",
|
||||
|
||||
Reference in New Issue
Block a user