add Modal+Qwen3 query expansion and self-critique agentic loop

This commit is contained in:
Nik
2025-10-30 21:40:02 +05:30
parent d198b55038
commit 8ed3a97275
3 changed files with 315 additions and 9 deletions

View File

@@ -0,0 +1,49 @@
# Agentic Legal Q&A on Bare Acts (Week 8)
An **agentic RAG** demo that answers legal questions from Indian Bare Acts (IPC/BNS/Constitution).
Pipeline: **Query expansion (Modal+Qwen3) → Multi-retrieval (Chroma) → Neighbor-aware context merge → LLM answer → Self-critique → Optional second pass**.
UI: lightweight **Gradio** chat with live agent logs.
## Features
- **Modal-first expander:** `modal_expander.py` (Qwen3-4B via vLLM, GPU) with local LLM fallback.
- **Vector store:** Chroma + `all-MiniLM-L6-v2`, token-span aware chunking, ±neighbor merge.
- **Agentic loop:** critic validates citations and triggers follow-up retrievals if needed.
- **Config knobs:** top-k per rewrite, neighbor radius, max merged blocks, model dropdown.
## Setup
```bash
python -m pip install -U openai chromadb transformers gradio python-dotenv modal
````
Create `.env` with your keys:
```bash
OPENAI_API_KEY=...
```
Place Bare Acts as UTF-8 `.txt` files in:
```
knowledge_base/bare_acts/ # e.g., ipc.txt, bns.txt, coi.txt
```
## Deploy the Modal expander
Set a Modal secret named `huggingface-secret` containing `HUGGINGFACE_HUB_TOKEN`, then:
```bash
modal deploy -m modal_expander
```
## Run the notebook app
```bash
jupyter notebook agentic_legal_qna_with_rag_on_bare_acts.ipynb
```
Run all cells; a Gradio chat appears. Tune **Top-K**, **Neighbor radius**, and **Max blocks** under *Advanced*.
## Notes
* Default OpenAI model: `gpt-4o-mini` (change via UI).
* Vector DB is persisted in `vector_db_w8`; re-run the indexing cell to rebuild after data changes.

View File

@@ -13,9 +13,10 @@
"from dataclasses import dataclass, field\n", "from dataclasses import dataclass, field\n",
"from typing import List, Dict, Any, Tuple, Protocol\n", "from typing import List, Dict, Any, Tuple, Protocol\n",
"\n", "\n",
"\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"\n", "\n",
"from dotenv import load_dotenv\n", "\n",
"from openai import OpenAI\n", "from openai import OpenAI\n",
"import chromadb\n", "import chromadb\n",
"from chromadb import PersistentClient\n", "from chromadb import PersistentClient\n",
@@ -24,6 +25,7 @@
"from transformers import AutoTokenizer\n", "from transformers import AutoTokenizer\n",
"import gradio as gr\n", "import gradio as gr\n",
"\n", "\n",
"from dotenv import load_dotenv\n",
"load_dotenv(override=True)" "load_dotenv(override=True)"
] ]
}, },
@@ -735,6 +737,73 @@
" )\n" " )\n"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "8284580f",
"metadata": {},
"outputs": [],
"source": [
"class Critic:\n",
" def __init__(self, llm: LLM):\n",
" self.llm = llm\n",
"\n",
" def review(self, question: str, answer: str, blocks: List[Dict]) -> Dict[str, Any]:\n",
" block_ids = [f\"[{b['source']} {b['start']}:{b['end']}]\" for b in blocks]\n",
" sys = (\n",
" \"You are a meticulous legal verifier. \"\n",
" \"Return ONLY JSON with keys: ok (bool), missing (list of short missing facts), followups (list of short retrieval keyphrases).\"\n",
" )\n",
" user = f\"\"\"Question:\n",
"{question}\n",
"\n",
"Proposed answer:\n",
"{answer}\n",
"\n",
"Verify that every factual claim is supported by the context blocks (by their ids).\n",
"If support is weak or missing, set ok=false and propose concise retrieval keyphrases in followups.\n",
"\n",
"Available context block ids:\n",
"{\", \".join(block_ids)}\n",
"\n",
"Return JSON only, e.g.:\n",
"{{\"ok\": true, \"missing\": [], \"followups\": []}}\n",
"\"\"\"\n",
" raw = self.llm.complete(system=sys, user=user, temperature=0.0, max_tokens=220)\n",
" try:\n",
" m = re.search(r\"\\{(.|\\n)*\\}\", raw)\n",
" return json.loads(m.group(0)) if m else {\"ok\": True, \"missing\": [], \"followups\": []}\n",
" except Exception:\n",
" return {\"ok\": True, \"missing\": [], \"followups\": []}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34449337",
"metadata": {},
"outputs": [],
"source": [
"import modal\n",
"_remote_expand = modal.Function.from_name(\"legal-query-expander-qwen3-v2\", \"expand\")\n",
"\n",
"\n",
"class ModalFirstExpander(QueryExpander):\n",
" def expand(self, question: str, *, model_override: str | None = None) -> List[str]:\n",
" got = []\n",
" if _remote_expand:\n",
" try:\n",
" got = _remote_expand.remote(question, self.n)\n",
" except Exception:\n",
" got = []\n",
" if not got or len([x for x in got if isinstance(x, str) and x.strip()]) < max(1, self.n // 2):\n",
" return super().expand(question, model_override=model_override)\n",
" import re\n",
" got = [re.sub(r\"[^\\w\\s\\-./]\", \"\", q).strip() for q in got]\n",
" return [q for q in got if q][: self.n]\n",
"\n"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@@ -751,16 +820,33 @@
" self.set = set\n", " self.set = set\n",
"\n", "\n",
" def run(self, question: str, model: str | None = None) -> str:\n", " def run(self, question: str, model: str | None = None) -> str:\n",
" # Pass 1\n",
" rewrites = self.expander.expand(question, model_override=model)\n", " rewrites = self.expander.expand(question, model_override=model)\n",
" hits: List[Dict] = []\n", " hits: List[Dict] = []\n",
" for q in rewrites:\n", " for q in rewrites:\n",
" hits.extend(self.retriever.topk(q))\n", " hits.extend(self.retriever.topk(q))\n",
" blocks = self.ctx_expander.expand_and_merge(hits)\n", " blocks = self.ctx_expander.expand_and_merge(hits)\n",
" return self.answerer.answer(question, blocks, model=model)\n", " answer1 = self.answerer.answer(question, blocks, model=model)\n",
"\n",
" # Self-critique\n",
" critic = Critic(self.answerer.llm)\n",
" review = critic.review(question, answer1, blocks)\n",
"\n",
" if review.get(\"ok\", True) or not review.get(\"followups\"):\n",
" return answer1 # Good enough\n",
"\n",
" # Pass 2 — adapt plan using follow-up rewrites\n",
" extra_hits: List[Dict] = []\n",
" for q in review[\"followups\"][:3]: # keep bounded\n",
" extra_hits.extend(self.retriever.topk(q))\n",
" blocks2 = self.ctx_expander.expand_and_merge(hits + extra_hits)\n",
" answer2 = self.answerer.answer(question, blocks2, model=model)\n",
" return answer2 + \"\\n\\n_Refined after self-critique (agentic step)._\"\n",
"\n", "\n",
" def run_stream(self, question: str, model: str | None = None):\n", " def run_stream(self, question: str, model: str | None = None):\n",
" \"\"\"\n", " \"\"\"\n",
" Generator: yields tuples (answer_or_none, logs_html) multiple times.\n", " Generator: yields tuples (answer_or_none, logs_html) multiple times.\n",
" Now includes a self-critique step and (if needed) a second pass, with logs.\n",
" \"\"\"\n", " \"\"\"\n",
" log = RunLogger()\n", " log = RunLogger()\n",
" log.add(f\"Question: {question}\")\n", " log.add(f\"Question: {question}\")\n",
@@ -797,15 +883,66 @@
" )\n", " )\n",
" for b in blocks:\n", " for b in blocks:\n",
" log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n", " log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n",
" log.add(b[\"text\"])\n", " log.add(b[\"text\"][:50].replace(\"\\n\", \" \"))\n",
" yield None, log.html()\n", " yield None, log.html()\n",
"\n", "\n",
" # 4) LLM answer\n", " # 4) LLM answer — pass 1\n",
" log.add(f\"Asking LLM: {model or self.set.gen_model}\")\n", " log.add(f\"Asking LLM (pass 1): {model or self.set.gen_model}\")\n",
" yield None, log.html()\n", " yield None, log.html()\n",
" answer = self.answerer.answer(question, blocks, model=model)\n", " answer1 = self.answerer.answer(question, blocks, model=model)\n",
" log.add(\"Answer ready.\")\n", " log.add(\"Answer 1 ready.\")\n",
" yield answer, log.html()\n" " yield answer1, log.html()\n",
"\n",
" # 5) Self-critique\n",
" log.add(\"Running self-critique…\")\n",
" yield None, log.html()\n",
" critic = Critic(self.answerer.llm)\n",
" review = critic.review(question, answer1, blocks)\n",
" ok = review.get(\"ok\", True)\n",
" missing = review.get(\"missing\") or []\n",
" followups = review.get(\"followups\") or []\n",
" # compact display\n",
" def _preview(lst, n=5):\n",
" return \", \".join(lst[:n]) + (f\" … (+{len(lst)-n} more)\" if len(lst) > n else \"\") if lst else \"—\"\n",
" log.add(f\"Self-critique result: ok={ok}; missing={_preview(missing)}; followups={_preview(followups)}\")\n",
" yield None, log.html()\n",
"\n",
" if ok or not followups:\n",
" log.add(\"Critique passed or no follow-ups. Finalizing pass 1.\")\n",
" yield answer1, log.html()\n",
" return\n",
"\n",
" # 6) Pass 2 — follow-up retrievals\n",
" extra_hits: List[Dict] = []\n",
" limited = followups[:3] # keep bounded\n",
" for i, q in enumerate(limited, 1):\n",
" hits = self.retriever.topk(q)\n",
" extra_hits.extend(hits)\n",
" top3 = \", \".join(_hit_id(h) for h in hits[:3]) or \"—\"\n",
" log.add(f\"Follow-up retrieval {i}/{len(limited)}: got {len(hits)} hits → {top3}\")\n",
" yield None, log.html()\n",
"\n",
" blocks2 = self.ctx_expander.expand_and_merge(all_hits + extra_hits)\n",
" used2 = self.ctx_expander.last_ids\n",
" peek2 = \", \".join(used2[:8]) + (\" …\" if len(used2) > 8 else \"\")\n",
" log.add(f\"Neighbor addition (pass 2): collected {len(used2)} chunk-ids → {peek2}\")\n",
"\n",
" approx_words2 = int(self.ctx_expander.pad_tokens / 1.4)\n",
" log.add(\n",
" f\"Context expansion (pass 2): merged {len(blocks2)} block(s) \"\n",
" f\"(radius ±{self.ctx_expander.radius}, pad ≈{approx_words2} words).\"\n",
" )\n",
" for b in blocks2[:10]: # dont spam the log\n",
" log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n",
" yield None, log.html()\n",
"\n",
" # 7) LLM answer — pass 2\n",
" log.add(f\"Asking LLM (pass 2): {model or self.set.gen_model}\")\n",
" yield None, log.html()\n",
" answer2 = self.answerer.answer(question, blocks2, model=model)\n",
" final_answer = answer2 + \"\\n\\n_Refined after self-critique (agentic step)._\"\n",
" log.add(\"Answer 2 ready (refined).\")\n",
" yield final_answer, log.html()\n"
] ]
}, },
{ {
@@ -817,7 +954,7 @@
"source": [ "source": [
"def make_agent(gen_model: str) -> LegalAgent:\n", "def make_agent(gen_model: str) -> LegalAgent:\n",
" llm = OpenAILLM(gen_model)\n", " llm = OpenAILLM(gen_model)\n",
" expander = QueryExpander(llm=llm, n=SET.expansions)\n", " expander = ModalFirstExpander(llm=llm, n=SET.expansions) # <— here\n",
" retriever = Retriever(index=index, k=SET.topk_per_rewrite)\n", " retriever = Retriever(index=index, k=SET.topk_per_rewrite)\n",
" ctx_expander = ContextExpander(index=index, radius=SET.neighbor_radius, max_blocks=SET.max_blocks_for_llm)\n", " ctx_expander = ContextExpander(index=index, radius=SET.neighbor_radius, max_blocks=SET.max_blocks_for_llm)\n",
" answerer = Answerer(llm=llm, set=SET)\n", " answerer = Answerer(llm=llm, set=SET)\n",

View File

@@ -0,0 +1,120 @@
# week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/modal_expander.py
import os, json, re
from typing import List
import modal
# minimal image: vLLM + torch + HF hub
image = (
modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12")
.entrypoint([])
.uv_pip_install(
"vllm==0.10.2",
"torch==2.8.0",
"huggingface_hub[hf_transfer]==0.35.0",
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
)
app = modal.App("legal-query-expander-qwen3-v2", image=image)
MODEL_NAME = "Qwen/Qwen3-4B-Instruct" # use instruct, defaults everywhere
_llm = None # warm-container cache
def _extract_json_array(text: str) -> List[str]:
try:
parsed = json.loads(text)
return [x for x in parsed if isinstance(x, str)]
except Exception:
pass
m = re.search(r"\[(?:.|\n|\r)*\]", text)
if m:
try:
parsed = json.loads(m.group(0))
return [x for x in parsed if isinstance(x, str)]
except Exception:
return []
return []
def _sanitize_and_dedupe(items: List[str], n: int) -> List[str]:
out, seen = [], set()
for q in items:
q = re.sub(r"[^\w\s\-./]", "", (q or "")).strip()
k = q.lower()
if q and k not in seen:
seen.add(k)
out.append(q)
if len(out) >= n:
break
return out
@app.function(
image=image,
gpu=modal.gpu.L4(), # pick any available GPU (A100/H100 also fine)
timeout=600,
secrets=[modal.Secret.from_name("huggingface-secret")], # set HF token here
)
def expand(question: str, n: int = 5) -> List[str]:
"""
Return up to n short, diverse retrieval keyphrases for Bare Acts.
Uses Qwen3-4B-Instruct with its default chat template.
"""
global _llm
from vllm import LLM, SamplingParams
# ensure HF token is available to vLLM
tok = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
if tok and not os.environ.get("HUGGINGFACE_HUB_TOKEN"):
os.environ["HUGGINGFACE_HUB_TOKEN"] = tok
if _llm is None:
_llm = LLM(
model=MODEL_NAME,
trust_remote_code=True,
dtype="auto",
tensor_parallel_size=1,
)
user = (
"You are Search Query Expander."
"For given search query you give 4-5 different variants of it to search the database better. It is a legal search query and our database is of legal data like bare acts."
"Respond ONLY as a JSON array of strings; no prose, no section numbers."
f"Question:\n{question}\n\n"
f"Return {n} distinct keyphrases (420 words each), which captures the what to search inside rag database. Return as a JSON array. No commentary."
)
messages = [
{"role": "user", "content": user},
]
tokenizer = _llm.get_tokenizer()
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
result = _llm.generate(
[prompt],
SamplingParams(
max_tokens=256,
temperature=0.2,
),
)
text = result[0].outputs[0].text
arr = _sanitize_and_dedupe(_extract_json_array(text), n)
if not arr:
# deterministic fallback (keeps things non-empty)
base = re.sub(r"[?]+$", "", (question or "")).strip()
pool = [
f"{base} section",
f"{base} provision bare act",
f"{base} indian penal code",
f"{base} bharatiya nyaya sanhita",
f"{base} punishment section keywords",
]
arr = _sanitize_and_dedupe(pool, n)
return arr