From 8ed3a97275c65592d4b6df627dd7dca6d2cd09e6 Mon Sep 17 00:00:00 2001 From: Nik Date: Thu, 30 Oct 2025 21:40:02 +0530 Subject: [PATCH] add Modal+Qwen3 query expansion and self-critique agentic loop --- .../README.md | 49 ++++++ ...ntic_legal_qna_with_rag_on_bare_acts.ipynb | 155 +++++++++++++++++- .../modal_expander.py | 120 ++++++++++++++ 3 files changed, 315 insertions(+), 9 deletions(-) create mode 100644 week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/README.md create mode 100644 week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/modal_expander.py diff --git a/week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/README.md b/week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/README.md new file mode 100644 index 0000000..4fb5a40 --- /dev/null +++ b/week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/README.md @@ -0,0 +1,49 @@ +# Agentic Legal Q&A on Bare Acts (Week 8) + +An **agentic RAG** demo that answers legal questions from Indian Bare Acts (IPC/BNS/Constitution). +Pipeline: **Query expansion (Modal+Qwen3) → Multi-retrieval (Chroma) → Neighbor-aware context merge → LLM answer → Self-critique → Optional second pass**. +UI: lightweight **Gradio** chat with live agent logs. + +## Features +- **Modal-first expander:** `modal_expander.py` (Qwen3-4B via vLLM, GPU) with local LLM fallback. +- **Vector store:** Chroma + `all-MiniLM-L6-v2`, token-span aware chunking, ±neighbor merge. +- **Agentic loop:** critic validates citations and triggers follow-up retrievals if needed. +- **Config knobs:** top-k per rewrite, neighbor radius, max merged blocks, model dropdown. + +## Setup +```bash +python -m pip install -U openai chromadb transformers gradio python-dotenv modal +```` + +Create `.env` with your keys: + +```bash +OPENAI_API_KEY=... +``` + +Place Bare Acts as UTF-8 `.txt` files in: + +``` +knowledge_base/bare_acts/ # e.g., ipc.txt, bns.txt, coi.txt +``` + +## Deploy the Modal expander + +Set a Modal secret named `huggingface-secret` containing `HUGGINGFACE_HUB_TOKEN`, then: + +```bash +modal deploy -m modal_expander +``` + +## Run the notebook app + +```bash +jupyter notebook agentic_legal_qna_with_rag_on_bare_acts.ipynb +``` + +Run all cells; a Gradio chat appears. Tune **Top-K**, **Neighbor radius**, and **Max blocks** under *Advanced*. + +## Notes + +* Default OpenAI model: `gpt-4o-mini` (change via UI). +* Vector DB is persisted in `vector_db_w8`; re-run the indexing cell to rebuild after data changes. \ No newline at end of file diff --git a/week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/agentic_legal_qna_with_rag_on_bare_acts.ipynb b/week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/agentic_legal_qna_with_rag_on_bare_acts.ipynb index f03d683..090c2d8 100644 --- a/week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/agentic_legal_qna_with_rag_on_bare_acts.ipynb +++ b/week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/agentic_legal_qna_with_rag_on_bare_acts.ipynb @@ -13,9 +13,10 @@ "from dataclasses import dataclass, field\n", "from typing import List, Dict, Any, Tuple, Protocol\n", "\n", + "\n", "from pathlib import Path\n", "\n", - "from dotenv import load_dotenv\n", + "\n", "from openai import OpenAI\n", "import chromadb\n", "from chromadb import PersistentClient\n", @@ -24,6 +25,7 @@ "from transformers import AutoTokenizer\n", "import gradio as gr\n", "\n", + "from dotenv import load_dotenv\n", "load_dotenv(override=True)" ] }, @@ -735,6 +737,73 @@ " )\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "8284580f", + "metadata": {}, + "outputs": [], + "source": [ + "class Critic:\n", + " def __init__(self, llm: LLM):\n", + " self.llm = llm\n", + "\n", + " def review(self, question: str, answer: str, blocks: List[Dict]) -> Dict[str, Any]:\n", + " block_ids = [f\"[{b['source']} {b['start']}:{b['end']}]\" for b in blocks]\n", + " sys = (\n", + " \"You are a meticulous legal verifier. \"\n", + " \"Return ONLY JSON with keys: ok (bool), missing (list of short missing facts), followups (list of short retrieval keyphrases).\"\n", + " )\n", + " user = f\"\"\"Question:\n", + "{question}\n", + "\n", + "Proposed answer:\n", + "{answer}\n", + "\n", + "Verify that every factual claim is supported by the context blocks (by their ids).\n", + "If support is weak or missing, set ok=false and propose concise retrieval keyphrases in followups.\n", + "\n", + "Available context block ids:\n", + "{\", \".join(block_ids)}\n", + "\n", + "Return JSON only, e.g.:\n", + "{{\"ok\": true, \"missing\": [], \"followups\": []}}\n", + "\"\"\"\n", + " raw = self.llm.complete(system=sys, user=user, temperature=0.0, max_tokens=220)\n", + " try:\n", + " m = re.search(r\"\\{(.|\\n)*\\}\", raw)\n", + " return json.loads(m.group(0)) if m else {\"ok\": True, \"missing\": [], \"followups\": []}\n", + " except Exception:\n", + " return {\"ok\": True, \"missing\": [], \"followups\": []}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34449337", + "metadata": {}, + "outputs": [], + "source": [ + "import modal\n", + "_remote_expand = modal.Function.from_name(\"legal-query-expander-qwen3-v2\", \"expand\")\n", + "\n", + "\n", + "class ModalFirstExpander(QueryExpander):\n", + " def expand(self, question: str, *, model_override: str | None = None) -> List[str]:\n", + " got = []\n", + " if _remote_expand:\n", + " try:\n", + " got = _remote_expand.remote(question, self.n)\n", + " except Exception:\n", + " got = []\n", + " if not got or len([x for x in got if isinstance(x, str) and x.strip()]) < max(1, self.n // 2):\n", + " return super().expand(question, model_override=model_override)\n", + " import re\n", + " got = [re.sub(r\"[^\\w\\s\\-./]\", \"\", q).strip() for q in got]\n", + " return [q for q in got if q][: self.n]\n", + "\n" + ] + }, { "cell_type": "code", "execution_count": null, @@ -751,16 +820,33 @@ " self.set = set\n", "\n", " def run(self, question: str, model: str | None = None) -> str:\n", + " # Pass 1\n", " rewrites = self.expander.expand(question, model_override=model)\n", " hits: List[Dict] = []\n", " for q in rewrites:\n", " hits.extend(self.retriever.topk(q))\n", " blocks = self.ctx_expander.expand_and_merge(hits)\n", - " return self.answerer.answer(question, blocks, model=model)\n", + " answer1 = self.answerer.answer(question, blocks, model=model)\n", + "\n", + " # Self-critique\n", + " critic = Critic(self.answerer.llm)\n", + " review = critic.review(question, answer1, blocks)\n", + "\n", + " if review.get(\"ok\", True) or not review.get(\"followups\"):\n", + " return answer1 # Good enough\n", + "\n", + " # Pass 2 — adapt plan using follow-up rewrites\n", + " extra_hits: List[Dict] = []\n", + " for q in review[\"followups\"][:3]: # keep bounded\n", + " extra_hits.extend(self.retriever.topk(q))\n", + " blocks2 = self.ctx_expander.expand_and_merge(hits + extra_hits)\n", + " answer2 = self.answerer.answer(question, blocks2, model=model)\n", + " return answer2 + \"\\n\\n_Refined after self-critique (agentic step)._\"\n", "\n", " def run_stream(self, question: str, model: str | None = None):\n", " \"\"\"\n", " Generator: yields tuples (answer_or_none, logs_html) multiple times.\n", + " Now includes a self-critique step and (if needed) a second pass, with logs.\n", " \"\"\"\n", " log = RunLogger()\n", " log.add(f\"Question: {question}\")\n", @@ -797,15 +883,66 @@ " )\n", " for b in blocks:\n", " log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n", - " log.add(b[\"text\"])\n", + " log.add(b[\"text\"][:50].replace(\"\\n\", \" \"))\n", " yield None, log.html()\n", "\n", - " # 4) LLM answer\n", - " log.add(f\"Asking LLM: {model or self.set.gen_model}\")\n", + " # 4) LLM answer — pass 1\n", + " log.add(f\"Asking LLM (pass 1): {model or self.set.gen_model}\")\n", " yield None, log.html()\n", - " answer = self.answerer.answer(question, blocks, model=model)\n", - " log.add(\"Answer ready.\")\n", - " yield answer, log.html()\n" + " answer1 = self.answerer.answer(question, blocks, model=model)\n", + " log.add(\"Answer 1 ready.\")\n", + " yield answer1, log.html()\n", + "\n", + " # 5) Self-critique\n", + " log.add(\"Running self-critique…\")\n", + " yield None, log.html()\n", + " critic = Critic(self.answerer.llm)\n", + " review = critic.review(question, answer1, blocks)\n", + " ok = review.get(\"ok\", True)\n", + " missing = review.get(\"missing\") or []\n", + " followups = review.get(\"followups\") or []\n", + " # compact display\n", + " def _preview(lst, n=5):\n", + " return \", \".join(lst[:n]) + (f\" … (+{len(lst)-n} more)\" if len(lst) > n else \"\") if lst else \"—\"\n", + " log.add(f\"Self-critique result: ok={ok}; missing={_preview(missing)}; followups={_preview(followups)}\")\n", + " yield None, log.html()\n", + "\n", + " if ok or not followups:\n", + " log.add(\"Critique passed or no follow-ups. Finalizing pass 1.\")\n", + " yield answer1, log.html()\n", + " return\n", + "\n", + " # 6) Pass 2 — follow-up retrievals\n", + " extra_hits: List[Dict] = []\n", + " limited = followups[:3] # keep bounded\n", + " for i, q in enumerate(limited, 1):\n", + " hits = self.retriever.topk(q)\n", + " extra_hits.extend(hits)\n", + " top3 = \", \".join(_hit_id(h) for h in hits[:3]) or \"—\"\n", + " log.add(f\"Follow-up retrieval {i}/{len(limited)}: got {len(hits)} hits → {top3}\")\n", + " yield None, log.html()\n", + "\n", + " blocks2 = self.ctx_expander.expand_and_merge(all_hits + extra_hits)\n", + " used2 = self.ctx_expander.last_ids\n", + " peek2 = \", \".join(used2[:8]) + (\" …\" if len(used2) > 8 else \"\")\n", + " log.add(f\"Neighbor addition (pass 2): collected {len(used2)} chunk-ids → {peek2}\")\n", + "\n", + " approx_words2 = int(self.ctx_expander.pad_tokens / 1.4)\n", + " log.add(\n", + " f\"Context expansion (pass 2): merged {len(blocks2)} block(s) \"\n", + " f\"(radius ±{self.ctx_expander.radius}, pad ≈{approx_words2} words).\"\n", + " )\n", + " for b in blocks2[:10]: # don’t spam the log\n", + " log.add(f\" [{b['source']} {b['start']}:{b['end']}]\")\n", + " yield None, log.html()\n", + "\n", + " # 7) LLM answer — pass 2\n", + " log.add(f\"Asking LLM (pass 2): {model or self.set.gen_model}\")\n", + " yield None, log.html()\n", + " answer2 = self.answerer.answer(question, blocks2, model=model)\n", + " final_answer = answer2 + \"\\n\\n_Refined after self-critique (agentic step)._\"\n", + " log.add(\"Answer 2 ready (refined).\")\n", + " yield final_answer, log.html()\n" ] }, { @@ -817,7 +954,7 @@ "source": [ "def make_agent(gen_model: str) -> LegalAgent:\n", " llm = OpenAILLM(gen_model)\n", - " expander = QueryExpander(llm=llm, n=SET.expansions)\n", + " expander = ModalFirstExpander(llm=llm, n=SET.expansions) # <— here\n", " retriever = Retriever(index=index, k=SET.topk_per_rewrite)\n", " ctx_expander = ContextExpander(index=index, radius=SET.neighbor_radius, max_blocks=SET.max_blocks_for_llm)\n", " answerer = Answerer(llm=llm, set=SET)\n", diff --git a/week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/modal_expander.py b/week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/modal_expander.py new file mode 100644 index 0000000..bb57075 --- /dev/null +++ b/week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/modal_expander.py @@ -0,0 +1,120 @@ +# week8/community_contributions/agentic_legal_qna_with_rag_on_bare_acts/modal_expander.py +import os, json, re +from typing import List +import modal + +# minimal image: vLLM + torch + HF hub +image = ( + modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12") + .entrypoint([]) + .uv_pip_install( + "vllm==0.10.2", + "torch==2.8.0", + "huggingface_hub[hf_transfer]==0.35.0", + ) + .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) +) + +app = modal.App("legal-query-expander-qwen3-v2", image=image) + +MODEL_NAME = "Qwen/Qwen3-4B-Instruct" # use instruct, defaults everywhere +_llm = None # warm-container cache + + +def _extract_json_array(text: str) -> List[str]: + try: + parsed = json.loads(text) + return [x for x in parsed if isinstance(x, str)] + except Exception: + pass + m = re.search(r"\[(?:.|\n|\r)*\]", text) + if m: + try: + parsed = json.loads(m.group(0)) + return [x for x in parsed if isinstance(x, str)] + except Exception: + return [] + return [] + + +def _sanitize_and_dedupe(items: List[str], n: int) -> List[str]: + out, seen = [], set() + for q in items: + q = re.sub(r"[^\w\s\-./]", "", (q or "")).strip() + k = q.lower() + if q and k not in seen: + seen.add(k) + out.append(q) + if len(out) >= n: + break + return out + + +@app.function( + image=image, + gpu=modal.gpu.L4(), # pick any available GPU (A100/H100 also fine) + timeout=600, + secrets=[modal.Secret.from_name("huggingface-secret")], # set HF token here +) +def expand(question: str, n: int = 5) -> List[str]: + """ + Return up to n short, diverse retrieval keyphrases for Bare Acts. + Uses Qwen3-4B-Instruct with its default chat template. + """ + global _llm + from vllm import LLM, SamplingParams + + # ensure HF token is available to vLLM + tok = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") + if tok and not os.environ.get("HUGGINGFACE_HUB_TOKEN"): + os.environ["HUGGINGFACE_HUB_TOKEN"] = tok + + if _llm is None: + _llm = LLM( + model=MODEL_NAME, + trust_remote_code=True, + dtype="auto", + tensor_parallel_size=1, + ) + + user = ( + "You are Search Query Expander." + "For given search query you give 4-5 different variants of it to search the database better. It is a legal search query and our database is of legal data like bare acts." + "Respond ONLY as a JSON array of strings; no prose, no section numbers." + f"Question:\n{question}\n\n" + f"Return {n} distinct keyphrases (4–20 words each), which captures the what to search inside rag database. Return as a JSON array. No commentary." + ) + + messages = [ + {"role": "user", "content": user}, + ] + + tokenizer = _llm.get_tokenizer() + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + result = _llm.generate( + [prompt], + SamplingParams( + max_tokens=256, + temperature=0.2, + ), + ) + text = result[0].outputs[0].text + + arr = _sanitize_and_dedupe(_extract_json_array(text), n) + + if not arr: + # deterministic fallback (keeps things non-empty) + base = re.sub(r"[?]+$", "", (question or "")).strip() + pool = [ + f"{base} section", + f"{base} provision bare act", + f"{base} indian penal code", + f"{base} bharatiya nyaya sanhita", + f"{base} punishment section keywords", + ] + arr = _sanitize_and_dedupe(pool, n) + + return arr