add Modal+Qwen3 query expansion and self-critique agentic loop
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
# Agentic Legal Q&A on Bare Acts (Week 8)
|
||||
|
||||
An **agentic RAG** demo that answers legal questions from Indian Bare Acts (IPC/BNS/Constitution).
|
||||
Pipeline: **Query expansion (Modal+Qwen3) → Multi-retrieval (Chroma) → Neighbor-aware context merge → LLM answer → Self-critique → Optional second pass**.
|
||||
UI: lightweight **Gradio** chat with live agent logs.
|
||||
|
||||
## Features
|
||||
- **Modal-first expander:** `modal_expander.py` (Qwen3-4B via vLLM, GPU) with local LLM fallback.
|
||||
- **Vector store:** Chroma + `all-MiniLM-L6-v2`, token-span aware chunking, ±neighbor merge.
|
||||
- **Agentic loop:** critic validates citations and triggers follow-up retrievals if needed.
|
||||
- **Config knobs:** top-k per rewrite, neighbor radius, max merged blocks, model dropdown.
|
||||
|
||||
## Setup
|
||||
```bash
|
||||
python -m pip install -U openai chromadb transformers gradio python-dotenv modal
|
||||
````
|
||||
|
||||
Create `.env` with your keys:
|
||||
|
||||
```bash
|
||||
OPENAI_API_KEY=...
|
||||
```
|
||||
|
||||
Place Bare Acts as UTF-8 `.txt` files in:
|
||||
|
||||
```
|
||||
knowledge_base/bare_acts/ # e.g., ipc.txt, bns.txt, coi.txt
|
||||
```
|
||||
|
||||
## Deploy the Modal expander
|
||||
|
||||
Set a Modal secret named `huggingface-secret` containing `HUGGINGFACE_HUB_TOKEN`, then:
|
||||
|
||||
```bash
|
||||
modal deploy -m modal_expander
|
||||
```
|
||||
|
||||
## Run the notebook app
|
||||
|
||||
```bash
|
||||
jupyter notebook agentic_legal_qna_with_rag_on_bare_acts.ipynb
|
||||
```
|
||||
|
||||
Run all cells; a Gradio chat appears. Tune **Top-K**, **Neighbor radius**, and **Max blocks** under *Advanced*.
|
||||
|
||||
## Notes
|
||||
|
||||
* Default OpenAI model: `gpt-4o-mini` (change via UI).
|
||||
* Vector DB is persisted in `vector_db_w8`; re-run the indexing cell to rebuild after data changes.
|
||||
@@ -13,9 +13,10 @@
|
||||
"from dataclasses import dataclass, field\n",
|
||||
"from typing import List, Dict, Any, Tuple, Protocol\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"\n",
|
||||
"from openai import OpenAI\n",
|
||||
"import chromadb\n",
|
||||
"from chromadb import PersistentClient\n",
|
||||
@@ -24,6 +25,7 @@
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"import gradio as gr\n",
|
||||
"\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"load_dotenv(override=True)"
|
||||
]
|
||||
},
|
||||
@@ -735,6 +737,73 @@
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8284580f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Critic:\n",
|
||||
" def __init__(self, llm: LLM):\n",
|
||||
" self.llm = llm\n",
|
||||
"\n",
|
||||
" def review(self, question: str, answer: str, blocks: List[Dict]) -> Dict[str, Any]:\n",
|
||||
" block_ids = [f\"[{b['source']} {b['start']}:{b['end']}]\" for b in blocks]\n",
|
||||
" sys = (\n",
|
||||
" \"You are a meticulous legal verifier. \"\n",
|
||||
" \"Return ONLY JSON with keys: ok (bool), missing (list of short missing facts), followups (list of short retrieval keyphrases).\"\n",
|
||||
" )\n",
|
||||
" user = f\"\"\"Question:\n",
|
||||
"{question}\n",
|
||||
"\n",
|
||||
"Proposed answer:\n",
|
||||
"{answer}\n",
|
||||
"\n",
|
||||
"Verify that every factual claim is supported by the context blocks (by their ids).\n",
|
||||
"If support is weak or missing, set ok=false and propose concise retrieval keyphrases in followups.\n",
|
||||
"\n",
|
||||
"Available context block ids:\n",
|
||||
"{\", \".join(block_ids)}\n",
|
||||
"\n",
|
||||
"Return JSON only, e.g.:\n",
|
||||
"{{\"ok\": true, \"missing\": [], \"followups\": []}}\n",
|
||||
"\"\"\"\n",
|
||||
" raw = self.llm.complete(system=sys, user=user, temperature=0.0, max_tokens=220)\n",
|
||||
" try:\n",
|
||||
" m = re.search(r\"\\{(.|\\n)*\\}\", raw)\n",
|
||||
" return json.loads(m.group(0)) if m else {\"ok\": True, \"missing\": [], \"followups\": []}\n",
|
||||
" except Exception:\n",
|
||||
" return {\"ok\": True, \"missing\": [], \"followups\": []}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "34449337",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import modal\n",
|
||||
"_remote_expand = modal.Function.from_name(\"legal-query-expander-qwen3-v2\", \"expand\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ModalFirstExpander(QueryExpander):\n",
|
||||
" def expand(self, question: str, *, model_override: str | None = None) -> List[str]:\n",
|
||||
" got = []\n",
|
||||
" if _remote_expand:\n",
|
||||
" try:\n",
|
||||
" got = _remote_expand.remote(question, self.n)\n",
|
||||
" except Exception:\n",
|
||||
" got = []\n",
|
||||
" if not got or len([x for x in got if isinstance(x, str) and x.strip()]) < max(1, self.n // 2):\n",
|
||||
" return super().expand(question, model_override=model_override)\n",
|
||||
" import re\n",
|
||||
" got = [re.sub(r\"[^\\w\\s\\-./]\", \"\", q).strip() for q in got]\n",
|
||||
" return [q for q in got if q][: self.n]\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -751,16 +820,33 @@
|
||||
" self.set = set\n",
|
||||
"\n",
|
||||
" def run(self, question: str, model: str | None = None) -> str:\n",
|
||||
" # Pass 1\n",
|
||||
" rewrites = self.expander.expand(question, model_override=model)\n",
|
||||
" hits: List[Dict] = []\n",
|
||||
" for q in rewrites:\n",
|
||||
" hits.extend(self.retriever.topk(q))\n",
|
||||
" blocks = self.ctx_expander.expand_and_merge(hits)\n",
|
||||
" return self.answerer.answer(question, blocks, model=model)\n",
|
||||
" answer1 = self.answerer.answer(question, blocks, model=model)\n",
|
||||
"\n",
|
||||
" # Self-critique\n",
|
||||
" critic = Critic(self.answerer.llm)\n",
|
||||
" review = critic.review(question, answer1, blocks)\n",
|
||||
"\n",
|
||||
" if review.get(\"ok\", True) or not review.get(\"followups\"):\n",
|
||||
" return answer1 # Good enough\n",
|
||||
"\n",
|
||||
" # Pass 2 — adapt plan using follow-up rewrites\n",
|
||||
" extra_hits: List[Dict] = []\n",
|
||||
" for q in review[\"followups\"][:3]: # keep bounded\n",
|
||||
" extra_hits.extend(self.retriever.topk(q))\n",
|
||||
" blocks2 = self.ctx_expander.expand_and_merge(hits + extra_hits)\n",
|
||||
" answer2 = self.answerer.answer(question, blocks2, model=model)\n",
|
||||
" return answer2 + \"\\n\\n_Refined after self-critique (agentic step)._\"\n",
|
||||
"\n",
|
||||
" def run_stream(self, question: str, model: str | None = None):\n",
|
||||
" \"\"\"\n",
|
||||
" Generator: yields tuples (answer_or_none, logs_html) multiple times.\n",
|
||||
" Now includes a self-critique step and (if needed) a second pass, with logs.\n",
|
||||
" \"\"\"\n",
|
||||
" log = RunLogger()\n",
|
||||
" log.add(f\"Question: {question}\")\n",
|
||||
@@ -797,15 +883,66 @@
|
||||
" )\n",
|
||||
" for b in blocks:\n",
|
||||
" log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n",
|
||||
" log.add(b[\"text\"])\n",
|
||||
" log.add(b[\"text\"][:50].replace(\"\\n\", \" \"))\n",
|
||||
" yield None, log.html()\n",
|
||||
"\n",
|
||||
" # 4) LLM answer\n",
|
||||
" log.add(f\"Asking LLM: {model or self.set.gen_model}\")\n",
|
||||
" # 4) LLM answer — pass 1\n",
|
||||
" log.add(f\"Asking LLM (pass 1): {model or self.set.gen_model}\")\n",
|
||||
" yield None, log.html()\n",
|
||||
" answer = self.answerer.answer(question, blocks, model=model)\n",
|
||||
" log.add(\"Answer ready.\")\n",
|
||||
" yield answer, log.html()\n"
|
||||
" answer1 = self.answerer.answer(question, blocks, model=model)\n",
|
||||
" log.add(\"Answer 1 ready.\")\n",
|
||||
" yield answer1, log.html()\n",
|
||||
"\n",
|
||||
" # 5) Self-critique\n",
|
||||
" log.add(\"Running self-critique…\")\n",
|
||||
" yield None, log.html()\n",
|
||||
" critic = Critic(self.answerer.llm)\n",
|
||||
" review = critic.review(question, answer1, blocks)\n",
|
||||
" ok = review.get(\"ok\", True)\n",
|
||||
" missing = review.get(\"missing\") or []\n",
|
||||
" followups = review.get(\"followups\") or []\n",
|
||||
" # compact display\n",
|
||||
" def _preview(lst, n=5):\n",
|
||||
" return \", \".join(lst[:n]) + (f\" … (+{len(lst)-n} more)\" if len(lst) > n else \"\") if lst else \"—\"\n",
|
||||
" log.add(f\"Self-critique result: ok={ok}; missing={_preview(missing)}; followups={_preview(followups)}\")\n",
|
||||
" yield None, log.html()\n",
|
||||
"\n",
|
||||
" if ok or not followups:\n",
|
||||
" log.add(\"Critique passed or no follow-ups. Finalizing pass 1.\")\n",
|
||||
" yield answer1, log.html()\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" # 6) Pass 2 — follow-up retrievals\n",
|
||||
" extra_hits: List[Dict] = []\n",
|
||||
" limited = followups[:3] # keep bounded\n",
|
||||
" for i, q in enumerate(limited, 1):\n",
|
||||
" hits = self.retriever.topk(q)\n",
|
||||
" extra_hits.extend(hits)\n",
|
||||
" top3 = \", \".join(_hit_id(h) for h in hits[:3]) or \"—\"\n",
|
||||
" log.add(f\"Follow-up retrieval {i}/{len(limited)}: got {len(hits)} hits → {top3}\")\n",
|
||||
" yield None, log.html()\n",
|
||||
"\n",
|
||||
" blocks2 = self.ctx_expander.expand_and_merge(all_hits + extra_hits)\n",
|
||||
" used2 = self.ctx_expander.last_ids\n",
|
||||
" peek2 = \", \".join(used2[:8]) + (\" …\" if len(used2) > 8 else \"\")\n",
|
||||
" log.add(f\"Neighbor addition (pass 2): collected {len(used2)} chunk-ids → {peek2}\")\n",
|
||||
"\n",
|
||||
" approx_words2 = int(self.ctx_expander.pad_tokens / 1.4)\n",
|
||||
" log.add(\n",
|
||||
" f\"Context expansion (pass 2): merged {len(blocks2)} block(s) \"\n",
|
||||
" f\"(radius ±{self.ctx_expander.radius}, pad ≈{approx_words2} words).\"\n",
|
||||
" )\n",
|
||||
" for b in blocks2[:10]: # don’t spam the log\n",
|
||||
" log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n",
|
||||
" yield None, log.html()\n",
|
||||
"\n",
|
||||
" # 7) LLM answer — pass 2\n",
|
||||
" log.add(f\"Asking LLM (pass 2): {model or self.set.gen_model}\")\n",
|
||||
" yield None, log.html()\n",
|
||||
" answer2 = self.answerer.answer(question, blocks2, model=model)\n",
|
||||
" final_answer = answer2 + \"\\n\\n_Refined after self-critique (agentic step)._\"\n",
|
||||
" log.add(\"Answer 2 ready (refined).\")\n",
|
||||
" yield final_answer, log.html()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -817,7 +954,7 @@
|
||||
"source": [
|
||||
"def make_agent(gen_model: str) -> LegalAgent:\n",
|
||||
" llm = OpenAILLM(gen_model)\n",
|
||||
" expander = QueryExpander(llm=llm, n=SET.expansions)\n",
|
||||
" expander = ModalFirstExpander(llm=llm, n=SET.expansions) # <— here\n",
|
||||
" retriever = Retriever(index=index, k=SET.topk_per_rewrite)\n",
|
||||
" ctx_expander = ContextExpander(index=index, radius=SET.neighbor_radius, max_blocks=SET.max_blocks_for_llm)\n",
|
||||
" answerer = Answerer(llm=llm, set=SET)\n",
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
# week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/modal_expander.py
|
||||
import os, json, re
|
||||
from typing import List
|
||||
import modal
|
||||
|
||||
# minimal image: vLLM + torch + HF hub
|
||||
image = (
|
||||
modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12")
|
||||
.entrypoint([])
|
||||
.uv_pip_install(
|
||||
"vllm==0.10.2",
|
||||
"torch==2.8.0",
|
||||
"huggingface_hub[hf_transfer]==0.35.0",
|
||||
)
|
||||
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
||||
)
|
||||
|
||||
app = modal.App("legal-query-expander-qwen3-v2", image=image)
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-4B-Instruct" # use instruct, defaults everywhere
|
||||
_llm = None # warm-container cache
|
||||
|
||||
|
||||
def _extract_json_array(text: str) -> List[str]:
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
return [x for x in parsed if isinstance(x, str)]
|
||||
except Exception:
|
||||
pass
|
||||
m = re.search(r"\[(?:.|\n|\r)*\]", text)
|
||||
if m:
|
||||
try:
|
||||
parsed = json.loads(m.group(0))
|
||||
return [x for x in parsed if isinstance(x, str)]
|
||||
except Exception:
|
||||
return []
|
||||
return []
|
||||
|
||||
|
||||
def _sanitize_and_dedupe(items: List[str], n: int) -> List[str]:
|
||||
out, seen = [], set()
|
||||
for q in items:
|
||||
q = re.sub(r"[^\w\s\-./]", "", (q or "")).strip()
|
||||
k = q.lower()
|
||||
if q and k not in seen:
|
||||
seen.add(k)
|
||||
out.append(q)
|
||||
if len(out) >= n:
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
gpu=modal.gpu.L4(), # pick any available GPU (A100/H100 also fine)
|
||||
timeout=600,
|
||||
secrets=[modal.Secret.from_name("huggingface-secret")], # set HF token here
|
||||
)
|
||||
def expand(question: str, n: int = 5) -> List[str]:
|
||||
"""
|
||||
Return up to n short, diverse retrieval keyphrases for Bare Acts.
|
||||
Uses Qwen3-4B-Instruct with its default chat template.
|
||||
"""
|
||||
global _llm
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# ensure HF token is available to vLLM
|
||||
tok = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
|
||||
if tok and not os.environ.get("HUGGINGFACE_HUB_TOKEN"):
|
||||
os.environ["HUGGINGFACE_HUB_TOKEN"] = tok
|
||||
|
||||
if _llm is None:
|
||||
_llm = LLM(
|
||||
model=MODEL_NAME,
|
||||
trust_remote_code=True,
|
||||
dtype="auto",
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
user = (
|
||||
"You are Search Query Expander."
|
||||
"For given search query you give 4-5 different variants of it to search the database better. It is a legal search query and our database is of legal data like bare acts."
|
||||
"Respond ONLY as a JSON array of strings; no prose, no section numbers."
|
||||
f"Question:\n{question}\n\n"
|
||||
f"Return {n} distinct keyphrases (4–20 words each), which captures the what to search inside rag database. Return as a JSON array. No commentary."
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
|
||||
tokenizer = _llm.get_tokenizer()
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
result = _llm.generate(
|
||||
[prompt],
|
||||
SamplingParams(
|
||||
max_tokens=256,
|
||||
temperature=0.2,
|
||||
),
|
||||
)
|
||||
text = result[0].outputs[0].text
|
||||
|
||||
arr = _sanitize_and_dedupe(_extract_json_array(text), n)
|
||||
|
||||
if not arr:
|
||||
# deterministic fallback (keeps things non-empty)
|
||||
base = re.sub(r"[?]+$", "", (question or "")).strip()
|
||||
pool = [
|
||||
f"{base} section",
|
||||
f"{base} provision bare act",
|
||||
f"{base} indian penal code",
|
||||
f"{base} bharatiya nyaya sanhita",
|
||||
f"{base} punishment section keywords",
|
||||
]
|
||||
arr = _sanitize_and_dedupe(pool, n)
|
||||
|
||||
return arr
|
||||
Reference in New Issue
Block a user