add Modal+Qwen3 query expansion and self-critique agentic loop

This commit is contained in:
Nik
2025-10-30 21:40:02 +05:30
parent d198b55038
commit 8ed3a97275
3 changed files with 315 additions and 9 deletions

View File

@@ -0,0 +1,49 @@
# Agentic Legal Q&A on Bare Acts (Week 8)
An **agentic RAG** demo that answers legal questions from Indian Bare Acts (IPC/BNS/Constitution).
Pipeline: **Query expansion (Modal+Qwen3) → Multi-retrieval (Chroma) → Neighbor-aware context merge → LLM answer → Self-critique → Optional second pass**.
UI: lightweight **Gradio** chat with live agent logs.
## Features
- **Modal-first expander:** `modal_expander.py` (Qwen3-4B via vLLM, GPU) with local LLM fallback.
- **Vector store:** Chroma + `all-MiniLM-L6-v2`, token-span aware chunking, ±neighbor merge.
- **Agentic loop:** critic validates citations and triggers follow-up retrievals if needed.
- **Config knobs:** top-k per rewrite, neighbor radius, max merged blocks, model dropdown.
## Setup
```bash
python -m pip install -U openai chromadb transformers gradio python-dotenv modal
````
Create `.env` with your keys:
```bash
OPENAI_API_KEY=...
```
Place Bare Acts as UTF-8 `.txt` files in:
```
knowledge_base/bare_acts/ # e.g., ipc.txt, bns.txt, coi.txt
```
## Deploy the Modal expander
Set a Modal secret named `huggingface-secret` containing `HUGGINGFACE_HUB_TOKEN`, then:
```bash
modal deploy -m modal_expander
```
## Run the notebook app
```bash
jupyter notebook agentic_legal_qna_with_rag_on_bare_acts.ipynb
```
Run all cells; a Gradio chat appears. Tune **Top-K**, **Neighbor radius**, and **Max blocks** under *Advanced*.
## Notes
* Default OpenAI model: `gpt-4o-mini` (change via UI).
* Vector DB is persisted in `vector_db_w8`; re-run the indexing cell to rebuild after data changes.

View File

@@ -13,9 +13,10 @@
"from dataclasses import dataclass, field\n",
"from typing import List, Dict, Any, Tuple, Protocol\n",
"\n",
"\n",
"from pathlib import Path\n",
"\n",
"from dotenv import load_dotenv\n",
"\n",
"from openai import OpenAI\n",
"import chromadb\n",
"from chromadb import PersistentClient\n",
@@ -24,6 +25,7 @@
"from transformers import AutoTokenizer\n",
"import gradio as gr\n",
"\n",
"from dotenv import load_dotenv\n",
"load_dotenv(override=True)"
]
},
@@ -735,6 +737,73 @@
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8284580f",
"metadata": {},
"outputs": [],
"source": [
"class Critic:\n",
" def __init__(self, llm: LLM):\n",
" self.llm = llm\n",
"\n",
" def review(self, question: str, answer: str, blocks: List[Dict]) -> Dict[str, Any]:\n",
" block_ids = [f\"[{b['source']} {b['start']}:{b['end']}]\" for b in blocks]\n",
" sys = (\n",
" \"You are a meticulous legal verifier. \"\n",
" \"Return ONLY JSON with keys: ok (bool), missing (list of short missing facts), followups (list of short retrieval keyphrases).\"\n",
" )\n",
" user = f\"\"\"Question:\n",
"{question}\n",
"\n",
"Proposed answer:\n",
"{answer}\n",
"\n",
"Verify that every factual claim is supported by the context blocks (by their ids).\n",
"If support is weak or missing, set ok=false and propose concise retrieval keyphrases in followups.\n",
"\n",
"Available context block ids:\n",
"{\", \".join(block_ids)}\n",
"\n",
"Return JSON only, e.g.:\n",
"{{\"ok\": true, \"missing\": [], \"followups\": []}}\n",
"\"\"\"\n",
" raw = self.llm.complete(system=sys, user=user, temperature=0.0, max_tokens=220)\n",
" try:\n",
" m = re.search(r\"\\{(.|\\n)*\\}\", raw)\n",
" return json.loads(m.group(0)) if m else {\"ok\": True, \"missing\": [], \"followups\": []}\n",
" except Exception:\n",
" return {\"ok\": True, \"missing\": [], \"followups\": []}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34449337",
"metadata": {},
"outputs": [],
"source": [
"import modal\n",
"_remote_expand = modal.Function.from_name(\"legal-query-expander-qwen3-v2\", \"expand\")\n",
"\n",
"\n",
"class ModalFirstExpander(QueryExpander):\n",
" def expand(self, question: str, *, model_override: str | None = None) -> List[str]:\n",
" got = []\n",
" if _remote_expand:\n",
" try:\n",
" got = _remote_expand.remote(question, self.n)\n",
" except Exception:\n",
" got = []\n",
" if not got or len([x for x in got if isinstance(x, str) and x.strip()]) < max(1, self.n // 2):\n",
" return super().expand(question, model_override=model_override)\n",
" import re\n",
" got = [re.sub(r\"[^\\w\\s\\-./]\", \"\", q).strip() for q in got]\n",
" return [q for q in got if q][: self.n]\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -751,16 +820,33 @@
" self.set = set\n",
"\n",
" def run(self, question: str, model: str | None = None) -> str:\n",
" # Pass 1\n",
" rewrites = self.expander.expand(question, model_override=model)\n",
" hits: List[Dict] = []\n",
" for q in rewrites:\n",
" hits.extend(self.retriever.topk(q))\n",
" blocks = self.ctx_expander.expand_and_merge(hits)\n",
" return self.answerer.answer(question, blocks, model=model)\n",
" answer1 = self.answerer.answer(question, blocks, model=model)\n",
"\n",
" # Self-critique\n",
" critic = Critic(self.answerer.llm)\n",
" review = critic.review(question, answer1, blocks)\n",
"\n",
" if review.get(\"ok\", True) or not review.get(\"followups\"):\n",
" return answer1 # Good enough\n",
"\n",
" # Pass 2 — adapt plan using follow-up rewrites\n",
" extra_hits: List[Dict] = []\n",
" for q in review[\"followups\"][:3]: # keep bounded\n",
" extra_hits.extend(self.retriever.topk(q))\n",
" blocks2 = self.ctx_expander.expand_and_merge(hits + extra_hits)\n",
" answer2 = self.answerer.answer(question, blocks2, model=model)\n",
" return answer2 + \"\\n\\n_Refined after self-critique (agentic step)._\"\n",
"\n",
" def run_stream(self, question: str, model: str | None = None):\n",
" \"\"\"\n",
" Generator: yields tuples (answer_or_none, logs_html) multiple times.\n",
" Now includes a self-critique step and (if needed) a second pass, with logs.\n",
" \"\"\"\n",
" log = RunLogger()\n",
" log.add(f\"Question: {question}\")\n",
@@ -797,15 +883,66 @@
" )\n",
" for b in blocks:\n",
" log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n",
" log.add(b[\"text\"])\n",
" log.add(b[\"text\"][:50].replace(\"\\n\", \" \"))\n",
" yield None, log.html()\n",
"\n",
" # 4) LLM answer\n",
" log.add(f\"Asking LLM: {model or self.set.gen_model}\")\n",
" # 4) LLM answer — pass 1\n",
" log.add(f\"Asking LLM (pass 1): {model or self.set.gen_model}\")\n",
" yield None, log.html()\n",
" answer = self.answerer.answer(question, blocks, model=model)\n",
" log.add(\"Answer ready.\")\n",
" yield answer, log.html()\n"
" answer1 = self.answerer.answer(question, blocks, model=model)\n",
" log.add(\"Answer 1 ready.\")\n",
" yield answer1, log.html()\n",
"\n",
" # 5) Self-critique\n",
" log.add(\"Running self-critique…\")\n",
" yield None, log.html()\n",
" critic = Critic(self.answerer.llm)\n",
" review = critic.review(question, answer1, blocks)\n",
" ok = review.get(\"ok\", True)\n",
" missing = review.get(\"missing\") or []\n",
" followups = review.get(\"followups\") or []\n",
" # compact display\n",
" def _preview(lst, n=5):\n",
" return \", \".join(lst[:n]) + (f\" … (+{len(lst)-n} more)\" if len(lst) > n else \"\") if lst else \"—\"\n",
" log.add(f\"Self-critique result: ok={ok}; missing={_preview(missing)}; followups={_preview(followups)}\")\n",
" yield None, log.html()\n",
"\n",
" if ok or not followups:\n",
" log.add(\"Critique passed or no follow-ups. Finalizing pass 1.\")\n",
" yield answer1, log.html()\n",
" return\n",
"\n",
" # 6) Pass 2 — follow-up retrievals\n",
" extra_hits: List[Dict] = []\n",
" limited = followups[:3] # keep bounded\n",
" for i, q in enumerate(limited, 1):\n",
" hits = self.retriever.topk(q)\n",
" extra_hits.extend(hits)\n",
" top3 = \", \".join(_hit_id(h) for h in hits[:3]) or \"—\"\n",
" log.add(f\"Follow-up retrieval {i}/{len(limited)}: got {len(hits)} hits → {top3}\")\n",
" yield None, log.html()\n",
"\n",
" blocks2 = self.ctx_expander.expand_and_merge(all_hits + extra_hits)\n",
" used2 = self.ctx_expander.last_ids\n",
" peek2 = \", \".join(used2[:8]) + (\" …\" if len(used2) > 8 else \"\")\n",
" log.add(f\"Neighbor addition (pass 2): collected {len(used2)} chunk-ids → {peek2}\")\n",
"\n",
" approx_words2 = int(self.ctx_expander.pad_tokens / 1.4)\n",
" log.add(\n",
" f\"Context expansion (pass 2): merged {len(blocks2)} block(s) \"\n",
" f\"(radius ±{self.ctx_expander.radius}, pad ≈{approx_words2} words).\"\n",
" )\n",
" for b in blocks2[:10]: # dont spam the log\n",
" log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n",
" yield None, log.html()\n",
"\n",
" # 7) LLM answer — pass 2\n",
" log.add(f\"Asking LLM (pass 2): {model or self.set.gen_model}\")\n",
" yield None, log.html()\n",
" answer2 = self.answerer.answer(question, blocks2, model=model)\n",
" final_answer = answer2 + \"\\n\\n_Refined after self-critique (agentic step)._\"\n",
" log.add(\"Answer 2 ready (refined).\")\n",
" yield final_answer, log.html()\n"
]
},
{
@@ -817,7 +954,7 @@
"source": [
"def make_agent(gen_model: str) -> LegalAgent:\n",
" llm = OpenAILLM(gen_model)\n",
" expander = QueryExpander(llm=llm, n=SET.expansions)\n",
" expander = ModalFirstExpander(llm=llm, n=SET.expansions) # <— here\n",
" retriever = Retriever(index=index, k=SET.topk_per_rewrite)\n",
" ctx_expander = ContextExpander(index=index, radius=SET.neighbor_radius, max_blocks=SET.max_blocks_for_llm)\n",
" answerer = Answerer(llm=llm, set=SET)\n",

View File

@@ -0,0 +1,120 @@
# week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/modal_expander.py
import os, json, re
from typing import List
import modal
# minimal image: vLLM + torch + HF hub
image = (
modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12")
.entrypoint([])
.uv_pip_install(
"vllm==0.10.2",
"torch==2.8.0",
"huggingface_hub[hf_transfer]==0.35.0",
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
)
app = modal.App("legal-query-expander-qwen3-v2", image=image)
MODEL_NAME = "Qwen/Qwen3-4B-Instruct" # use instruct, defaults everywhere
_llm = None # warm-container cache
def _extract_json_array(text: str) -> List[str]:
try:
parsed = json.loads(text)
return [x for x in parsed if isinstance(x, str)]
except Exception:
pass
m = re.search(r"\[(?:.|\n|\r)*\]", text)
if m:
try:
parsed = json.loads(m.group(0))
return [x for x in parsed if isinstance(x, str)]
except Exception:
return []
return []
def _sanitize_and_dedupe(items: List[str], n: int) -> List[str]:
out, seen = [], set()
for q in items:
q = re.sub(r"[^\w\s\-./]", "", (q or "")).strip()
k = q.lower()
if q and k not in seen:
seen.add(k)
out.append(q)
if len(out) >= n:
break
return out
@app.function(
image=image,
gpu=modal.gpu.L4(), # pick any available GPU (A100/H100 also fine)
timeout=600,
secrets=[modal.Secret.from_name("huggingface-secret")], # set HF token here
)
def expand(question: str, n: int = 5) -> List[str]:
"""
Return up to n short, diverse retrieval keyphrases for Bare Acts.
Uses Qwen3-4B-Instruct with its default chat template.
"""
global _llm
from vllm import LLM, SamplingParams
# ensure HF token is available to vLLM
tok = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
if tok and not os.environ.get("HUGGINGFACE_HUB_TOKEN"):
os.environ["HUGGINGFACE_HUB_TOKEN"] = tok
if _llm is None:
_llm = LLM(
model=MODEL_NAME,
trust_remote_code=True,
dtype="auto",
tensor_parallel_size=1,
)
user = (
"You are Search Query Expander."
"For given search query you give 4-5 different variants of it to search the database better. It is a legal search query and our database is of legal data like bare acts."
"Respond ONLY as a JSON array of strings; no prose, no section numbers."
f"Question:\n{question}\n\n"
f"Return {n} distinct keyphrases (420 words each), which captures the what to search inside rag database. Return as a JSON array. No commentary."
)
messages = [
{"role": "user", "content": user},
]
tokenizer = _llm.get_tokenizer()
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
result = _llm.generate(
[prompt],
SamplingParams(
max_tokens=256,
temperature=0.2,
),
)
text = result[0].outputs[0].text
arr = _sanitize_and_dedupe(_extract_json_array(text), n)
if not arr:
# deterministic fallback (keeps things non-empty)
base = re.sub(r"[?]+$", "", (question or "")).strip()
pool = [
f"{base} section",
f"{base} provision bare act",
f"{base} indian penal code",
f"{base} bharatiya nyaya sanhita",
f"{base} punishment section keywords",
]
arr = _sanitize_and_dedupe(pool, n)
return arr