Merge pull request #843 from rnik12/rnik12-week5
[Bootcamp] - Nikhil - Week 5 Exercise - RAG on Legal Bare Acts Txt Copied from PDF
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,376 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d27544d4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import Dict, List, Optional, Tuple\n",
|
||||
"\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from openai import OpenAI\n",
|
||||
"import gradio as gr\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import List, Tuple\n",
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# ---- load env ----\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"\n",
|
||||
"# ---- OpenAI-compatible base URLs (Gemini & Groq) ----\n",
|
||||
"GEMINI_BASE = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
|
||||
"GROQ_BASE = \"https://api.groq.com/openai/v1\"\n",
|
||||
"\n",
|
||||
"OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n",
|
||||
"GOOGLE_API_KEY = os.getenv(\"GOOGLE_API_KEY\") # Gemini\n",
|
||||
"GROQ_API_KEY = os.getenv(\"GROQ_API_KEY\") # Groq\n",
|
||||
"\n",
|
||||
"# ---- create clients only if keys exist ----\n",
|
||||
"openai_client = OpenAI() if OPENAI_API_KEY else None\n",
|
||||
"gemini_client = OpenAI(api_key=GOOGLE_API_KEY, base_url=GEMINI_BASE) if GOOGLE_API_KEY else None\n",
|
||||
"groq_client = OpenAI(api_key=GROQ_API_KEY, base_url=GROQ_BASE) if GROQ_API_KEY else None\n",
|
||||
"\n",
|
||||
"# ---- model registry (label -> client/model) ----\n",
|
||||
"MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}\n",
|
||||
"def _register(label: str, client: Optional[OpenAI], model_id: str):\n",
|
||||
" if client is not None:\n",
|
||||
" MODEL_REGISTRY[label] = {\"client\": client, \"model\": model_id}\n",
|
||||
"\n",
|
||||
"# OpenAI\n",
|
||||
"_register(\"OpenAI • GPT-5\", openai_client, \"gpt-5\")\n",
|
||||
"_register(\"OpenAI • GPT-5 Nano\", openai_client, \"gpt-5-nano\")\n",
|
||||
"_register(\"OpenAI • GPT-4o-mini\", openai_client, \"gpt-4o-mini\")\n",
|
||||
"\n",
|
||||
"# Gemini (Google)\n",
|
||||
"_register(\"Gemini • 2.5 Pro\", gemini_client, \"gemini-2.5-pro\")\n",
|
||||
"_register(\"Gemini • 2.5 Flash\", gemini_client, \"gemini-2.5-flash\")\n",
|
||||
"\n",
|
||||
"# Groq\n",
|
||||
"_register(\"Groq • Llama 3.1 8B\", groq_client, \"llama-3.1-8b-instant\")\n",
|
||||
"_register(\"Groq • Llama 3.3 70B\", groq_client, \"llama-3.3-70b-versatile\")\n",
|
||||
"_register(\"Groq • GPT-OSS 20B\", groq_client, \"openai/gpt-oss-20b\")\n",
|
||||
"_register(\"Groq • GPT-OSS 120B\", groq_client, \"openai/gpt-oss-120b\")\n",
|
||||
"\n",
|
||||
"AVAILABLE_MODELS = list(MODEL_REGISTRY.keys())\n",
|
||||
"DEFAULT_MODEL = AVAILABLE_MODELS[0] if AVAILABLE_MODELS else \"OpenAI • GPT-4o-mini\"\n",
|
||||
"\n",
|
||||
"print(\"Providers configured →\",\n",
|
||||
" f\"OpenAI:{bool(OPENAI_API_KEY)} Gemini:{bool(GOOGLE_API_KEY)} Groq:{bool(GROQ_API_KEY)}\")\n",
|
||||
"print(\"Models available →\", \", \".join(AVAILABLE_MODELS) or \"None (add API keys in .env)\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "efe4e4db",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass(frozen=True)\n",
|
||||
"class LLMRoute:\n",
|
||||
" client: OpenAI\n",
|
||||
" model: str\n",
|
||||
"\n",
|
||||
"class MultiLLM:\n",
|
||||
" \"\"\"OpenAI-compatible chat across providers (OpenAI, Gemini, Groq).\"\"\"\n",
|
||||
" def __init__(self, registry: Dict[str, Dict[str, object]]):\n",
|
||||
" self._routes: Dict[str, LLMRoute] = {\n",
|
||||
" k: LLMRoute(client=v[\"client\"], model=str(v[\"model\"])) for k, v in registry.items()\n",
|
||||
" }\n",
|
||||
" if not self._routes:\n",
|
||||
" raise RuntimeError(\"No LLM providers configured. Add API keys in .env.\")\n",
|
||||
"\n",
|
||||
" def complete(self, *, model_label: str, system: str, user: str) -> str:\n",
|
||||
" if model_label not in self._routes:\n",
|
||||
" raise ValueError(f\"Unknown model: {model_label}\")\n",
|
||||
" r = self._routes[model_label]\n",
|
||||
" resp = r.client.chat.completions.create(\n",
|
||||
" model=r.model,\n",
|
||||
" messages=[{\"role\":\"system\",\"content\":system},\n",
|
||||
" {\"role\":\"user\",\"content\":user}]\n",
|
||||
" )\n",
|
||||
" return (resp.choices[0].message.content or \"\").strip()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "30636b66",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# MiniLM embedding model & tokenizer (BERT WordPiece)\n",
|
||||
"EMBED_MODEL_NAME = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
|
||||
"\n",
|
||||
"# Use the model's practical window with 50% overlap\n",
|
||||
"MAX_TOKENS = 256 # all-MiniLM-L6-v2 effective limit used by Sentence-Transformers\n",
|
||||
"OVERLAP_RATIO = 0.50 # 50% sliding window overlap\n",
|
||||
"\n",
|
||||
"TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_NAME)\n",
|
||||
"\n",
|
||||
"def chunk_text(\n",
|
||||
" text: str,\n",
|
||||
" tokenizer: AutoTokenizer = TOKENIZER,\n",
|
||||
" max_tokens: int = MAX_TOKENS,\n",
|
||||
" overlap_ratio: float = OVERLAP_RATIO,\n",
|
||||
") -> List[str]:\n",
|
||||
" \"\"\"\n",
|
||||
" Token-aware sliding window chunking for MiniLM.\n",
|
||||
" - Windows of `max_tokens`\n",
|
||||
" - Step = max_tokens * (1 - overlap_ratio) -> 50% overlap by default\n",
|
||||
" \"\"\"\n",
|
||||
" ids = tokenizer.encode(text, add_special_tokens=False)\n",
|
||||
" if not ids:\n",
|
||||
" return []\n",
|
||||
"\n",
|
||||
" step = max(1, int(max_tokens * (1.0 - overlap_ratio)))\n",
|
||||
" out: List[str] = []\n",
|
||||
" for start in range(0, len(ids), step):\n",
|
||||
" window = ids[start : start + max_tokens]\n",
|
||||
" if not window:\n",
|
||||
" break\n",
|
||||
" toks = tokenizer.convert_ids_to_tokens(window)\n",
|
||||
" chunk = tokenizer.convert_tokens_to_string(toks).strip()\n",
|
||||
" if chunk:\n",
|
||||
" out.append(chunk)\n",
|
||||
" if start + max_tokens >= len(ids):\n",
|
||||
" break\n",
|
||||
" return out\n",
|
||||
"\n",
|
||||
"def load_bare_acts(root: str = \"knowledge_base/bare_acts\") -> List[Tuple[str, str]]:\n",
|
||||
" \"\"\"Return list of (source_id, text). `source_id` is filename stem.\"\"\"\n",
|
||||
" base = Path(root)\n",
|
||||
" if not base.exists():\n",
|
||||
" raise FileNotFoundError(f\"Folder not found: {base.resolve()}\")\n",
|
||||
" pairs: List[Tuple[str, str]] = []\n",
|
||||
" for p in sorted(base.glob(\"*.txt\")):\n",
|
||||
" pairs.append((p.stem, p.read_text(encoding=\"utf-8\")))\n",
|
||||
" if not pairs:\n",
|
||||
" raise RuntimeError(\"No .txt files found under knowledge_base/bare_acts\")\n",
|
||||
" return pairs\n",
|
||||
"\n",
|
||||
"acts_raw = load_bare_acts()\n",
|
||||
"print(\"Bare Acts loaded:\", [s for s, _ in acts_raw])\n",
|
||||
"print(f\"Chunking → max_tokens={MAX_TOKENS}, overlap={int(OVERLAP_RATIO*100)}%\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "af537e05",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import chromadb\n",
|
||||
"from chromadb import PersistentClient\n",
|
||||
"from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n",
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"from typing import Dict, List, Tuple\n",
|
||||
"\n",
|
||||
"class BareActsIndex:\n",
|
||||
" \"\"\"Owns the vector DB lifecycle & retrieval (token-aware chunking).\"\"\"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" db_path: str = \"vector_db\",\n",
|
||||
" collection: str = \"bare_acts\",\n",
|
||||
" embed_model: str = EMBED_MODEL_NAME,\n",
|
||||
" max_tokens: int = MAX_TOKENS,\n",
|
||||
" overlap_ratio: float = OVERLAP_RATIO,\n",
|
||||
" ):\n",
|
||||
" self.db_path = db_path\n",
|
||||
" self.collection_name = collection\n",
|
||||
" self.embed_model = embed_model\n",
|
||||
" self.max_tokens = max_tokens\n",
|
||||
" self.overlap_ratio = overlap_ratio\n",
|
||||
"\n",
|
||||
" self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=self.embed_model)\n",
|
||||
" self.tokenizer = AutoTokenizer.from_pretrained(self.embed_model)\n",
|
||||
"\n",
|
||||
" self.client: PersistentClient = PersistentClient(path=db_path)\n",
|
||||
" self.col = self.client.get_or_create_collection(\n",
|
||||
" name=self.collection_name,\n",
|
||||
" embedding_function=self.embed_fn,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def rebuild(self, docs: List[Tuple[str, str]]):\n",
|
||||
" \"\"\"Idempotent rebuild: clears and re-adds chunks with metadata.\"\"\"\n",
|
||||
" try:\n",
|
||||
" self.client.delete_collection(self.collection_name)\n",
|
||||
" except Exception:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" self.col = self.client.get_or_create_collection(\n",
|
||||
" name=self.collection_name,\n",
|
||||
" embedding_function=self.embed_fn,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" ids, texts, metas = [], [], []\n",
|
||||
" for src, text in docs:\n",
|
||||
" for idx, ch in enumerate(\n",
|
||||
" chunk_text(\n",
|
||||
" text,\n",
|
||||
" tokenizer=self.tokenizer,\n",
|
||||
" max_tokens=self.max_tokens,\n",
|
||||
" overlap_ratio=self.overlap_ratio,\n",
|
||||
" )\n",
|
||||
" ):\n",
|
||||
" ids.append(f\"{src}-{idx}\")\n",
|
||||
" texts.append(ch)\n",
|
||||
" metas.append({\"source\": src, \"chunk_id\": idx})\n",
|
||||
"\n",
|
||||
" if ids:\n",
|
||||
" self.col.add(ids=ids, documents=texts, metadatas=metas)\n",
|
||||
"\n",
|
||||
" print(\n",
|
||||
" f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name} \"\n",
|
||||
" f\"(tokens/chunk={self.max_tokens}, overlap={int(self.overlap_ratio*100)}%)\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def query(self, q: str, k: int = 6) -> List[Dict]:\n",
|
||||
" res = self.col.query(query_texts=[q], n_results=k)\n",
|
||||
" docs = res.get(\"documents\", [[]])[0]\n",
|
||||
" metas = res.get(\"metadatas\", [[]])[0]\n",
|
||||
" return [{\"text\": d, \"meta\": m} for d, m in zip(docs, metas)]\n",
|
||||
"\n",
|
||||
"# build (or rebuild) the index once\n",
|
||||
"index = BareActsIndex()\n",
|
||||
"index.rebuild(acts_raw)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7eec89e4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class PromptBuilder:\n",
|
||||
" \"\"\"Small utility to keep prompting consistent and auditable.\"\"\"\n",
|
||||
" SYSTEM = (\n",
|
||||
" \"You are a precise legal assistant for Indian Bare Acts. \"\n",
|
||||
" \"Answer ONLY from the provided context. If the answer is not in context, say you don't know. \"\n",
|
||||
" \"Cite sources inline in square brackets as [file #chunk] (e.g., [bns #12]). \"\n",
|
||||
" \"Prefer exact quotes for critical provisions/sections.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def build_user(query: str, contexts: List[Dict]) -> str:\n",
|
||||
" ctx = \"\\n\\n---\\n\\n\".join(\n",
|
||||
" f\"[{c['meta']['source']} #{c['meta']['chunk_id']}]\\n{c['text']}\" for c in contexts\n",
|
||||
" )\n",
|
||||
" return (\n",
|
||||
" f\"Question:\\n{query}\\n\\n\"\n",
|
||||
" f\"Context (do not use outside this):\\n{ctx}\\n\\n\"\n",
|
||||
" \"Instructions:\\n- Keep answers concise and faithful to the text.\\n\"\n",
|
||||
" \"- Use [file #chunk] inline where relevant.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"def _snippet(txt: str, n: int = 220) -> str:\n",
|
||||
" s = \" \".join(txt.strip().split())\n",
|
||||
" return (s[:n] + \"…\") if len(s) > n else s\n",
|
||||
"\n",
|
||||
"class RagQAService:\n",
|
||||
" \"\"\"Coordinates retrieval + generation, and returns a rich reference block.\"\"\"\n",
|
||||
" def __init__(self, index: BareActsIndex, llm: MultiLLM):\n",
|
||||
" self.index = index\n",
|
||||
" self.llm = llm\n",
|
||||
" self.builder = PromptBuilder()\n",
|
||||
"\n",
|
||||
" def answer(self, *, question: str, model_label: str, k: int = 6) -> str:\n",
|
||||
" ctx = self.index.query(question, k=k)\n",
|
||||
" user = self.builder.build_user(question, ctx)\n",
|
||||
" reply = self.llm.complete(model_label=model_label, system=self.builder.SYSTEM, user=user)\n",
|
||||
"\n",
|
||||
" # Rich references: file, chunk index, snippet\n",
|
||||
" references = \"\\n\".join(\n",
|
||||
" f\"- [{c['meta']['source']} #{c['meta']['chunk_id']}] {_snippet(c['text'])}\"\n",
|
||||
" for c in ctx\n",
|
||||
" )\n",
|
||||
" return f\"{reply}\\n\\n**References**\\n{references}\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4862732b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = MultiLLM(MODEL_REGISTRY)\n",
|
||||
"qa_service = RagQAService(index=index, llm=llm)\n",
|
||||
"\n",
|
||||
"# quick smoke test (won't spend tokens if no keys for that provider)\n",
|
||||
"if AVAILABLE_MODELS:\n",
|
||||
" print(\"Ready. Default model:\", DEFAULT_MODEL)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0b1512b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def chat_fn(message: str, history: List[Dict], model_label: str, top_k: int) -> str:\n",
|
||||
" try:\n",
|
||||
" return qa_service.answer(question=message, model_label=model_label, k=int(top_k))\n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"⚠️ {e}\"\n",
|
||||
"\n",
|
||||
"DEFAULT_QUESTION = \"Which sections deals with punishment for murder ?\"\n",
|
||||
"\n",
|
||||
"with gr.Blocks(title=\"Legal QnA • Bare Acts (RAG + Multi-LLM)\") as app:\n",
|
||||
" gr.Markdown(\"### 🧑⚖️ Legal Q&A on Bare Acts (RAG) — Multi-Provider LLM\")\n",
|
||||
" with gr.Row():\n",
|
||||
" model_dd = gr.Dropdown(\n",
|
||||
" choices=AVAILABLE_MODELS or [\"OpenAI • GPT-4o-mini\"],\n",
|
||||
" value=DEFAULT_MODEL if AVAILABLE_MODELS else None,\n",
|
||||
" label=\"Model\"\n",
|
||||
" )\n",
|
||||
" topk = gr.Slider(2, 12, value=6, step=1, label=\"Top-K context\")\n",
|
||||
"\n",
|
||||
" chat = gr.ChatInterface(\n",
|
||||
" fn=chat_fn,\n",
|
||||
" type=\"messages\",\n",
|
||||
" additional_inputs=[model_dd, topk],\n",
|
||||
" textbox=gr.Textbox(\n",
|
||||
" value=DEFAULT_QUESTION,\n",
|
||||
" label=\"Ask a legal question\",\n",
|
||||
" placeholder=\"Type your question about BNS/IPC/Constitution…\"\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"app.launch(inbrowser=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "llm-engineering",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user