Agentic RAG with Query and Context Expannsion

This commit is contained in:
Nik
2025-10-27 17:51:15 +05:30
parent 8faff0283b
commit d198b55038
4 changed files with 29365 additions and 0 deletions

View File

@@ -0,0 +1,940 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "d27544d4",
"metadata": {},
"outputs": [],
"source": [
"# Week 8 • Legal QnA RAG (Agentic)\n",
"\n",
"import os\n",
"from dataclasses import dataclass, field\n",
"from typing import List, Dict, Any, Tuple, Protocol\n",
"\n",
"from pathlib import Path\n",
"\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import chromadb\n",
"from chromadb import PersistentClient\n",
"from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n",
"\n",
"from transformers import AutoTokenizer\n",
"import gradio as gr\n",
"\n",
"load_dotenv(override=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efe4e4db",
"metadata": {},
"outputs": [],
"source": [
"@dataclass(frozen=True)\n",
"class Settings:\n",
" # Data\n",
" data_root: str = \"knowledge_base/bare_acts\" # same as W5\n",
" \n",
" # Vector store\n",
" db_path: str = \"vector_db_w8\"\n",
" collection: str = \"bare_acts\"\n",
" embed_model: str = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
" chunk_max_tokens: int = 256\n",
" chunk_overlap: float = 0.50\n",
"\n",
" # Retrieval\n",
" expansions: int = 5\n",
" topk_per_rewrite: int = 10\n",
" neighbor_radius: int = 2\n",
" max_blocks_for_llm: int = 20 # cap merged blocks for token control\n",
"\n",
" # LLM (default + selectable)\n",
" gen_model: str = \"gpt-4o-mini\" # default\n",
" # You can list any OpenAI-compatible IDs you have access to\n",
" selectable_models: tuple = (\"gpt-4o-mini\",)\n",
"\n",
" temperature: float = 0.2\n",
" max_tokens: int = 220\n",
"\n",
"SET = Settings()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e37dda9",
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Trace:\n",
" rewrites: List[str] = field(default_factory=list)\n",
" retrievals: List[Dict[str, Any]] = field(default_factory=list) # [{query, hits:[{source,chunk_id,start,end,preview}]}]\n",
" merged_blocks: List[Dict[str, Any]] = field(default_factory=list) # [{source,start,end,preview}]\n",
" notes: List[str] = field(default_factory=list)\n",
"\n",
" def as_markdown(self) -> str:\n",
" lines = []\n",
" if self.rewrites:\n",
" lines.append(\"#### 🔁 Query expansions\")\n",
" for i, q in enumerate(self.rewrites, 1):\n",
" lines.append(f\"{i}. `{q}`\")\n",
" lines.append(\"\")\n",
"\n",
" if self.retrievals:\n",
" lines.append(\"#### 🔎 Retrieval per rewrite (top-10 each)\")\n",
" for r in self.retrievals:\n",
" lines.append(f\"- **Rewrite:** `{r['query']}`\")\n",
" for h in r.get(\"hits\", []):\n",
" lines.append(\n",
" f\" - [{h['source']} #{h['chunk_id']}] {h['start']}:{h['end']} — {h['preview']}\"\n",
" )\n",
" lines.append(\"\")\n",
" \n",
" if self.merged_blocks:\n",
" lines.append(\"#### 🧩 Context expansion (± neighbors, merged)\")\n",
" for b in self.merged_blocks:\n",
" lines.append(f\"- [{b['source']}] {b['start']}:{b['end']} — {b['preview']}\")\n",
" lines.append(f\"\\n**Total merged blocks:** {len(self.merged_blocks)}\")\n",
" lines.append(\"\")\n",
" \n",
" if self.notes:\n",
" lines.append(\"#### 📝 Notes\")\n",
" for n in self.notes:\n",
" lines.append(f\"- {n}\")\n",
" lines.append(\"\")\n",
"\n",
" return \"\\n\".join(lines) if lines else \"_No logs._\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf7e8b41",
"metadata": {},
"outputs": [],
"source": [
"class LLM(Protocol):\n",
" def complete(\n",
" self,\n",
" system: str,\n",
" user: str,\n",
" temperature: float,\n",
" max_tokens: int,\n",
" model: str | None = None, # <- allow per-call model override\n",
" ) -> str: ...\n",
"\n",
"class Index(Protocol):\n",
" def rebuild(self, docs: List[Tuple[str, str]]) -> None: ...\n",
" def query(self, q: str, k: int) -> List[Dict]: ...\n",
" def get_by_ids(self, ids: List[str]) -> Tuple[List[str], List[Dict]]: ...\n",
" def raw_doc(self, src: str) -> str: ...\n",
" @property\n",
" def tokenizer(self) -> AutoTokenizer: ...\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30636b66",
"metadata": {},
"outputs": [],
"source": [
"class OpenAILLM(LLM):\n",
" def __init__(self, default_model: str):\n",
" self._client = OpenAI()\n",
" self._default_model = default_model\n",
"\n",
" def complete(\n",
" self,\n",
" system: str,\n",
" user: str,\n",
" temperature: float = 0.2,\n",
" max_tokens: int = 220,\n",
" model: str | None = None,\n",
" ) -> str:\n",
" model_id = model or self._default_model\n",
" resp = self._client.chat.completions.create(\n",
" model=model_id,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": system},\n",
" {\"role\": \"user\", \"content\": user},\n",
" ],\n",
" seed=42,\n",
" )\n",
" return (resp.choices[0].message.content or \"\").strip()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "86c83552",
"metadata": {},
"outputs": [],
"source": [
"# --- Streaming run logger (ADD) ---\n",
"from dataclasses import dataclass, field\n",
"from time import perf_counter\n",
"from html import escape\n",
"\n",
"@dataclass\n",
"class RunLogger:\n",
" lines: List[str] = field(default_factory=list)\n",
" t0: float = field(default_factory=perf_counter)\n",
"\n",
" def add(self, msg: str):\n",
" dt = perf_counter() - self.t0\n",
" self.lines.append(f\"[{dt:05.2f}s] {msg}\")\n",
"\n",
" def html(self) -> str:\n",
" # Simple monospace log panel\n",
" safe = [escape(x) for x in self.lines[-300:]] # cap last 300 lines\n",
" return \"<div style='font-family:ui-monospace,Menlo,Consolas;white-space:pre-wrap'>\" + \"<br/>\".join(safe) + \"</div>\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "545699c6",
"metadata": {},
"outputs": [],
"source": [
"def _one_line(text: str, limit: int = 120) -> str:\n",
" s = \" \".join(text.split())\n",
" return s[:limit] + \"…\" if len(s) > limit else s\n",
"\n",
"def _hit_view(h: Dict) -> Dict:\n",
" m = h[\"meta\"]\n",
" return {\n",
" \"source\": m[\"source\"],\n",
" \"chunk_id\": int(m[\"chunk_id\"]),\n",
" \"start\": int(m[\"start_tok\"]),\n",
" \"end\": int(m[\"end_tok\"]),\n",
" \"preview\": _one_line(h[\"text\"]),\n",
" }\n",
"\n",
"def _hit_id(h: Dict) -> str:\n",
" m = h[\"meta\"]\n",
" # Example: [bns 4720:4860 (#38)]\n",
" return f\"[{m['source']} {int(m['start_tok'])}:{int(m['end_tok'])} (#{int(m['chunk_id'])})]\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af537e05",
"metadata": {},
"outputs": [],
"source": [
"def load_bare_acts(root: str) -> List[Tuple[str, str]]:\n",
" \"\"\"Return (source_id, text). Each .txt is one bare act file.\"\"\"\n",
" base = Path(root)\n",
" if not base.exists():\n",
" raise FileNotFoundError(f\"Folder not found: {base.resolve()}\")\n",
" pairs: List[Tuple[str, str]] = []\n",
" for p in sorted(base.glob(\"*.txt\")):\n",
" pairs.append((p.stem, p.read_text(encoding=\"utf-8\")))\n",
" if not pairs:\n",
" raise RuntimeError(f\"No .txt files found under {base}\")\n",
" return pairs\n",
"\n",
"acts_raw = load_bare_acts(SET.data_root)\n",
"print(\"Bare Acts loaded:\", [s for s, _ in acts_raw][:5], \"…\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7eec89e4",
"metadata": {},
"outputs": [],
"source": [
"def chunk_text_with_spans(\n",
" text: str,\n",
" tokenizer: AutoTokenizer,\n",
" max_tokens: int,\n",
" overlap_ratio: float,\n",
") -> List[Tuple[int, int, int, int, str]]:\n",
" \"\"\"\n",
" Returns a list of chunks with both token and character spans:\n",
" (start_tok, end_tok, start_char, end_char, chunk_text)\n",
" \"\"\"\n",
" enc = tokenizer(\n",
" text,\n",
" add_special_tokens=False,\n",
" return_offsets_mapping=True,\n",
" )\n",
" ids = enc[\"input_ids\"]\n",
" offs = enc[\"offset_mapping\"] # list[(char_start, char_end)]\n",
" if not ids:\n",
" return []\n",
"\n",
" step = max(1, int(max_tokens * (1.0 - overlap_ratio)))\n",
" out = []\n",
" for start in range(0, len(ids), step):\n",
" end = min(start + max_tokens, len(ids))\n",
" if start >= end:\n",
" break\n",
" # token -> char\n",
" start_char = offs[start][0]\n",
" end_char = offs[end - 1][1]\n",
" chunk = text[start_char:end_char].strip()\n",
" if chunk:\n",
" out.append((start, end, start_char, end_char, chunk))\n",
" if end >= len(ids):\n",
" break\n",
" return out\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4862732b",
"metadata": {},
"outputs": [],
"source": [
"class ChromaBareActsIndex(Index):\n",
" def __init__(self, settings: Settings):\n",
" self.set = settings\n",
" self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=self.set.embed_model)\n",
" self._tokenizer = AutoTokenizer.from_pretrained(self.set.embed_model)\n",
" self._client: PersistentClient = PersistentClient(path=self.set.db_path)\n",
" self._col = self._client.get_or_create_collection(\n",
" name=self.set.collection,\n",
" embedding_function=self.embed_fn\n",
" )\n",
" self._doc_cache: Dict[str, str] = {}\n",
"\n",
" @property\n",
" def tokenizer(self) -> AutoTokenizer:\n",
" return self._tokenizer\n",
"\n",
" def rebuild(self, docs: List[Tuple[str, str]]) -> None:\n",
" self._doc_cache = {src: txt for src, txt in docs}\n",
" try:\n",
" self._client.delete_collection(self.set.collection)\n",
" except Exception:\n",
" pass\n",
" self._col = self._client.get_or_create_collection(\n",
" name=self.set.collection,\n",
" embedding_function=self.embed_fn\n",
" )\n",
"\n",
" ids, texts, metas = [], [], []\n",
" for src, text in docs:\n",
" spans = chunk_text_with_spans(\n",
" text, tokenizer=self._tokenizer,\n",
" max_tokens=self.set.chunk_max_tokens,\n",
" overlap_ratio=self.set.chunk_overlap\n",
" )\n",
" for idx, (start_tok, end_tok, start_char, end_char, ch) in enumerate(spans):\n",
" ids.append(f\"{src}-{idx}\")\n",
" texts.append(ch)\n",
" metas.append({\n",
" \"source\": src,\n",
" \"chunk_id\": idx,\n",
" \"start_tok\": int(start_tok),\n",
" \"end_tok\": int(end_tok),\n",
" \"start_char\": int(start_char),\n",
" \"end_char\": int(end_char),\n",
" })\n",
" if ids:\n",
" self._col.add(ids=ids, documents=texts, metadatas=metas)\n",
" print(f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.set.collection}\")\n",
"\n",
" def query(self, q: str, k: int) -> List[Dict]:\n",
" res = self._col.query(query_texts=[q], n_results=k)\n",
" docs = res.get(\"documents\", [[]])[0]\n",
" metas = res.get(\"metadatas\", [[]])[0]\n",
" ids = res.get(\"ids\", [[]])[0]\n",
" out = []\n",
" for _id, d, m in zip(ids, docs, metas):\n",
" if not m: \n",
" continue\n",
" m = dict(m)\n",
" m[\"id\"] = _id\n",
" out.append({\"text\": d, \"meta\": m})\n",
" return out\n",
"\n",
" def get_by_ids(self, ids: List[str]) -> Tuple[List[str], List[Dict]]:\n",
" \"\"\"\n",
" Return documents & metadatas in the SAME ORDER as the requested ids.\n",
" Compatible with Chroma clients that don't allow 'ids' in `include`.\n",
" \"\"\"\n",
" if not ids:\n",
" return [], []\n",
"\n",
" # 'ids' is NOT a valid item in `include`; ask only for docs+metas\n",
" got = self._col.get(ids=ids, include=[\"documents\", \"metadatas\"])\n",
"\n",
" ret_ids = (got.get(\"ids\") or [])\n",
" docs = (got.get(\"documents\") or [])\n",
" metas = (got.get(\"metadatas\") or [])\n",
"\n",
" # If the client returns 'ids', reorder using it\n",
" if ret_ids:\n",
" table = {i: (d, m) for i, d, m in zip(ret_ids, docs, metas)}\n",
" ordered_docs, ordered_metas = [], []\n",
" for want in ids:\n",
" if want in table:\n",
" d, m = table[want]\n",
" ordered_docs.append(d)\n",
" ordered_metas.append(m)\n",
" return ordered_docs, ordered_metas\n",
"\n",
" # Fallback: some clients omit 'ids' in the response; do deterministic 1-by-1\n",
" ordered_docs, ordered_metas = [], []\n",
" for _id in ids:\n",
" sub = self._col.get(ids=[_id], include=[\"documents\", \"metadatas\"])\n",
" d = (sub.get(\"documents\") or [None])[0]\n",
" m = (sub.get(\"metadatas\") or [None])[0]\n",
" if d is not None and m is not None:\n",
" ordered_docs.append(d)\n",
" ordered_metas.append(m)\n",
" return ordered_docs, ordered_metas\n",
"\n",
"\n",
"\n",
" def raw_doc(self, src: str) -> str:\n",
" return self._doc_cache[src]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0b1512b",
"metadata": {},
"outputs": [],
"source": [
"index = ChromaBareActsIndex(SET)\n",
"index.rebuild(acts_raw) # run once; safe to re-run if content changed\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e001124",
"metadata": {},
"outputs": [],
"source": [
"import json, re\n",
"\n",
"class QueryExpander:\n",
" \"\"\"\n",
" Expands a user question into N retrieval-friendly keyphrase queries.\n",
" - Tries LLM (strict JSON array) first.\n",
" - If parsing fails or returns too few items, falls back to deterministic domain-aware rewrites.\n",
" - Exposes .last_info to let the UI log whether fallback was used.\n",
" \"\"\"\n",
" def __init__(self, llm: LLM, n: int):\n",
" self.llm = llm\n",
" self.n = n\n",
" self.last_info: Dict[str, Any] = {}\n",
"\n",
" @staticmethod\n",
" def _extract_json_array(text: str) -> List[str]:\n",
" \"\"\"\n",
" Try to safely extract a JSON array from messy model output.\n",
" \"\"\"\n",
" try:\n",
" # Fast path: direct JSON\n",
" parsed = json.loads(text)\n",
" return [x for x in parsed if isinstance(x, str)]\n",
" except Exception:\n",
" pass\n",
"\n",
" # Fallback: regex the first [...] block\n",
" m = re.search(r\"\\[(.|\\n|\\r)*\\]\", text)\n",
" if m:\n",
" try:\n",
" parsed = json.loads(m.group(0))\n",
" return [x for x in parsed if isinstance(x, str)]\n",
" except Exception:\n",
" return []\n",
" return []\n",
"\n",
" @staticmethod\n",
" def _dedupe_keep_order(items: List[str]) -> List[str]:\n",
" seen = set()\n",
" out = []\n",
" for it in items:\n",
" k = it.strip().lower()\n",
" if k and k not in seen:\n",
" seen.add(k)\n",
" out.append(it.strip())\n",
" return out\n",
"\n",
" def _deterministic_fallback(self, question: str) -> List[str]:\n",
" \"\"\"\n",
" Domain-aware, retrieval-ready keyphrases for Bare Acts.\n",
" Keeps them short (good for vector/BM25), includes act names & synonyms.\n",
" \"\"\"\n",
" q = re.sub(r\"[?]+$\", \"\", question).strip()\n",
" # Heuristic tokens\n",
" base = q.lower()\n",
" # Try to infer a core noun/verb pair for variety\n",
" variants = [\n",
" f\"{q} section\",\n",
" f\"{q} provision bare act\",\n",
" f\"{q} indian penal code\",\n",
" f\"{q} bharatiya nyaya sanhita\",\n",
" f\"{q} punishment section key words\",\n",
" ]\n",
" # Add generic legal synonyms if helpful\n",
" synonyms = [\n",
" \"murder punishment section\",\n",
" \"culpable homicide punishment\",\n",
" \"offence of murder penalty\",\n",
" \"ipc section for murder\",\n",
" \"bns murder punishment\"\n",
" ]\n",
" pool = self._dedupe_keep_order(variants + synonyms)\n",
" return pool[: self.n] if len(pool) >= self.n else (pool + [q])[: self.n]\n",
"\n",
" def expand(self, question: str, *, model_override: str | None = None) -> List[str]:\n",
" sys = (\n",
" \"You generate EXACTLY the requested number of short, retrieval-friendly queries \"\n",
" \"for Indian Bare Acts (IPC/BNS/Constitution). Keep them concise (420 words), \"\n",
" \"keyphrase-style (no punctuation, no quotes), and diversify wording and act names. \"\n",
" \"Do NOT add commentary or mention any bare act or section number. Respond ONLY as a JSON array of strings.\\n\\n\"\n",
" \"Good examples:\\n\"\n",
" '[\"murder punishment\", \"punishment for murder\", '\n",
" '\"attempt to murder\", \"culpable homicide amounting to murder\", \"provision murder penalty\"]'\n",
" )\n",
" user = f\"Question:\\n{question}\\n\\nReturn {self.n} diverse keyphrase queries as a JSON array.\"\n",
"\n",
" raw = self.llm.complete(system=sys, user=user, temperature=0.2, max_tokens=300, model=model_override)\n",
" queries = self._extract_json_array(raw)\n",
" queries = self._dedupe_keep_order(queries)\n",
"\n",
" used_fallback = False\n",
" if len(queries) < self.n:\n",
" used_fallback = True\n",
" queries = self._deterministic_fallback(question)\n",
"\n",
" # final safety: trim and cap\n",
" queries = [re.sub(r\"[^\\w\\s\\-./]\", \"\", q).strip() for q in queries]\n",
" queries = [q for q in queries if q]\n",
" queries = queries[: self.n]\n",
"\n",
" self.last_info = {\n",
" \"used_fallback\": used_fallback,\n",
" \"raw\": raw,\n",
" \"final\": queries,\n",
" }\n",
" return queries\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c611e42",
"metadata": {},
"outputs": [],
"source": [
"class Retriever:\n",
" def __init__(self, index: Index, k: int):\n",
" self.index = index\n",
" self.k = k\n",
"\n",
" def topk(self, query: str) -> List[Dict]:\n",
" return self.index.query(query, k=self.k)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de155734",
"metadata": {},
"outputs": [],
"source": [
"class ContextExpander:\n",
" def __init__(\n",
" self,\n",
" index: Index,\n",
" radius: int,\n",
" max_blocks: int,\n",
" pad_words: int = 100,\n",
" words_to_tokens: float = 1.4,\n",
" ):\n",
" \"\"\"\n",
" Expand context by:\n",
" 1) Adding ±neighbor_radius chunks around every hit (by chunk_id)\n",
" 2) Converting those chunk spans to padded token ranges\n",
" 3) Merging overlapping token ranges per source\n",
" \"\"\"\n",
" self.index = index\n",
" self.radius = max(0, int(radius))\n",
" self.max_blocks = int(max_blocks)\n",
" # approx tokens to pad on both sides of each span\n",
" self.pad_tokens = int(pad_words * words_to_tokens)\n",
" # for logging/inspection\n",
" self.last_ids: List[str] = []\n",
"\n",
" @staticmethod\n",
" def _merge_spans(spans: List[Tuple[int, int]]) -> List[Tuple[int, int]]:\n",
" if not spans:\n",
" return []\n",
" spans = sorted(spans, key=lambda x: x[0])\n",
" merged = [spans[0]]\n",
" for s, e in spans[1:]:\n",
" ls, le = merged[-1]\n",
" if s <= le:\n",
" merged[-1] = (ls, max(le, e))\n",
" else:\n",
" merged.append((s, e))\n",
" return merged\n",
"\n",
" @staticmethod\n",
" def _dedupe_keep_order(items: List[str]) -> List[str]:\n",
" seen = set()\n",
" out: List[str] = []\n",
" for it in items:\n",
" if it not in seen:\n",
" seen.add(it)\n",
" out.append(it)\n",
" return out\n",
"\n",
" def _neighbor_ids_for_hit(self, meta: Dict[str, Any]) -> List[str]:\n",
" \"\"\"\n",
" Build vector-store IDs for the hit's chunk and its ±radius neighbors.\n",
" We rely on the indexing convention: id == f\"{source}-{chunk_id}\".\n",
" \"\"\"\n",
" src = meta[\"source\"]\n",
" cid = int(meta[\"chunk_id\"])\n",
" ids: List[str] = []\n",
" for d in range(-self.radius, self.radius + 1):\n",
" n = cid + d\n",
" if n < 0:\n",
" continue\n",
" ids.append(f\"{src}-{n}\")\n",
" return ids\n",
"\n",
" def expand_and_merge(self, hits: List[Dict]) -> List[Dict]:\n",
" # 1) Collect all candidate ids (hits + ±neighbors), dedupe, keep order\n",
" all_ids: List[str] = []\n",
" for h in hits:\n",
" m = h[\"meta\"]\n",
" hit_id = m.get(\"id\") or f\"{m['source']}-{int(m['chunk_id'])}\"\n",
" all_ids.append(hit_id)\n",
" all_ids.extend(self._neighbor_ids_for_hit(m))\n",
"\n",
" all_ids = self._dedupe_keep_order(all_ids)\n",
" self.last_ids = all_ids[:] # capture for logging/inspection\n",
"\n",
" # 2) Fetch metas for those ids (order-preserving as much as possible)\n",
" _, metas = self.index.get_by_ids(all_ids)\n",
"\n",
" # 3) Build padded token spans per source from metas\n",
" spans_by_src: Dict[str, List[Tuple[int, int]]] = {}\n",
" for m in metas:\n",
" if not m:\n",
" continue\n",
" src = m[\"source\"]\n",
" s_tok = int(m[\"start_tok\"])\n",
" e_tok = int(m[\"end_tok\"])\n",
" ps = max(0, s_tok - self.pad_tokens)\n",
" pe = e_tok + self.pad_tokens\n",
" spans_by_src.setdefault(src, []).append((ps, pe))\n",
"\n",
" # 4) Merge, slice, and construct blocks (capped by max_blocks)\n",
" blocks: List[Dict] = []\n",
" for src, spans in spans_by_src.items():\n",
" merged_tok_spans = self._merge_spans(spans)\n",
"\n",
" full_doc = self.index.raw_doc(src)\n",
" # Tokenize & get offsets once per source\n",
" enc = self.index.tokenizer(\n",
" full_doc, add_special_tokens=False, return_offsets_mapping=True\n",
" )\n",
" ids = enc[\"input_ids\"]\n",
" offs = enc[\"offset_mapping\"]\n",
"\n",
" for s_tok, e_tok in merged_tok_spans:\n",
" if not ids:\n",
" continue\n",
" s_tok = max(0, min(s_tok, len(ids)))\n",
" e_tok = max(0, min(e_tok, len(ids)))\n",
" if e_tok <= s_tok:\n",
" continue\n",
"\n",
" # token span -> char span\n",
" s_char = offs[s_tok][0]\n",
" e_char = offs[e_tok - 1][1]\n",
" text = full_doc[s_char:e_char]\n",
"\n",
" blocks.append(\n",
" {\n",
" \"source\": src,\n",
" \"start\": int(s_tok),\n",
" \"end\": int(e_tok),\n",
" \"text\": text,\n",
" }\n",
" )\n",
"\n",
" blocks.sort(key=lambda b: (b[\"source\"], b[\"start\"]))\n",
" return blocks[: self.max_blocks]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ae074d00",
"metadata": {},
"outputs": [],
"source": [
"class PromptBuilder:\n",
" SYSTEM = (\n",
" \"You are a precise legal assistant for Indian Bare Acts.\\n\"\n",
" )\n",
"\n",
" @staticmethod\n",
" def build_user(question: str, blocks: List[Dict]) -> str:\n",
" ctx = \"\\n\\n---\\n\\n\".join(\n",
" f\"[{b['source']} {b['start']}:{b['end']}]\\n{b['text']}\" for b in blocks\n",
" )\n",
" return (\n",
" f\"Question:\\n{question}\\n\\n\"\n",
" f\"Context (use only this):\\n{ctx}\\n\\n\"\n",
" \"Instructions:\\n\"\n",
" \"- Describe the context whether it has information to answer the question or not or whether any part can be used to answer the question.\\n\"\n",
" \"- Answer the original question in new paragraph with label Answer:\\n\"\n",
" \"- Quote or paraphrase ONLY from the context above.\\n\"\n",
" \"- Inline-cite with [source start:end] when using any snippet.\\n\"\n",
" \"- If the answer is not in context, please describe what the agentic rag found related to original question and try to tell answer based on the data.\"\n",
" )\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc8ca7b2",
"metadata": {},
"outputs": [],
"source": [
"class Answerer:\n",
" def __init__(self, llm: LLM, set: Settings):\n",
" self.llm = llm\n",
" self.set = set\n",
"\n",
" def answer(self, question: str, blocks: List[Dict], model: str | None = None) -> str:\n",
" user = PromptBuilder.build_user(question, blocks)\n",
" return self.llm.complete(\n",
" system=PromptBuilder.SYSTEM,\n",
" user=user,\n",
" temperature=self.set.temperature,\n",
" max_tokens=self.set.max_tokens,\n",
" model=model or self.set.gen_model,\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8a33d665",
"metadata": {},
"outputs": [],
"source": [
"class LegalAgent:\n",
" def __init__(self, expander: QueryExpander, retriever: Retriever, ctx_expander: ContextExpander, answerer: Answerer, set: Settings):\n",
" self.expander = expander\n",
" self.retriever = retriever\n",
" self.ctx_expander = ctx_expander\n",
" self.answerer = answerer\n",
" self.set = set\n",
"\n",
" def run(self, question: str, model: str | None = None) -> str:\n",
" rewrites = self.expander.expand(question, model_override=model)\n",
" hits: List[Dict] = []\n",
" for q in rewrites:\n",
" hits.extend(self.retriever.topk(q))\n",
" blocks = self.ctx_expander.expand_and_merge(hits)\n",
" return self.answerer.answer(question, blocks, model=model)\n",
"\n",
" def run_stream(self, question: str, model: str | None = None):\n",
" \"\"\"\n",
" Generator: yields tuples (answer_or_none, logs_html) multiple times.\n",
" \"\"\"\n",
" log = RunLogger()\n",
" log.add(f\"Question: {question}\")\n",
" yield None, log.html()\n",
"\n",
" # 1) Expand queries\n",
" rewrites = self.expander.expand(question, model_override=model)\n",
" if getattr(self.expander, \"last_info\", {}).get(\"used_fallback\", False):\n",
" log.add(\"Query expansion: LLM output unparsable → using deterministic fallback.\")\n",
" log.add(f\"Expanded into {len(rewrites)} queries:\")\n",
" for i, q in enumerate(rewrites, 1):\n",
" log.add(f\" {i}. {q}\")\n",
" yield None, log.html()\n",
"\n",
" # 2) Retrieve top-k per rewrite\n",
" all_hits: List[Dict] = []\n",
" for i, q in enumerate(rewrites, 1):\n",
" hits = self.retriever.topk(q)\n",
" all_hits.extend(hits)\n",
" top3 = \", \".join(_hit_id(h) for h in hits[:3]) or \"—\"\n",
" log.add(f\"Retrieval {i}/{len(rewrites)}: got {len(hits)} hits → {top3}\")\n",
" yield None, log.html()\n",
"\n",
" # 3) Context expansion / merging (with neighbor ids logged)\n",
" blocks = self.ctx_expander.expand_and_merge(all_hits)\n",
" used = self.ctx_expander.last_ids\n",
" peek = \", \".join(used[:8]) + (\" …\" if len(used) > 8 else \"\")\n",
" log.add(f\"Neighbor addition: collected {len(used)} chunk-ids → {peek}\")\n",
"\n",
" approx_words = int(self.ctx_expander.pad_tokens / 1.4) # inverse of words_to_tokens≈1.4\n",
" log.add(\n",
" f\"Context expansion: merged {len(blocks)} block(s) \"\n",
" f\"(radius ±{self.ctx_expander.radius}, pad ≈{approx_words} words).\"\n",
" )\n",
" for b in blocks:\n",
" log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n",
" log.add(b[\"text\"])\n",
" yield None, log.html()\n",
"\n",
" # 4) LLM answer\n",
" log.add(f\"Asking LLM: {model or self.set.gen_model}\")\n",
" yield None, log.html()\n",
" answer = self.answerer.answer(question, blocks, model=model)\n",
" log.add(\"Answer ready.\")\n",
" yield answer, log.html()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9e57320",
"metadata": {},
"outputs": [],
"source": [
"def make_agent(gen_model: str) -> LegalAgent:\n",
" llm = OpenAILLM(gen_model)\n",
" expander = QueryExpander(llm=llm, n=SET.expansions)\n",
" retriever = Retriever(index=index, k=SET.topk_per_rewrite)\n",
" ctx_expander = ContextExpander(index=index, radius=SET.neighbor_radius, max_blocks=SET.max_blocks_for_llm)\n",
" answerer = Answerer(llm=llm, set=SET)\n",
" return LegalAgent(expander, retriever, ctx_expander, answerer, SET)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "837df704",
"metadata": {},
"outputs": [],
"source": [
"agent = make_agent(SET.gen_model)\n",
"print(\"Agent ready (global).\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1665aba",
"metadata": {},
"outputs": [],
"source": [
"import gradio as gr\n",
"\n",
"MODEL_CHOICES = [\n",
" \"gpt-4o-mini\", # default\n",
" \"gpt-4o\",\n",
" \"gpt-5\",\n",
" \"gpt-5-nano\",\n",
"]\n",
"\n",
"DEFAULT_Q = \"is death by ignorance like giving medicine to old person amounts to a murder ?\"\n",
"\n",
"def chat_stream(chat_history: List[Tuple[str, str]], message: str, model_choice: str, topk: int, radius: int, max_blocks: int):\n",
" # Always create a fresh agent per request (avoids NameError & state bleed)\n",
" _agent = make_agent(model_choice)\n",
" _agent.retriever.k = int(topk)\n",
" _agent.ctx_expander.radius = int(radius)\n",
" _agent.ctx_expander.max_blocks = int(max_blocks)\n",
"\n",
" # Initialize history and add a placeholder response\n",
" if chat_history is None:\n",
" chat_history = []\n",
" chat_history = chat_history + [(message, \"🛠️ running pipeline…\")]\n",
" yield chat_history, \"<i>starting…</i>\"\n",
"\n",
" # Stream logs + final answer\n",
" for ans, logs_html in _agent.run_stream(message, model=model_choice):\n",
" assistant_text = ans if ans is not None else \"⏳ working…\"\n",
" chat_history[-1] = (message, assistant_text)\n",
" # logs_html already contains inline styles; just pass it through\n",
" yield chat_history, logs_html\n",
"\n",
"with gr.Blocks(title=\"Week 8 • Legal QnA RAG (Agentic)\") as app:\n",
" gr.Markdown(\"### 🧑‍⚖️ Legal Q&A on Bare Acts — Agentic RAG (Week 8)\")\n",
" gr.Markdown(\"Flow: **query expansion → multi-retrieval → context expansion (±neighbors, merged) → LLM answer**\")\n",
"\n",
" with gr.Row():\n",
" # Left: classic chat\n",
" with gr.Column(scale=3):\n",
" chatbot = gr.Chatbot(label=\"Chat\", height=420)\n",
" with gr.Row():\n",
" msg = gr.Textbox(value=DEFAULT_Q, label=\"Ask a legal question\", scale=5)\n",
" send = gr.Button(\"Send\", variant=\"primary\", scale=1)\n",
"\n",
" # Right: logs panel (simple header + HTML)\n",
" with gr.Column(scale=2):\n",
" gr.Markdown(\"#### Agent Logs\")\n",
" logs_html = gr.HTML(value=\"<div style='font-family:ui-monospace,Menlo,Consolas;white-space:pre-wrap;border:1px solid #ddd;border-radius:8px;padding:10px;height:380px;overflow:auto;background:#fafafa'>Idle</div>\")\n",
"\n",
" with gr.Accordion(\"Advanced\", open=False):\n",
" with gr.Row():\n",
" model_dd = gr.Dropdown(choices=MODEL_CHOICES, value=SET.gen_model, label=\"LLM Model\", scale=2)\n",
" topk = gr.Slider(2, 20, value=SET.topk_per_rewrite, step=1, label=\"Top-K per rewrite\", scale=2)\n",
" radius = gr.Slider(0, 4, value=SET.neighbor_radius, step=1, label=\"Neighbor radius (±)\", scale=2)\n",
" cap = gr.Slider(4, 60, value=SET.max_blocks_for_llm, step=1, label=\"Max merged blocks for LLM\", scale=2)\n",
"\n",
" # Wire streaming generator to both Send click and Enter submit\n",
" send.click(\n",
" fn=chat_stream,\n",
" inputs=[chatbot, msg, model_dd, topk, radius, cap],\n",
" outputs=[chatbot, logs_html],\n",
" show_progress=False,\n",
" ).then(lambda: \"\", None, msg) # clear textbox\n",
"\n",
" msg.submit(\n",
" fn=chat_stream,\n",
" inputs=[chatbot, msg, model_dd, topk, radius, cap],\n",
" outputs=[chatbot, logs_html],\n",
" show_progress=False,\n",
" ).then(lambda: \"\", None, msg) # clear textbox\n",
"\n",
"app.launch(inbrowser=True)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "llm-engineering",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}