add Modal+Qwen3 query expansion and self-critique agentic loop
This commit is contained in:
@@ -0,0 +1,49 @@
|
|||||||
|
# Agentic Legal Q&A on Bare Acts (Week 8)
|
||||||
|
|
||||||
|
An **agentic RAG** demo that answers legal questions from Indian Bare Acts (IPC/BNS/Constitution).
|
||||||
|
Pipeline: **Query expansion (Modal+Qwen3) → Multi-retrieval (Chroma) → Neighbor-aware context merge → LLM answer → Self-critique → Optional second pass**.
|
||||||
|
UI: lightweight **Gradio** chat with live agent logs.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
- **Modal-first expander:** `modal_expander.py` (Qwen3-4B via vLLM, GPU) with local LLM fallback.
|
||||||
|
- **Vector store:** Chroma + `all-MiniLM-L6-v2`, token-span aware chunking, ±neighbor merge.
|
||||||
|
- **Agentic loop:** critic validates citations and triggers follow-up retrievals if needed.
|
||||||
|
- **Config knobs:** top-k per rewrite, neighbor radius, max merged blocks, model dropdown.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
```bash
|
||||||
|
python -m pip install -U openai chromadb transformers gradio python-dotenv modal
|
||||||
|
````
|
||||||
|
|
||||||
|
Create `.env` with your keys:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
OPENAI_API_KEY=...
|
||||||
|
```
|
||||||
|
|
||||||
|
Place Bare Acts as UTF-8 `.txt` files in:
|
||||||
|
|
||||||
|
```
|
||||||
|
knowledge_base/bare_acts/ # e.g., ipc.txt, bns.txt, coi.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Deploy the Modal expander
|
||||||
|
|
||||||
|
Set a Modal secret named `huggingface-secret` containing `HUGGINGFACE_HUB_TOKEN`, then:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
modal deploy -m modal_expander
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run the notebook app
|
||||||
|
|
||||||
|
```bash
|
||||||
|
jupyter notebook agentic_legal_qna_with_rag_on_bare_acts.ipynb
|
||||||
|
```
|
||||||
|
|
||||||
|
Run all cells; a Gradio chat appears. Tune **Top-K**, **Neighbor radius**, and **Max blocks** under *Advanced*.
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
* Default OpenAI model: `gpt-4o-mini` (change via UI).
|
||||||
|
* Vector DB is persisted in `vector_db_w8`; re-run the indexing cell to rebuild after data changes.
|
||||||
@@ -13,9 +13,10 @@
|
|||||||
"from dataclasses import dataclass, field\n",
|
"from dataclasses import dataclass, field\n",
|
||||||
"from typing import List, Dict, Any, Tuple, Protocol\n",
|
"from typing import List, Dict, Any, Tuple, Protocol\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from dotenv import load_dotenv\n",
|
"\n",
|
||||||
"from openai import OpenAI\n",
|
"from openai import OpenAI\n",
|
||||||
"import chromadb\n",
|
"import chromadb\n",
|
||||||
"from chromadb import PersistentClient\n",
|
"from chromadb import PersistentClient\n",
|
||||||
@@ -24,6 +25,7 @@
|
|||||||
"from transformers import AutoTokenizer\n",
|
"from transformers import AutoTokenizer\n",
|
||||||
"import gradio as gr\n",
|
"import gradio as gr\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"from dotenv import load_dotenv\n",
|
||||||
"load_dotenv(override=True)"
|
"load_dotenv(override=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -735,6 +737,73 @@
|
|||||||
" )\n"
|
" )\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8284580f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class Critic:\n",
|
||||||
|
" def __init__(self, llm: LLM):\n",
|
||||||
|
" self.llm = llm\n",
|
||||||
|
"\n",
|
||||||
|
" def review(self, question: str, answer: str, blocks: List[Dict]) -> Dict[str, Any]:\n",
|
||||||
|
" block_ids = [f\"[{b['source']} {b['start']}:{b['end']}]\" for b in blocks]\n",
|
||||||
|
" sys = (\n",
|
||||||
|
" \"You are a meticulous legal verifier. \"\n",
|
||||||
|
" \"Return ONLY JSON with keys: ok (bool), missing (list of short missing facts), followups (list of short retrieval keyphrases).\"\n",
|
||||||
|
" )\n",
|
||||||
|
" user = f\"\"\"Question:\n",
|
||||||
|
"{question}\n",
|
||||||
|
"\n",
|
||||||
|
"Proposed answer:\n",
|
||||||
|
"{answer}\n",
|
||||||
|
"\n",
|
||||||
|
"Verify that every factual claim is supported by the context blocks (by their ids).\n",
|
||||||
|
"If support is weak or missing, set ok=false and propose concise retrieval keyphrases in followups.\n",
|
||||||
|
"\n",
|
||||||
|
"Available context block ids:\n",
|
||||||
|
"{\", \".join(block_ids)}\n",
|
||||||
|
"\n",
|
||||||
|
"Return JSON only, e.g.:\n",
|
||||||
|
"{{\"ok\": true, \"missing\": [], \"followups\": []}}\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
" raw = self.llm.complete(system=sys, user=user, temperature=0.0, max_tokens=220)\n",
|
||||||
|
" try:\n",
|
||||||
|
" m = re.search(r\"\\{(.|\\n)*\\}\", raw)\n",
|
||||||
|
" return json.loads(m.group(0)) if m else {\"ok\": True, \"missing\": [], \"followups\": []}\n",
|
||||||
|
" except Exception:\n",
|
||||||
|
" return {\"ok\": True, \"missing\": [], \"followups\": []}\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "34449337",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import modal\n",
|
||||||
|
"_remote_expand = modal.Function.from_name(\"legal-query-expander-qwen3-v2\", \"expand\")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class ModalFirstExpander(QueryExpander):\n",
|
||||||
|
" def expand(self, question: str, *, model_override: str | None = None) -> List[str]:\n",
|
||||||
|
" got = []\n",
|
||||||
|
" if _remote_expand:\n",
|
||||||
|
" try:\n",
|
||||||
|
" got = _remote_expand.remote(question, self.n)\n",
|
||||||
|
" except Exception:\n",
|
||||||
|
" got = []\n",
|
||||||
|
" if not got or len([x for x in got if isinstance(x, str) and x.strip()]) < max(1, self.n // 2):\n",
|
||||||
|
" return super().expand(question, model_override=model_override)\n",
|
||||||
|
" import re\n",
|
||||||
|
" got = [re.sub(r\"[^\\w\\s\\-./]\", \"\", q).strip() for q in got]\n",
|
||||||
|
" return [q for q in got if q][: self.n]\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
@@ -751,16 +820,33 @@
|
|||||||
" self.set = set\n",
|
" self.set = set\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def run(self, question: str, model: str | None = None) -> str:\n",
|
" def run(self, question: str, model: str | None = None) -> str:\n",
|
||||||
|
" # Pass 1\n",
|
||||||
" rewrites = self.expander.expand(question, model_override=model)\n",
|
" rewrites = self.expander.expand(question, model_override=model)\n",
|
||||||
" hits: List[Dict] = []\n",
|
" hits: List[Dict] = []\n",
|
||||||
" for q in rewrites:\n",
|
" for q in rewrites:\n",
|
||||||
" hits.extend(self.retriever.topk(q))\n",
|
" hits.extend(self.retriever.topk(q))\n",
|
||||||
" blocks = self.ctx_expander.expand_and_merge(hits)\n",
|
" blocks = self.ctx_expander.expand_and_merge(hits)\n",
|
||||||
" return self.answerer.answer(question, blocks, model=model)\n",
|
" answer1 = self.answerer.answer(question, blocks, model=model)\n",
|
||||||
|
"\n",
|
||||||
|
" # Self-critique\n",
|
||||||
|
" critic = Critic(self.answerer.llm)\n",
|
||||||
|
" review = critic.review(question, answer1, blocks)\n",
|
||||||
|
"\n",
|
||||||
|
" if review.get(\"ok\", True) or not review.get(\"followups\"):\n",
|
||||||
|
" return answer1 # Good enough\n",
|
||||||
|
"\n",
|
||||||
|
" # Pass 2 — adapt plan using follow-up rewrites\n",
|
||||||
|
" extra_hits: List[Dict] = []\n",
|
||||||
|
" for q in review[\"followups\"][:3]: # keep bounded\n",
|
||||||
|
" extra_hits.extend(self.retriever.topk(q))\n",
|
||||||
|
" blocks2 = self.ctx_expander.expand_and_merge(hits + extra_hits)\n",
|
||||||
|
" answer2 = self.answerer.answer(question, blocks2, model=model)\n",
|
||||||
|
" return answer2 + \"\\n\\n_Refined after self-critique (agentic step)._\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def run_stream(self, question: str, model: str | None = None):\n",
|
" def run_stream(self, question: str, model: str | None = None):\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
" Generator: yields tuples (answer_or_none, logs_html) multiple times.\n",
|
" Generator: yields tuples (answer_or_none, logs_html) multiple times.\n",
|
||||||
|
" Now includes a self-critique step and (if needed) a second pass, with logs.\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
" log = RunLogger()\n",
|
" log = RunLogger()\n",
|
||||||
" log.add(f\"Question: {question}\")\n",
|
" log.add(f\"Question: {question}\")\n",
|
||||||
@@ -797,15 +883,66 @@
|
|||||||
" )\n",
|
" )\n",
|
||||||
" for b in blocks:\n",
|
" for b in blocks:\n",
|
||||||
" log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n",
|
" log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n",
|
||||||
" log.add(b[\"text\"])\n",
|
" log.add(b[\"text\"][:50].replace(\"\\n\", \" \"))\n",
|
||||||
" yield None, log.html()\n",
|
" yield None, log.html()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # 4) LLM answer\n",
|
" # 4) LLM answer — pass 1\n",
|
||||||
" log.add(f\"Asking LLM: {model or self.set.gen_model}\")\n",
|
" log.add(f\"Asking LLM (pass 1): {model or self.set.gen_model}\")\n",
|
||||||
" yield None, log.html()\n",
|
" yield None, log.html()\n",
|
||||||
" answer = self.answerer.answer(question, blocks, model=model)\n",
|
" answer1 = self.answerer.answer(question, blocks, model=model)\n",
|
||||||
" log.add(\"Answer ready.\")\n",
|
" log.add(\"Answer 1 ready.\")\n",
|
||||||
" yield answer, log.html()\n"
|
" yield answer1, log.html()\n",
|
||||||
|
"\n",
|
||||||
|
" # 5) Self-critique\n",
|
||||||
|
" log.add(\"Running self-critique…\")\n",
|
||||||
|
" yield None, log.html()\n",
|
||||||
|
" critic = Critic(self.answerer.llm)\n",
|
||||||
|
" review = critic.review(question, answer1, blocks)\n",
|
||||||
|
" ok = review.get(\"ok\", True)\n",
|
||||||
|
" missing = review.get(\"missing\") or []\n",
|
||||||
|
" followups = review.get(\"followups\") or []\n",
|
||||||
|
" # compact display\n",
|
||||||
|
" def _preview(lst, n=5):\n",
|
||||||
|
" return \", \".join(lst[:n]) + (f\" … (+{len(lst)-n} more)\" if len(lst) > n else \"\") if lst else \"—\"\n",
|
||||||
|
" log.add(f\"Self-critique result: ok={ok}; missing={_preview(missing)}; followups={_preview(followups)}\")\n",
|
||||||
|
" yield None, log.html()\n",
|
||||||
|
"\n",
|
||||||
|
" if ok or not followups:\n",
|
||||||
|
" log.add(\"Critique passed or no follow-ups. Finalizing pass 1.\")\n",
|
||||||
|
" yield answer1, log.html()\n",
|
||||||
|
" return\n",
|
||||||
|
"\n",
|
||||||
|
" # 6) Pass 2 — follow-up retrievals\n",
|
||||||
|
" extra_hits: List[Dict] = []\n",
|
||||||
|
" limited = followups[:3] # keep bounded\n",
|
||||||
|
" for i, q in enumerate(limited, 1):\n",
|
||||||
|
" hits = self.retriever.topk(q)\n",
|
||||||
|
" extra_hits.extend(hits)\n",
|
||||||
|
" top3 = \", \".join(_hit_id(h) for h in hits[:3]) or \"—\"\n",
|
||||||
|
" log.add(f\"Follow-up retrieval {i}/{len(limited)}: got {len(hits)} hits → {top3}\")\n",
|
||||||
|
" yield None, log.html()\n",
|
||||||
|
"\n",
|
||||||
|
" blocks2 = self.ctx_expander.expand_and_merge(all_hits + extra_hits)\n",
|
||||||
|
" used2 = self.ctx_expander.last_ids\n",
|
||||||
|
" peek2 = \", \".join(used2[:8]) + (\" …\" if len(used2) > 8 else \"\")\n",
|
||||||
|
" log.add(f\"Neighbor addition (pass 2): collected {len(used2)} chunk-ids → {peek2}\")\n",
|
||||||
|
"\n",
|
||||||
|
" approx_words2 = int(self.ctx_expander.pad_tokens / 1.4)\n",
|
||||||
|
" log.add(\n",
|
||||||
|
" f\"Context expansion (pass 2): merged {len(blocks2)} block(s) \"\n",
|
||||||
|
" f\"(radius ±{self.ctx_expander.radius}, pad ≈{approx_words2} words).\"\n",
|
||||||
|
" )\n",
|
||||||
|
" for b in blocks2[:10]: # don’t spam the log\n",
|
||||||
|
" log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n",
|
||||||
|
" yield None, log.html()\n",
|
||||||
|
"\n",
|
||||||
|
" # 7) LLM answer — pass 2\n",
|
||||||
|
" log.add(f\"Asking LLM (pass 2): {model or self.set.gen_model}\")\n",
|
||||||
|
" yield None, log.html()\n",
|
||||||
|
" answer2 = self.answerer.answer(question, blocks2, model=model)\n",
|
||||||
|
" final_answer = answer2 + \"\\n\\n_Refined after self-critique (agentic step)._\"\n",
|
||||||
|
" log.add(\"Answer 2 ready (refined).\")\n",
|
||||||
|
" yield final_answer, log.html()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -817,7 +954,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"def make_agent(gen_model: str) -> LegalAgent:\n",
|
"def make_agent(gen_model: str) -> LegalAgent:\n",
|
||||||
" llm = OpenAILLM(gen_model)\n",
|
" llm = OpenAILLM(gen_model)\n",
|
||||||
" expander = QueryExpander(llm=llm, n=SET.expansions)\n",
|
" expander = ModalFirstExpander(llm=llm, n=SET.expansions) # <— here\n",
|
||||||
" retriever = Retriever(index=index, k=SET.topk_per_rewrite)\n",
|
" retriever = Retriever(index=index, k=SET.topk_per_rewrite)\n",
|
||||||
" ctx_expander = ContextExpander(index=index, radius=SET.neighbor_radius, max_blocks=SET.max_blocks_for_llm)\n",
|
" ctx_expander = ContextExpander(index=index, radius=SET.neighbor_radius, max_blocks=SET.max_blocks_for_llm)\n",
|
||||||
" answerer = Answerer(llm=llm, set=SET)\n",
|
" answerer = Answerer(llm=llm, set=SET)\n",
|
||||||
|
|||||||
@@ -0,0 +1,120 @@
|
|||||||
|
# week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/modal_expander.py
|
||||||
|
import os, json, re
|
||||||
|
from typing import List
|
||||||
|
import modal
|
||||||
|
|
||||||
|
# minimal image: vLLM + torch + HF hub
|
||||||
|
image = (
|
||||||
|
modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12")
|
||||||
|
.entrypoint([])
|
||||||
|
.uv_pip_install(
|
||||||
|
"vllm==0.10.2",
|
||||||
|
"torch==2.8.0",
|
||||||
|
"huggingface_hub[hf_transfer]==0.35.0",
|
||||||
|
)
|
||||||
|
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||||
|
)
|
||||||
|
|
||||||
|
app = modal.App("legal-query-expander-qwen3-v2", image=image)
|
||||||
|
|
||||||
|
MODEL_NAME = "Qwen/Qwen3-4B-Instruct" # use instruct, defaults everywhere
|
||||||
|
_llm = None # warm-container cache
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_array(text: str) -> List[str]:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(text)
|
||||||
|
return [x for x in parsed if isinstance(x, str)]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
m = re.search(r"\[(?:.|\n|\r)*\]", text)
|
||||||
|
if m:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(m.group(0))
|
||||||
|
return [x for x in parsed if isinstance(x, str)]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_and_dedupe(items: List[str], n: int) -> List[str]:
|
||||||
|
out, seen = [], set()
|
||||||
|
for q in items:
|
||||||
|
q = re.sub(r"[^\w\s\-./]", "", (q or "")).strip()
|
||||||
|
k = q.lower()
|
||||||
|
if q and k not in seen:
|
||||||
|
seen.add(k)
|
||||||
|
out.append(q)
|
||||||
|
if len(out) >= n:
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@app.function(
|
||||||
|
image=image,
|
||||||
|
gpu=modal.gpu.L4(), # pick any available GPU (A100/H100 also fine)
|
||||||
|
timeout=600,
|
||||||
|
secrets=[modal.Secret.from_name("huggingface-secret")], # set HF token here
|
||||||
|
)
|
||||||
|
def expand(question: str, n: int = 5) -> List[str]:
|
||||||
|
"""
|
||||||
|
Return up to n short, diverse retrieval keyphrases for Bare Acts.
|
||||||
|
Uses Qwen3-4B-Instruct with its default chat template.
|
||||||
|
"""
|
||||||
|
global _llm
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# ensure HF token is available to vLLM
|
||||||
|
tok = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
|
||||||
|
if tok and not os.environ.get("HUGGINGFACE_HUB_TOKEN"):
|
||||||
|
os.environ["HUGGINGFACE_HUB_TOKEN"] = tok
|
||||||
|
|
||||||
|
if _llm is None:
|
||||||
|
_llm = LLM(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="auto",
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
user = (
|
||||||
|
"You are Search Query Expander."
|
||||||
|
"For given search query you give 4-5 different variants of it to search the database better. It is a legal search query and our database is of legal data like bare acts."
|
||||||
|
"Respond ONLY as a JSON array of strings; no prose, no section numbers."
|
||||||
|
f"Question:\n{question}\n\n"
|
||||||
|
f"Return {n} distinct keyphrases (4–20 words each), which captures the what to search inside rag database. Return as a JSON array. No commentary."
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": user},
|
||||||
|
]
|
||||||
|
|
||||||
|
tokenizer = _llm.get_tokenizer()
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _llm.generate(
|
||||||
|
[prompt],
|
||||||
|
SamplingParams(
|
||||||
|
max_tokens=256,
|
||||||
|
temperature=0.2,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
text = result[0].outputs[0].text
|
||||||
|
|
||||||
|
arr = _sanitize_and_dedupe(_extract_json_array(text), n)
|
||||||
|
|
||||||
|
if not arr:
|
||||||
|
# deterministic fallback (keeps things non-empty)
|
||||||
|
base = re.sub(r"[?]+$", "", (question or "")).strip()
|
||||||
|
pool = [
|
||||||
|
f"{base} section",
|
||||||
|
f"{base} provision bare act",
|
||||||
|
f"{base} indian penal code",
|
||||||
|
f"{base} bharatiya nyaya sanhita",
|
||||||
|
f"{base} punishment section keywords",
|
||||||
|
]
|
||||||
|
arr = _sanitize_and_dedupe(pool, n)
|
||||||
|
|
||||||
|
return arr
|
||||||
Reference in New Issue
Block a user