Agentic RAG with Query and Context Expannsion
This commit is contained in:
@@ -0,0 +1,940 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d27544d4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Week 8 • Legal QnA RAG (Agentic)\n",
|
||||||
|
"\n",
|
||||||
|
"import os\n",
|
||||||
|
"from dataclasses import dataclass, field\n",
|
||||||
|
"from typing import List, Dict, Any, Tuple, Protocol\n",
|
||||||
|
"\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"from dotenv import load_dotenv\n",
|
||||||
|
"from openai import OpenAI\n",
|
||||||
|
"import chromadb\n",
|
||||||
|
"from chromadb import PersistentClient\n",
|
||||||
|
"from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n",
|
||||||
|
"\n",
|
||||||
|
"from transformers import AutoTokenizer\n",
|
||||||
|
"import gradio as gr\n",
|
||||||
|
"\n",
|
||||||
|
"load_dotenv(override=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "efe4e4db",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"@dataclass(frozen=True)\n",
|
||||||
|
"class Settings:\n",
|
||||||
|
" # Data\n",
|
||||||
|
" data_root: str = \"knowledge_base/bare_acts\" # same as W5\n",
|
||||||
|
" \n",
|
||||||
|
" # Vector store\n",
|
||||||
|
" db_path: str = \"vector_db_w8\"\n",
|
||||||
|
" collection: str = \"bare_acts\"\n",
|
||||||
|
" embed_model: str = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
|
||||||
|
" chunk_max_tokens: int = 256\n",
|
||||||
|
" chunk_overlap: float = 0.50\n",
|
||||||
|
"\n",
|
||||||
|
" # Retrieval\n",
|
||||||
|
" expansions: int = 5\n",
|
||||||
|
" topk_per_rewrite: int = 10\n",
|
||||||
|
" neighbor_radius: int = 2\n",
|
||||||
|
" max_blocks_for_llm: int = 20 # cap merged blocks for token control\n",
|
||||||
|
"\n",
|
||||||
|
" # LLM (default + selectable)\n",
|
||||||
|
" gen_model: str = \"gpt-4o-mini\" # default\n",
|
||||||
|
" # You can list any OpenAI-compatible IDs you have access to\n",
|
||||||
|
" selectable_models: tuple = (\"gpt-4o-mini\",)\n",
|
||||||
|
"\n",
|
||||||
|
" temperature: float = 0.2\n",
|
||||||
|
" max_tokens: int = 220\n",
|
||||||
|
"\n",
|
||||||
|
"SET = Settings()\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0e37dda9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"@dataclass\n",
|
||||||
|
"class Trace:\n",
|
||||||
|
" rewrites: List[str] = field(default_factory=list)\n",
|
||||||
|
" retrievals: List[Dict[str, Any]] = field(default_factory=list) # [{query, hits:[{source,chunk_id,start,end,preview}]}]\n",
|
||||||
|
" merged_blocks: List[Dict[str, Any]] = field(default_factory=list) # [{source,start,end,preview}]\n",
|
||||||
|
" notes: List[str] = field(default_factory=list)\n",
|
||||||
|
"\n",
|
||||||
|
" def as_markdown(self) -> str:\n",
|
||||||
|
" lines = []\n",
|
||||||
|
" if self.rewrites:\n",
|
||||||
|
" lines.append(\"#### 🔁 Query expansions\")\n",
|
||||||
|
" for i, q in enumerate(self.rewrites, 1):\n",
|
||||||
|
" lines.append(f\"{i}. `{q}`\")\n",
|
||||||
|
" lines.append(\"\")\n",
|
||||||
|
"\n",
|
||||||
|
" if self.retrievals:\n",
|
||||||
|
" lines.append(\"#### 🔎 Retrieval per rewrite (top-10 each)\")\n",
|
||||||
|
" for r in self.retrievals:\n",
|
||||||
|
" lines.append(f\"- **Rewrite:** `{r['query']}`\")\n",
|
||||||
|
" for h in r.get(\"hits\", []):\n",
|
||||||
|
" lines.append(\n",
|
||||||
|
" f\" - [{h['source']} #{h['chunk_id']}] {h['start']}:{h['end']} — {h['preview']}\"\n",
|
||||||
|
" )\n",
|
||||||
|
" lines.append(\"\")\n",
|
||||||
|
" \n",
|
||||||
|
" if self.merged_blocks:\n",
|
||||||
|
" lines.append(\"#### 🧩 Context expansion (± neighbors, merged)\")\n",
|
||||||
|
" for b in self.merged_blocks:\n",
|
||||||
|
" lines.append(f\"- [{b['source']}] {b['start']}:{b['end']} — {b['preview']}\")\n",
|
||||||
|
" lines.append(f\"\\n**Total merged blocks:** {len(self.merged_blocks)}\")\n",
|
||||||
|
" lines.append(\"\")\n",
|
||||||
|
" \n",
|
||||||
|
" if self.notes:\n",
|
||||||
|
" lines.append(\"#### 📝 Notes\")\n",
|
||||||
|
" for n in self.notes:\n",
|
||||||
|
" lines.append(f\"- {n}\")\n",
|
||||||
|
" lines.append(\"\")\n",
|
||||||
|
"\n",
|
||||||
|
" return \"\\n\".join(lines) if lines else \"_No logs._\"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "cf7e8b41",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class LLM(Protocol):\n",
|
||||||
|
" def complete(\n",
|
||||||
|
" self,\n",
|
||||||
|
" system: str,\n",
|
||||||
|
" user: str,\n",
|
||||||
|
" temperature: float,\n",
|
||||||
|
" max_tokens: int,\n",
|
||||||
|
" model: str | None = None, # <- allow per-call model override\n",
|
||||||
|
" ) -> str: ...\n",
|
||||||
|
"\n",
|
||||||
|
"class Index(Protocol):\n",
|
||||||
|
" def rebuild(self, docs: List[Tuple[str, str]]) -> None: ...\n",
|
||||||
|
" def query(self, q: str, k: int) -> List[Dict]: ...\n",
|
||||||
|
" def get_by_ids(self, ids: List[str]) -> Tuple[List[str], List[Dict]]: ...\n",
|
||||||
|
" def raw_doc(self, src: str) -> str: ...\n",
|
||||||
|
" @property\n",
|
||||||
|
" def tokenizer(self) -> AutoTokenizer: ...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "30636b66",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class OpenAILLM(LLM):\n",
|
||||||
|
" def __init__(self, default_model: str):\n",
|
||||||
|
" self._client = OpenAI()\n",
|
||||||
|
" self._default_model = default_model\n",
|
||||||
|
"\n",
|
||||||
|
" def complete(\n",
|
||||||
|
" self,\n",
|
||||||
|
" system: str,\n",
|
||||||
|
" user: str,\n",
|
||||||
|
" temperature: float = 0.2,\n",
|
||||||
|
" max_tokens: int = 220,\n",
|
||||||
|
" model: str | None = None,\n",
|
||||||
|
" ) -> str:\n",
|
||||||
|
" model_id = model or self._default_model\n",
|
||||||
|
" resp = self._client.chat.completions.create(\n",
|
||||||
|
" model=model_id,\n",
|
||||||
|
" messages=[\n",
|
||||||
|
" {\"role\": \"system\", \"content\": system},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": user},\n",
|
||||||
|
" ],\n",
|
||||||
|
" seed=42,\n",
|
||||||
|
" )\n",
|
||||||
|
" return (resp.choices[0].message.content or \"\").strip()\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "86c83552",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# --- Streaming run logger (ADD) ---\n",
|
||||||
|
"from dataclasses import dataclass, field\n",
|
||||||
|
"from time import perf_counter\n",
|
||||||
|
"from html import escape\n",
|
||||||
|
"\n",
|
||||||
|
"@dataclass\n",
|
||||||
|
"class RunLogger:\n",
|
||||||
|
" lines: List[str] = field(default_factory=list)\n",
|
||||||
|
" t0: float = field(default_factory=perf_counter)\n",
|
||||||
|
"\n",
|
||||||
|
" def add(self, msg: str):\n",
|
||||||
|
" dt = perf_counter() - self.t0\n",
|
||||||
|
" self.lines.append(f\"[{dt:05.2f}s] {msg}\")\n",
|
||||||
|
"\n",
|
||||||
|
" def html(self) -> str:\n",
|
||||||
|
" # Simple monospace log panel\n",
|
||||||
|
" safe = [escape(x) for x in self.lines[-300:]] # cap last 300 lines\n",
|
||||||
|
" return \"<div style='font-family:ui-monospace,Menlo,Consolas;white-space:pre-wrap'>\" + \"<br/>\".join(safe) + \"</div>\"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "545699c6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def _one_line(text: str, limit: int = 120) -> str:\n",
|
||||||
|
" s = \" \".join(text.split())\n",
|
||||||
|
" return s[:limit] + \"…\" if len(s) > limit else s\n",
|
||||||
|
"\n",
|
||||||
|
"def _hit_view(h: Dict) -> Dict:\n",
|
||||||
|
" m = h[\"meta\"]\n",
|
||||||
|
" return {\n",
|
||||||
|
" \"source\": m[\"source\"],\n",
|
||||||
|
" \"chunk_id\": int(m[\"chunk_id\"]),\n",
|
||||||
|
" \"start\": int(m[\"start_tok\"]),\n",
|
||||||
|
" \"end\": int(m[\"end_tok\"]),\n",
|
||||||
|
" \"preview\": _one_line(h[\"text\"]),\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
"def _hit_id(h: Dict) -> str:\n",
|
||||||
|
" m = h[\"meta\"]\n",
|
||||||
|
" # Example: [bns 4720:4860 (#38)]\n",
|
||||||
|
" return f\"[{m['source']} {int(m['start_tok'])}:{int(m['end_tok'])} (#{int(m['chunk_id'])})]\"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "af537e05",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def load_bare_acts(root: str) -> List[Tuple[str, str]]:\n",
|
||||||
|
" \"\"\"Return (source_id, text). Each .txt is one bare act file.\"\"\"\n",
|
||||||
|
" base = Path(root)\n",
|
||||||
|
" if not base.exists():\n",
|
||||||
|
" raise FileNotFoundError(f\"Folder not found: {base.resolve()}\")\n",
|
||||||
|
" pairs: List[Tuple[str, str]] = []\n",
|
||||||
|
" for p in sorted(base.glob(\"*.txt\")):\n",
|
||||||
|
" pairs.append((p.stem, p.read_text(encoding=\"utf-8\")))\n",
|
||||||
|
" if not pairs:\n",
|
||||||
|
" raise RuntimeError(f\"No .txt files found under {base}\")\n",
|
||||||
|
" return pairs\n",
|
||||||
|
"\n",
|
||||||
|
"acts_raw = load_bare_acts(SET.data_root)\n",
|
||||||
|
"print(\"Bare Acts loaded:\", [s for s, _ in acts_raw][:5], \"…\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7eec89e4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def chunk_text_with_spans(\n",
|
||||||
|
" text: str,\n",
|
||||||
|
" tokenizer: AutoTokenizer,\n",
|
||||||
|
" max_tokens: int,\n",
|
||||||
|
" overlap_ratio: float,\n",
|
||||||
|
") -> List[Tuple[int, int, int, int, str]]:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Returns a list of chunks with both token and character spans:\n",
|
||||||
|
" (start_tok, end_tok, start_char, end_char, chunk_text)\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" enc = tokenizer(\n",
|
||||||
|
" text,\n",
|
||||||
|
" add_special_tokens=False,\n",
|
||||||
|
" return_offsets_mapping=True,\n",
|
||||||
|
" )\n",
|
||||||
|
" ids = enc[\"input_ids\"]\n",
|
||||||
|
" offs = enc[\"offset_mapping\"] # list[(char_start, char_end)]\n",
|
||||||
|
" if not ids:\n",
|
||||||
|
" return []\n",
|
||||||
|
"\n",
|
||||||
|
" step = max(1, int(max_tokens * (1.0 - overlap_ratio)))\n",
|
||||||
|
" out = []\n",
|
||||||
|
" for start in range(0, len(ids), step):\n",
|
||||||
|
" end = min(start + max_tokens, len(ids))\n",
|
||||||
|
" if start >= end:\n",
|
||||||
|
" break\n",
|
||||||
|
" # token -> char\n",
|
||||||
|
" start_char = offs[start][0]\n",
|
||||||
|
" end_char = offs[end - 1][1]\n",
|
||||||
|
" chunk = text[start_char:end_char].strip()\n",
|
||||||
|
" if chunk:\n",
|
||||||
|
" out.append((start, end, start_char, end_char, chunk))\n",
|
||||||
|
" if end >= len(ids):\n",
|
||||||
|
" break\n",
|
||||||
|
" return out\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4862732b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class ChromaBareActsIndex(Index):\n",
|
||||||
|
" def __init__(self, settings: Settings):\n",
|
||||||
|
" self.set = settings\n",
|
||||||
|
" self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=self.set.embed_model)\n",
|
||||||
|
" self._tokenizer = AutoTokenizer.from_pretrained(self.set.embed_model)\n",
|
||||||
|
" self._client: PersistentClient = PersistentClient(path=self.set.db_path)\n",
|
||||||
|
" self._col = self._client.get_or_create_collection(\n",
|
||||||
|
" name=self.set.collection,\n",
|
||||||
|
" embedding_function=self.embed_fn\n",
|
||||||
|
" )\n",
|
||||||
|
" self._doc_cache: Dict[str, str] = {}\n",
|
||||||
|
"\n",
|
||||||
|
" @property\n",
|
||||||
|
" def tokenizer(self) -> AutoTokenizer:\n",
|
||||||
|
" return self._tokenizer\n",
|
||||||
|
"\n",
|
||||||
|
" def rebuild(self, docs: List[Tuple[str, str]]) -> None:\n",
|
||||||
|
" self._doc_cache = {src: txt for src, txt in docs}\n",
|
||||||
|
" try:\n",
|
||||||
|
" self._client.delete_collection(self.set.collection)\n",
|
||||||
|
" except Exception:\n",
|
||||||
|
" pass\n",
|
||||||
|
" self._col = self._client.get_or_create_collection(\n",
|
||||||
|
" name=self.set.collection,\n",
|
||||||
|
" embedding_function=self.embed_fn\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" ids, texts, metas = [], [], []\n",
|
||||||
|
" for src, text in docs:\n",
|
||||||
|
" spans = chunk_text_with_spans(\n",
|
||||||
|
" text, tokenizer=self._tokenizer,\n",
|
||||||
|
" max_tokens=self.set.chunk_max_tokens,\n",
|
||||||
|
" overlap_ratio=self.set.chunk_overlap\n",
|
||||||
|
" )\n",
|
||||||
|
" for idx, (start_tok, end_tok, start_char, end_char, ch) in enumerate(spans):\n",
|
||||||
|
" ids.append(f\"{src}-{idx}\")\n",
|
||||||
|
" texts.append(ch)\n",
|
||||||
|
" metas.append({\n",
|
||||||
|
" \"source\": src,\n",
|
||||||
|
" \"chunk_id\": idx,\n",
|
||||||
|
" \"start_tok\": int(start_tok),\n",
|
||||||
|
" \"end_tok\": int(end_tok),\n",
|
||||||
|
" \"start_char\": int(start_char),\n",
|
||||||
|
" \"end_char\": int(end_char),\n",
|
||||||
|
" })\n",
|
||||||
|
" if ids:\n",
|
||||||
|
" self._col.add(ids=ids, documents=texts, metadatas=metas)\n",
|
||||||
|
" print(f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.set.collection}\")\n",
|
||||||
|
"\n",
|
||||||
|
" def query(self, q: str, k: int) -> List[Dict]:\n",
|
||||||
|
" res = self._col.query(query_texts=[q], n_results=k)\n",
|
||||||
|
" docs = res.get(\"documents\", [[]])[0]\n",
|
||||||
|
" metas = res.get(\"metadatas\", [[]])[0]\n",
|
||||||
|
" ids = res.get(\"ids\", [[]])[0]\n",
|
||||||
|
" out = []\n",
|
||||||
|
" for _id, d, m in zip(ids, docs, metas):\n",
|
||||||
|
" if not m: \n",
|
||||||
|
" continue\n",
|
||||||
|
" m = dict(m)\n",
|
||||||
|
" m[\"id\"] = _id\n",
|
||||||
|
" out.append({\"text\": d, \"meta\": m})\n",
|
||||||
|
" return out\n",
|
||||||
|
"\n",
|
||||||
|
" def get_by_ids(self, ids: List[str]) -> Tuple[List[str], List[Dict]]:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Return documents & metadatas in the SAME ORDER as the requested ids.\n",
|
||||||
|
" Compatible with Chroma clients that don't allow 'ids' in `include`.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" if not ids:\n",
|
||||||
|
" return [], []\n",
|
||||||
|
"\n",
|
||||||
|
" # 'ids' is NOT a valid item in `include`; ask only for docs+metas\n",
|
||||||
|
" got = self._col.get(ids=ids, include=[\"documents\", \"metadatas\"])\n",
|
||||||
|
"\n",
|
||||||
|
" ret_ids = (got.get(\"ids\") or [])\n",
|
||||||
|
" docs = (got.get(\"documents\") or [])\n",
|
||||||
|
" metas = (got.get(\"metadatas\") or [])\n",
|
||||||
|
"\n",
|
||||||
|
" # If the client returns 'ids', reorder using it\n",
|
||||||
|
" if ret_ids:\n",
|
||||||
|
" table = {i: (d, m) for i, d, m in zip(ret_ids, docs, metas)}\n",
|
||||||
|
" ordered_docs, ordered_metas = [], []\n",
|
||||||
|
" for want in ids:\n",
|
||||||
|
" if want in table:\n",
|
||||||
|
" d, m = table[want]\n",
|
||||||
|
" ordered_docs.append(d)\n",
|
||||||
|
" ordered_metas.append(m)\n",
|
||||||
|
" return ordered_docs, ordered_metas\n",
|
||||||
|
"\n",
|
||||||
|
" # Fallback: some clients omit 'ids' in the response; do deterministic 1-by-1\n",
|
||||||
|
" ordered_docs, ordered_metas = [], []\n",
|
||||||
|
" for _id in ids:\n",
|
||||||
|
" sub = self._col.get(ids=[_id], include=[\"documents\", \"metadatas\"])\n",
|
||||||
|
" d = (sub.get(\"documents\") or [None])[0]\n",
|
||||||
|
" m = (sub.get(\"metadatas\") or [None])[0]\n",
|
||||||
|
" if d is not None and m is not None:\n",
|
||||||
|
" ordered_docs.append(d)\n",
|
||||||
|
" ordered_metas.append(m)\n",
|
||||||
|
" return ordered_docs, ordered_metas\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" def raw_doc(self, src: str) -> str:\n",
|
||||||
|
" return self._doc_cache[src]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c0b1512b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"index = ChromaBareActsIndex(SET)\n",
|
||||||
|
"index.rebuild(acts_raw) # run once; safe to re-run if content changed\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1e001124",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import json, re\n",
|
||||||
|
"\n",
|
||||||
|
"class QueryExpander:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Expands a user question into N retrieval-friendly keyphrase queries.\n",
|
||||||
|
" - Tries LLM (strict JSON array) first.\n",
|
||||||
|
" - If parsing fails or returns too few items, falls back to deterministic domain-aware rewrites.\n",
|
||||||
|
" - Exposes .last_info to let the UI log whether fallback was used.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" def __init__(self, llm: LLM, n: int):\n",
|
||||||
|
" self.llm = llm\n",
|
||||||
|
" self.n = n\n",
|
||||||
|
" self.last_info: Dict[str, Any] = {}\n",
|
||||||
|
"\n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def _extract_json_array(text: str) -> List[str]:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Try to safely extract a JSON array from messy model output.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" try:\n",
|
||||||
|
" # Fast path: direct JSON\n",
|
||||||
|
" parsed = json.loads(text)\n",
|
||||||
|
" return [x for x in parsed if isinstance(x, str)]\n",
|
||||||
|
" except Exception:\n",
|
||||||
|
" pass\n",
|
||||||
|
"\n",
|
||||||
|
" # Fallback: regex the first [...] block\n",
|
||||||
|
" m = re.search(r\"\\[(.|\\n|\\r)*\\]\", text)\n",
|
||||||
|
" if m:\n",
|
||||||
|
" try:\n",
|
||||||
|
" parsed = json.loads(m.group(0))\n",
|
||||||
|
" return [x for x in parsed if isinstance(x, str)]\n",
|
||||||
|
" except Exception:\n",
|
||||||
|
" return []\n",
|
||||||
|
" return []\n",
|
||||||
|
"\n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def _dedupe_keep_order(items: List[str]) -> List[str]:\n",
|
||||||
|
" seen = set()\n",
|
||||||
|
" out = []\n",
|
||||||
|
" for it in items:\n",
|
||||||
|
" k = it.strip().lower()\n",
|
||||||
|
" if k and k not in seen:\n",
|
||||||
|
" seen.add(k)\n",
|
||||||
|
" out.append(it.strip())\n",
|
||||||
|
" return out\n",
|
||||||
|
"\n",
|
||||||
|
" def _deterministic_fallback(self, question: str) -> List[str]:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Domain-aware, retrieval-ready keyphrases for Bare Acts.\n",
|
||||||
|
" Keeps them short (good for vector/BM25), includes act names & synonyms.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" q = re.sub(r\"[?]+$\", \"\", question).strip()\n",
|
||||||
|
" # Heuristic tokens\n",
|
||||||
|
" base = q.lower()\n",
|
||||||
|
" # Try to infer a core noun/verb pair for variety\n",
|
||||||
|
" variants = [\n",
|
||||||
|
" f\"{q} section\",\n",
|
||||||
|
" f\"{q} provision bare act\",\n",
|
||||||
|
" f\"{q} indian penal code\",\n",
|
||||||
|
" f\"{q} bharatiya nyaya sanhita\",\n",
|
||||||
|
" f\"{q} punishment section key words\",\n",
|
||||||
|
" ]\n",
|
||||||
|
" # Add generic legal synonyms if helpful\n",
|
||||||
|
" synonyms = [\n",
|
||||||
|
" \"murder punishment section\",\n",
|
||||||
|
" \"culpable homicide punishment\",\n",
|
||||||
|
" \"offence of murder penalty\",\n",
|
||||||
|
" \"ipc section for murder\",\n",
|
||||||
|
" \"bns murder punishment\"\n",
|
||||||
|
" ]\n",
|
||||||
|
" pool = self._dedupe_keep_order(variants + synonyms)\n",
|
||||||
|
" return pool[: self.n] if len(pool) >= self.n else (pool + [q])[: self.n]\n",
|
||||||
|
"\n",
|
||||||
|
" def expand(self, question: str, *, model_override: str | None = None) -> List[str]:\n",
|
||||||
|
" sys = (\n",
|
||||||
|
" \"You generate EXACTLY the requested number of short, retrieval-friendly queries \"\n",
|
||||||
|
" \"for Indian Bare Acts (IPC/BNS/Constitution). Keep them concise (4–20 words), \"\n",
|
||||||
|
" \"keyphrase-style (no punctuation, no quotes), and diversify wording and act names. \"\n",
|
||||||
|
" \"Do NOT add commentary or mention any bare act or section number. Respond ONLY as a JSON array of strings.\\n\\n\"\n",
|
||||||
|
" \"Good examples:\\n\"\n",
|
||||||
|
" '[\"murder punishment\", \"punishment for murder\", '\n",
|
||||||
|
" '\"attempt to murder\", \"culpable homicide amounting to murder\", \"provision murder penalty\"]'\n",
|
||||||
|
" )\n",
|
||||||
|
" user = f\"Question:\\n{question}\\n\\nReturn {self.n} diverse keyphrase queries as a JSON array.\"\n",
|
||||||
|
"\n",
|
||||||
|
" raw = self.llm.complete(system=sys, user=user, temperature=0.2, max_tokens=300, model=model_override)\n",
|
||||||
|
" queries = self._extract_json_array(raw)\n",
|
||||||
|
" queries = self._dedupe_keep_order(queries)\n",
|
||||||
|
"\n",
|
||||||
|
" used_fallback = False\n",
|
||||||
|
" if len(queries) < self.n:\n",
|
||||||
|
" used_fallback = True\n",
|
||||||
|
" queries = self._deterministic_fallback(question)\n",
|
||||||
|
"\n",
|
||||||
|
" # final safety: trim and cap\n",
|
||||||
|
" queries = [re.sub(r\"[^\\w\\s\\-./]\", \"\", q).strip() for q in queries]\n",
|
||||||
|
" queries = [q for q in queries if q]\n",
|
||||||
|
" queries = queries[: self.n]\n",
|
||||||
|
"\n",
|
||||||
|
" self.last_info = {\n",
|
||||||
|
" \"used_fallback\": used_fallback,\n",
|
||||||
|
" \"raw\": raw,\n",
|
||||||
|
" \"final\": queries,\n",
|
||||||
|
" }\n",
|
||||||
|
" return queries\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5c611e42",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class Retriever:\n",
|
||||||
|
" def __init__(self, index: Index, k: int):\n",
|
||||||
|
" self.index = index\n",
|
||||||
|
" self.k = k\n",
|
||||||
|
"\n",
|
||||||
|
" def topk(self, query: str) -> List[Dict]:\n",
|
||||||
|
" return self.index.query(query, k=self.k)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "de155734",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class ContextExpander:\n",
|
||||||
|
" def __init__(\n",
|
||||||
|
" self,\n",
|
||||||
|
" index: Index,\n",
|
||||||
|
" radius: int,\n",
|
||||||
|
" max_blocks: int,\n",
|
||||||
|
" pad_words: int = 100,\n",
|
||||||
|
" words_to_tokens: float = 1.4,\n",
|
||||||
|
" ):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Expand context by:\n",
|
||||||
|
" 1) Adding ±neighbor_radius chunks around every hit (by chunk_id)\n",
|
||||||
|
" 2) Converting those chunk spans to padded token ranges\n",
|
||||||
|
" 3) Merging overlapping token ranges per source\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" self.index = index\n",
|
||||||
|
" self.radius = max(0, int(radius))\n",
|
||||||
|
" self.max_blocks = int(max_blocks)\n",
|
||||||
|
" # approx tokens to pad on both sides of each span\n",
|
||||||
|
" self.pad_tokens = int(pad_words * words_to_tokens)\n",
|
||||||
|
" # for logging/inspection\n",
|
||||||
|
" self.last_ids: List[str] = []\n",
|
||||||
|
"\n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def _merge_spans(spans: List[Tuple[int, int]]) -> List[Tuple[int, int]]:\n",
|
||||||
|
" if not spans:\n",
|
||||||
|
" return []\n",
|
||||||
|
" spans = sorted(spans, key=lambda x: x[0])\n",
|
||||||
|
" merged = [spans[0]]\n",
|
||||||
|
" for s, e in spans[1:]:\n",
|
||||||
|
" ls, le = merged[-1]\n",
|
||||||
|
" if s <= le:\n",
|
||||||
|
" merged[-1] = (ls, max(le, e))\n",
|
||||||
|
" else:\n",
|
||||||
|
" merged.append((s, e))\n",
|
||||||
|
" return merged\n",
|
||||||
|
"\n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def _dedupe_keep_order(items: List[str]) -> List[str]:\n",
|
||||||
|
" seen = set()\n",
|
||||||
|
" out: List[str] = []\n",
|
||||||
|
" for it in items:\n",
|
||||||
|
" if it not in seen:\n",
|
||||||
|
" seen.add(it)\n",
|
||||||
|
" out.append(it)\n",
|
||||||
|
" return out\n",
|
||||||
|
"\n",
|
||||||
|
" def _neighbor_ids_for_hit(self, meta: Dict[str, Any]) -> List[str]:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Build vector-store IDs for the hit's chunk and its ±radius neighbors.\n",
|
||||||
|
" We rely on the indexing convention: id == f\"{source}-{chunk_id}\".\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" src = meta[\"source\"]\n",
|
||||||
|
" cid = int(meta[\"chunk_id\"])\n",
|
||||||
|
" ids: List[str] = []\n",
|
||||||
|
" for d in range(-self.radius, self.radius + 1):\n",
|
||||||
|
" n = cid + d\n",
|
||||||
|
" if n < 0:\n",
|
||||||
|
" continue\n",
|
||||||
|
" ids.append(f\"{src}-{n}\")\n",
|
||||||
|
" return ids\n",
|
||||||
|
"\n",
|
||||||
|
" def expand_and_merge(self, hits: List[Dict]) -> List[Dict]:\n",
|
||||||
|
" # 1) Collect all candidate ids (hits + ±neighbors), dedupe, keep order\n",
|
||||||
|
" all_ids: List[str] = []\n",
|
||||||
|
" for h in hits:\n",
|
||||||
|
" m = h[\"meta\"]\n",
|
||||||
|
" hit_id = m.get(\"id\") or f\"{m['source']}-{int(m['chunk_id'])}\"\n",
|
||||||
|
" all_ids.append(hit_id)\n",
|
||||||
|
" all_ids.extend(self._neighbor_ids_for_hit(m))\n",
|
||||||
|
"\n",
|
||||||
|
" all_ids = self._dedupe_keep_order(all_ids)\n",
|
||||||
|
" self.last_ids = all_ids[:] # capture for logging/inspection\n",
|
||||||
|
"\n",
|
||||||
|
" # 2) Fetch metas for those ids (order-preserving as much as possible)\n",
|
||||||
|
" _, metas = self.index.get_by_ids(all_ids)\n",
|
||||||
|
"\n",
|
||||||
|
" # 3) Build padded token spans per source from metas\n",
|
||||||
|
" spans_by_src: Dict[str, List[Tuple[int, int]]] = {}\n",
|
||||||
|
" for m in metas:\n",
|
||||||
|
" if not m:\n",
|
||||||
|
" continue\n",
|
||||||
|
" src = m[\"source\"]\n",
|
||||||
|
" s_tok = int(m[\"start_tok\"])\n",
|
||||||
|
" e_tok = int(m[\"end_tok\"])\n",
|
||||||
|
" ps = max(0, s_tok - self.pad_tokens)\n",
|
||||||
|
" pe = e_tok + self.pad_tokens\n",
|
||||||
|
" spans_by_src.setdefault(src, []).append((ps, pe))\n",
|
||||||
|
"\n",
|
||||||
|
" # 4) Merge, slice, and construct blocks (capped by max_blocks)\n",
|
||||||
|
" blocks: List[Dict] = []\n",
|
||||||
|
" for src, spans in spans_by_src.items():\n",
|
||||||
|
" merged_tok_spans = self._merge_spans(spans)\n",
|
||||||
|
"\n",
|
||||||
|
" full_doc = self.index.raw_doc(src)\n",
|
||||||
|
" # Tokenize & get offsets once per source\n",
|
||||||
|
" enc = self.index.tokenizer(\n",
|
||||||
|
" full_doc, add_special_tokens=False, return_offsets_mapping=True\n",
|
||||||
|
" )\n",
|
||||||
|
" ids = enc[\"input_ids\"]\n",
|
||||||
|
" offs = enc[\"offset_mapping\"]\n",
|
||||||
|
"\n",
|
||||||
|
" for s_tok, e_tok in merged_tok_spans:\n",
|
||||||
|
" if not ids:\n",
|
||||||
|
" continue\n",
|
||||||
|
" s_tok = max(0, min(s_tok, len(ids)))\n",
|
||||||
|
" e_tok = max(0, min(e_tok, len(ids)))\n",
|
||||||
|
" if e_tok <= s_tok:\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" # token span -> char span\n",
|
||||||
|
" s_char = offs[s_tok][0]\n",
|
||||||
|
" e_char = offs[e_tok - 1][1]\n",
|
||||||
|
" text = full_doc[s_char:e_char]\n",
|
||||||
|
"\n",
|
||||||
|
" blocks.append(\n",
|
||||||
|
" {\n",
|
||||||
|
" \"source\": src,\n",
|
||||||
|
" \"start\": int(s_tok),\n",
|
||||||
|
" \"end\": int(e_tok),\n",
|
||||||
|
" \"text\": text,\n",
|
||||||
|
" }\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" blocks.sort(key=lambda b: (b[\"source\"], b[\"start\"]))\n",
|
||||||
|
" return blocks[: self.max_blocks]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ae074d00",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class PromptBuilder:\n",
|
||||||
|
" SYSTEM = (\n",
|
||||||
|
" \"You are a precise legal assistant for Indian Bare Acts.\\n\"\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def build_user(question: str, blocks: List[Dict]) -> str:\n",
|
||||||
|
" ctx = \"\\n\\n---\\n\\n\".join(\n",
|
||||||
|
" f\"[{b['source']} {b['start']}:{b['end']}]\\n{b['text']}\" for b in blocks\n",
|
||||||
|
" )\n",
|
||||||
|
" return (\n",
|
||||||
|
" f\"Question:\\n{question}\\n\\n\"\n",
|
||||||
|
" f\"Context (use only this):\\n{ctx}\\n\\n\"\n",
|
||||||
|
" \"Instructions:\\n\"\n",
|
||||||
|
" \"- Describe the context whether it has information to answer the question or not or whether any part can be used to answer the question.\\n\"\n",
|
||||||
|
" \"- Answer the original question in new paragraph with label Answer:\\n\"\n",
|
||||||
|
" \"- Quote or paraphrase ONLY from the context above.\\n\"\n",
|
||||||
|
" \"- Inline-cite with [source start:end] when using any snippet.\\n\"\n",
|
||||||
|
" \"- If the answer is not in context, please describe what the agentic rag found related to original question and try to tell answer based on the data.\"\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "fc8ca7b2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class Answerer:\n",
|
||||||
|
" def __init__(self, llm: LLM, set: Settings):\n",
|
||||||
|
" self.llm = llm\n",
|
||||||
|
" self.set = set\n",
|
||||||
|
"\n",
|
||||||
|
" def answer(self, question: str, blocks: List[Dict], model: str | None = None) -> str:\n",
|
||||||
|
" user = PromptBuilder.build_user(question, blocks)\n",
|
||||||
|
" return self.llm.complete(\n",
|
||||||
|
" system=PromptBuilder.SYSTEM,\n",
|
||||||
|
" user=user,\n",
|
||||||
|
" temperature=self.set.temperature,\n",
|
||||||
|
" max_tokens=self.set.max_tokens,\n",
|
||||||
|
" model=model or self.set.gen_model,\n",
|
||||||
|
" )\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8a33d665",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class LegalAgent:\n",
|
||||||
|
" def __init__(self, expander: QueryExpander, retriever: Retriever, ctx_expander: ContextExpander, answerer: Answerer, set: Settings):\n",
|
||||||
|
" self.expander = expander\n",
|
||||||
|
" self.retriever = retriever\n",
|
||||||
|
" self.ctx_expander = ctx_expander\n",
|
||||||
|
" self.answerer = answerer\n",
|
||||||
|
" self.set = set\n",
|
||||||
|
"\n",
|
||||||
|
" def run(self, question: str, model: str | None = None) -> str:\n",
|
||||||
|
" rewrites = self.expander.expand(question, model_override=model)\n",
|
||||||
|
" hits: List[Dict] = []\n",
|
||||||
|
" for q in rewrites:\n",
|
||||||
|
" hits.extend(self.retriever.topk(q))\n",
|
||||||
|
" blocks = self.ctx_expander.expand_and_merge(hits)\n",
|
||||||
|
" return self.answerer.answer(question, blocks, model=model)\n",
|
||||||
|
"\n",
|
||||||
|
" def run_stream(self, question: str, model: str | None = None):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Generator: yields tuples (answer_or_none, logs_html) multiple times.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" log = RunLogger()\n",
|
||||||
|
" log.add(f\"Question: {question}\")\n",
|
||||||
|
" yield None, log.html()\n",
|
||||||
|
"\n",
|
||||||
|
" # 1) Expand queries\n",
|
||||||
|
" rewrites = self.expander.expand(question, model_override=model)\n",
|
||||||
|
" if getattr(self.expander, \"last_info\", {}).get(\"used_fallback\", False):\n",
|
||||||
|
" log.add(\"Query expansion: LLM output unparsable → using deterministic fallback.\")\n",
|
||||||
|
" log.add(f\"Expanded into {len(rewrites)} queries:\")\n",
|
||||||
|
" for i, q in enumerate(rewrites, 1):\n",
|
||||||
|
" log.add(f\" {i}. {q}\")\n",
|
||||||
|
" yield None, log.html()\n",
|
||||||
|
"\n",
|
||||||
|
" # 2) Retrieve top-k per rewrite\n",
|
||||||
|
" all_hits: List[Dict] = []\n",
|
||||||
|
" for i, q in enumerate(rewrites, 1):\n",
|
||||||
|
" hits = self.retriever.topk(q)\n",
|
||||||
|
" all_hits.extend(hits)\n",
|
||||||
|
" top3 = \", \".join(_hit_id(h) for h in hits[:3]) or \"—\"\n",
|
||||||
|
" log.add(f\"Retrieval {i}/{len(rewrites)}: got {len(hits)} hits → {top3}\")\n",
|
||||||
|
" yield None, log.html()\n",
|
||||||
|
"\n",
|
||||||
|
" # 3) Context expansion / merging (with neighbor ids logged)\n",
|
||||||
|
" blocks = self.ctx_expander.expand_and_merge(all_hits)\n",
|
||||||
|
" used = self.ctx_expander.last_ids\n",
|
||||||
|
" peek = \", \".join(used[:8]) + (\" …\" if len(used) > 8 else \"\")\n",
|
||||||
|
" log.add(f\"Neighbor addition: collected {len(used)} chunk-ids → {peek}\")\n",
|
||||||
|
"\n",
|
||||||
|
" approx_words = int(self.ctx_expander.pad_tokens / 1.4) # inverse of words_to_tokens≈1.4\n",
|
||||||
|
" log.add(\n",
|
||||||
|
" f\"Context expansion: merged {len(blocks)} block(s) \"\n",
|
||||||
|
" f\"(radius ±{self.ctx_expander.radius}, pad ≈{approx_words} words).\"\n",
|
||||||
|
" )\n",
|
||||||
|
" for b in blocks:\n",
|
||||||
|
" log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n",
|
||||||
|
" log.add(b[\"text\"])\n",
|
||||||
|
" yield None, log.html()\n",
|
||||||
|
"\n",
|
||||||
|
" # 4) LLM answer\n",
|
||||||
|
" log.add(f\"Asking LLM: {model or self.set.gen_model}\")\n",
|
||||||
|
" yield None, log.html()\n",
|
||||||
|
" answer = self.answerer.answer(question, blocks, model=model)\n",
|
||||||
|
" log.add(\"Answer ready.\")\n",
|
||||||
|
" yield answer, log.html()\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d9e57320",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def make_agent(gen_model: str) -> LegalAgent:\n",
|
||||||
|
" llm = OpenAILLM(gen_model)\n",
|
||||||
|
" expander = QueryExpander(llm=llm, n=SET.expansions)\n",
|
||||||
|
" retriever = Retriever(index=index, k=SET.topk_per_rewrite)\n",
|
||||||
|
" ctx_expander = ContextExpander(index=index, radius=SET.neighbor_radius, max_blocks=SET.max_blocks_for_llm)\n",
|
||||||
|
" answerer = Answerer(llm=llm, set=SET)\n",
|
||||||
|
" return LegalAgent(expander, retriever, ctx_expander, answerer, SET)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "837df704",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"agent = make_agent(SET.gen_model)\n",
|
||||||
|
"print(\"Agent ready (global).\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d1665aba",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import gradio as gr\n",
|
||||||
|
"\n",
|
||||||
|
"MODEL_CHOICES = [\n",
|
||||||
|
" \"gpt-4o-mini\", # default\n",
|
||||||
|
" \"gpt-4o\",\n",
|
||||||
|
" \"gpt-5\",\n",
|
||||||
|
" \"gpt-5-nano\",\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"DEFAULT_Q = \"is death by ignorance like giving medicine to old person amounts to a murder ?\"\n",
|
||||||
|
"\n",
|
||||||
|
"def chat_stream(chat_history: List[Tuple[str, str]], message: str, model_choice: str, topk: int, radius: int, max_blocks: int):\n",
|
||||||
|
" # Always create a fresh agent per request (avoids NameError & state bleed)\n",
|
||||||
|
" _agent = make_agent(model_choice)\n",
|
||||||
|
" _agent.retriever.k = int(topk)\n",
|
||||||
|
" _agent.ctx_expander.radius = int(radius)\n",
|
||||||
|
" _agent.ctx_expander.max_blocks = int(max_blocks)\n",
|
||||||
|
"\n",
|
||||||
|
" # Initialize history and add a placeholder response\n",
|
||||||
|
" if chat_history is None:\n",
|
||||||
|
" chat_history = []\n",
|
||||||
|
" chat_history = chat_history + [(message, \"🛠️ running pipeline…\")]\n",
|
||||||
|
" yield chat_history, \"<i>starting…</i>\"\n",
|
||||||
|
"\n",
|
||||||
|
" # Stream logs + final answer\n",
|
||||||
|
" for ans, logs_html in _agent.run_stream(message, model=model_choice):\n",
|
||||||
|
" assistant_text = ans if ans is not None else \"⏳ working…\"\n",
|
||||||
|
" chat_history[-1] = (message, assistant_text)\n",
|
||||||
|
" # logs_html already contains inline styles; just pass it through\n",
|
||||||
|
" yield chat_history, logs_html\n",
|
||||||
|
"\n",
|
||||||
|
"with gr.Blocks(title=\"Week 8 • Legal QnA RAG (Agentic)\") as app:\n",
|
||||||
|
" gr.Markdown(\"### 🧑⚖️ Legal Q&A on Bare Acts — Agentic RAG (Week 8)\")\n",
|
||||||
|
" gr.Markdown(\"Flow: **query expansion → multi-retrieval → context expansion (±neighbors, merged) → LLM answer**\")\n",
|
||||||
|
"\n",
|
||||||
|
" with gr.Row():\n",
|
||||||
|
" # Left: classic chat\n",
|
||||||
|
" with gr.Column(scale=3):\n",
|
||||||
|
" chatbot = gr.Chatbot(label=\"Chat\", height=420)\n",
|
||||||
|
" with gr.Row():\n",
|
||||||
|
" msg = gr.Textbox(value=DEFAULT_Q, label=\"Ask a legal question\", scale=5)\n",
|
||||||
|
" send = gr.Button(\"Send\", variant=\"primary\", scale=1)\n",
|
||||||
|
"\n",
|
||||||
|
" # Right: logs panel (simple header + HTML)\n",
|
||||||
|
" with gr.Column(scale=2):\n",
|
||||||
|
" gr.Markdown(\"#### Agent Logs\")\n",
|
||||||
|
" logs_html = gr.HTML(value=\"<div style='font-family:ui-monospace,Menlo,Consolas;white-space:pre-wrap;border:1px solid #ddd;border-radius:8px;padding:10px;height:380px;overflow:auto;background:#fafafa'>Idle</div>\")\n",
|
||||||
|
"\n",
|
||||||
|
" with gr.Accordion(\"Advanced\", open=False):\n",
|
||||||
|
" with gr.Row():\n",
|
||||||
|
" model_dd = gr.Dropdown(choices=MODEL_CHOICES, value=SET.gen_model, label=\"LLM Model\", scale=2)\n",
|
||||||
|
" topk = gr.Slider(2, 20, value=SET.topk_per_rewrite, step=1, label=\"Top-K per rewrite\", scale=2)\n",
|
||||||
|
" radius = gr.Slider(0, 4, value=SET.neighbor_radius, step=1, label=\"Neighbor radius (±)\", scale=2)\n",
|
||||||
|
" cap = gr.Slider(4, 60, value=SET.max_blocks_for_llm, step=1, label=\"Max merged blocks for LLM\", scale=2)\n",
|
||||||
|
"\n",
|
||||||
|
" # Wire streaming generator to both Send click and Enter submit\n",
|
||||||
|
" send.click(\n",
|
||||||
|
" fn=chat_stream,\n",
|
||||||
|
" inputs=[chatbot, msg, model_dd, topk, radius, cap],\n",
|
||||||
|
" outputs=[chatbot, logs_html],\n",
|
||||||
|
" show_progress=False,\n",
|
||||||
|
" ).then(lambda: \"\", None, msg) # clear textbox\n",
|
||||||
|
"\n",
|
||||||
|
" msg.submit(\n",
|
||||||
|
" fn=chat_stream,\n",
|
||||||
|
" inputs=[chatbot, msg, model_dd, topk, radius, cap],\n",
|
||||||
|
" outputs=[chatbot, logs_html],\n",
|
||||||
|
" show_progress=False,\n",
|
||||||
|
" ).then(lambda: \"\", None, msg) # clear textbox\n",
|
||||||
|
"\n",
|
||||||
|
"app.launch(inbrowser=True)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "llm-engineering",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user