From fd7d9da85213781930db2429423d1238f8adc3de Mon Sep 17 00:00:00 2001 From: Nik Date: Sun, 26 Oct 2025 20:52:03 +0530 Subject: [PATCH] Update legal_qna_with_rag_on_bare_acts.ipynb --- .../legal_qna_with_rag_on_bare_acts.ipynb | 270 ++++++++++++++++++ 1 file changed, 270 insertions(+) diff --git a/week5/community-contributions/legal_qna_with_rag_on_bare_acts/legal_qna_with_rag_on_bare_acts.ipynb b/week5/community-contributions/legal_qna_with_rag_on_bare_acts/legal_qna_with_rag_on_bare_acts.ipynb index e69de29..c64b90a 100644 --- a/week5/community-contributions/legal_qna_with_rag_on_bare_acts/legal_qna_with_rag_on_bare_acts.ipynb +++ b/week5/community-contributions/legal_qna_with_rag_on_bare_acts/legal_qna_with_rag_on_bare_acts.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "d27544d4", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from dataclasses import dataclass\n", + "from pathlib import Path\n", + "from typing import Dict, List, Optional, Tuple\n", + "\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "\n", + "# ---- load env ----\n", + "load_dotenv(override=True)\n", + "\n", + "# ---- OpenAI-compatible base URLs (Gemini & Groq) ----\n", + "GEMINI_BASE = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n", + "GROQ_BASE = \"https://api.groq.com/openai/v1\"\n", + "\n", + "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n", + "GOOGLE_API_KEY = os.getenv(\"GOOGLE_API_KEY\") # Gemini\n", + "GROQ_API_KEY = os.getenv(\"GROQ_API_KEY\") # Groq\n", + "\n", + "# ---- create clients only if keys exist ----\n", + "openai_client = OpenAI() if OPENAI_API_KEY else None\n", + "gemini_client = OpenAI(api_key=GOOGLE_API_KEY, base_url=GEMINI_BASE) if GOOGLE_API_KEY else None\n", + "groq_client = OpenAI(api_key=GROQ_API_KEY, base_url=GROQ_BASE) if GROQ_API_KEY else None\n", + "\n", + "# ---- model registry (label -> client/model) ----\n", + "MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}\n", + "def _register(label: str, client: Optional[OpenAI], model_id: str):\n", + " if client is not None:\n", + " MODEL_REGISTRY[label] = {\"client\": client, \"model\": model_id}\n", + "\n", + "# OpenAI\n", + "_register(\"OpenAI • GPT-4o-mini\", openai_client, \"gpt-4o-mini\")\n", + "\n", + "# Gemini\n", + "_register(\"Gemini • 2.5 Flash\", gemini_client, \"gemini-2.5-flash\")\n", + "_register(\"Gemini • 2.5 Pro\", gemini_client, \"gemini-2.5-pro\")\n", + "\n", + "# Groq\n", + "_register(\"Groq • Llama 3.3 70B\", groq_client, \"llama-3.3-70b-versatile\")\n", + "_register(\"Groq • Llama 3.1 8B\", groq_client, \"llama-3.1-8b-instant\")\n", + "\n", + "AVAILABLE_MODELS = list(MODEL_REGISTRY.keys())\n", + "DEFAULT_MODEL = AVAILABLE_MODELS[0] if AVAILABLE_MODELS else \"OpenAI • GPT-4o-mini\"\n", + "\n", + "print(\"Providers configured →\",\n", + " f\"OpenAI:{bool(OPENAI_API_KEY)} Gemini:{bool(GOOGLE_API_KEY)} Groq:{bool(GROQ_API_KEY)}\")\n", + "print(\"Models available →\", \", \".join(AVAILABLE_MODELS) or \"None (add API keys in .env)\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efe4e4db", + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass(frozen=True)\n", + "class LLMRoute:\n", + " client: OpenAI\n", + " model: str\n", + "\n", + "class MultiLLM:\n", + " \"\"\"OpenAI-compatible chat across providers (OpenAI, Gemini, Groq).\"\"\"\n", + " def __init__(self, registry: Dict[str, Dict[str, object]]):\n", + " self._routes: Dict[str, LLMRoute] = {\n", + " k: LLMRoute(client=v[\"client\"], model=str(v[\"model\"])) for k, v in registry.items()\n", + " }\n", + " if not self._routes:\n", + " raise RuntimeError(\"No LLM providers configured. Add API keys in .env.\")\n", + "\n", + " def complete(self, *, model_label: str, system: str, user: str) -> str:\n", + " if model_label not in self._routes:\n", + " raise ValueError(f\"Unknown model: {model_label}\")\n", + " r = self._routes[model_label]\n", + " resp = r.client.chat.completions.create(\n", + " model=r.model,\n", + " messages=[{\"role\":\"system\",\"content\":system},\n", + " {\"role\":\"user\",\"content\":user}]\n", + " )\n", + " return (resp.choices[0].message.content or \"\").strip()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30636b66", + "metadata": {}, + "outputs": [], + "source": [ + "def chunk_text(text: str, size: int = 900, overlap: int = 150) -> List[str]:\n", + " \"\"\"Greedy fixed-size chunking with overlap (simple & fast).\"\"\"\n", + " out, i, n = [], 0, len(text)\n", + " while i < n:\n", + " j = min(i + size, n)\n", + " out.append(text[i:j])\n", + " i = max(j - overlap, j)\n", + " return out\n", + "\n", + "def load_bare_acts(root: str = \"knowledge_base/bare_acts\") -> List[Tuple[str, str]]:\n", + " \"\"\"Return list of (source_id, text). source_id is filename stem.\"\"\"\n", + " base = Path(root)\n", + " if not base.exists():\n", + " raise FileNotFoundError(f\"Folder not found: {base.resolve()}\")\n", + " pairs = []\n", + " for p in sorted(base.glob(\"*.txt\")):\n", + " pairs.append((p.stem, p.read_text(encoding=\"utf-8\")))\n", + " if not pairs:\n", + " raise RuntimeError(\"No .txt files found under knowledge_base/bare_acts\")\n", + " return pairs\n", + "\n", + "acts_raw = load_bare_acts()\n", + "print(\"Bare Acts loaded:\", [s for s,_ in acts_raw])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af537e05", + "metadata": {}, + "outputs": [], + "source": [ + "import chromadb\n", + "from chromadb import PersistentClient\n", + "from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n", + "\n", + "class BareActsIndex:\n", + " \"\"\"Owns the vector DB lifecycle & retrieval.\"\"\"\n", + " def __init__(self, db_path: str = \"vector_db\", collection: str = \"bare_acts\",\n", + " embed_model: str = \"sentence-transformers/all-MiniLM-L6-v2\"):\n", + " self.db_path = db_path\n", + " self.collection_name = collection\n", + " self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=embed_model)\n", + " self.client: PersistentClient = PersistentClient(path=db_path)\n", + " self.col = self.client.get_or_create_collection(\n", + " name=self.collection_name,\n", + " embedding_function=self.embed_fn\n", + " )\n", + "\n", + " def rebuild(self, docs: List[Tuple[str, str]]):\n", + " \"\"\"Idempotent rebuild: clears and re-adds chunks with metadata.\"\"\"\n", + " try:\n", + " self.client.delete_collection(self.collection_name)\n", + " except Exception:\n", + " pass\n", + " self.col = self.client.get_or_create_collection(\n", + " name=self.collection_name,\n", + " embedding_function=self.embed_fn\n", + " )\n", + "\n", + " ids, texts, metas = [], [], []\n", + " for src, text in docs:\n", + " for idx, ch in enumerate(chunk_text(text)):\n", + " ids.append(f\"{src}-{idx}\")\n", + " texts.append(ch)\n", + " metas.append({\"source\": src, \"chunk_id\": idx})\n", + " self.col.add(ids=ids, documents=texts, metadatas=metas)\n", + " print(f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name}\")\n", + "\n", + " def query(self, q: str, k: int = 6) -> List[Dict]:\n", + " res = self.col.query(query_texts=[q], n_results=k)\n", + " docs = res.get(\"documents\", [[]])[0]\n", + " metas = res.get(\"metadatas\", [[]])[0]\n", + " return [{\"text\": d, \"meta\": m} for d, m in zip(docs, metas)]\n", + "\n", + "# build (or rebuild) the index once\n", + "index = BareActsIndex()\n", + "index.rebuild(acts_raw)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7eec89e4", + "metadata": {}, + "outputs": [], + "source": [ + "class PromptBuilder:\n", + " \"\"\"Small utility to keep prompting consistent and auditable.\"\"\"\n", + " SYSTEM = (\n", + " \"You are a precise legal assistant for Indian Bare Acts. \"\n", + " \"Answer ONLY from the provided context. If the answer is not in context, say you don't know. \"\n", + " \"Cite the sources by file name (e.g., ipc, coi, bns) in brackets.\"\n", + " )\n", + "\n", + " @staticmethod\n", + " def build_user(query: str, contexts: List[Dict]) -> str:\n", + " ctx = \"\\n\\n---\\n\\n\".join(\n", + " f\"[{c['meta']['source']} #{c['meta']['chunk_id']}]\\n{c['text']}\" for c in contexts\n", + " )\n", + " return f\"Question:\\n{query}\\n\\nContext:\\n{ctx}\\n\\nInstructions:\\n- Keep answers concise.\\n- Quote key lines when useful.\\n- Add [source] inline.\"\n", + "\n", + "class RagQAService:\n", + " \"\"\"Coordinates retrieval + generation.\"\"\"\n", + " def __init__(self, index: BareActsIndex, llm: MultiLLM):\n", + " self.index = index\n", + " self.llm = llm\n", + " self.builder = PromptBuilder()\n", + "\n", + " def answer(self, *, question: str, model_label: str, k: int = 6) -> str:\n", + " ctx = self.index.query(question, k=k)\n", + " user = self.builder.build_user(question, ctx)\n", + " reply = self.llm.complete(model_label=model_label, system=self.builder.SYSTEM, user=user)\n", + "\n", + " # Append sources deterministically (post-processing for transparency)\n", + " sources = \", \".join(sorted({c[\"meta\"][\"source\"] for c in ctx}))\n", + " return f\"{reply}\\n\\n— Sources: {sources}\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4862732b", + "metadata": {}, + "outputs": [], + "source": [ + "llm = MultiLLM(MODEL_REGISTRY)\n", + "qa_service = RagQAService(index=index, llm=llm)\n", + "\n", + "# quick smoke test (won't spend tokens if no keys for that provider)\n", + "if AVAILABLE_MODELS:\n", + " print(\"Ready. Default model:\", DEFAULT_MODEL)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0b1512b", + "metadata": {}, + "outputs": [], + "source": [ + "def chat_fn(message: str, history: List[Dict], model_label: str, top_k: int) -> str:\n", + " try:\n", + " return qa_service.answer(question=message, model_label=model_label, k=int(top_k))\n", + " except Exception as e:\n", + " return f\"⚠️ {e}\"\n", + "\n", + "with gr.Blocks(title=\"Legal QnA • Bare Acts (RAG + Multi-LLM)\") as app:\n", + " gr.Markdown(\"### 🧑‍⚖️ Legal Q&A on Bare Acts (RAG) — Multi-Provider LLM\")\n", + " with gr.Row():\n", + " model_dd = gr.Dropdown(choices=AVAILABLE_MODELS or [\"OpenAI • GPT-4o-mini\"],\n", + " value=DEFAULT_MODEL if AVAILABLE_MODELS else None,\n", + " label=\"Model\")\n", + " topk = gr.Slider(2, 12, value=6, step=1, label=\"Top-K context\")\n", + "\n", + " chat = gr.ChatInterface(fn=chat_fn,\n", + " type=\"messages\",\n", + " additional_inputs=[model_dd, topk])\n", + "\n", + "app.launch(inbrowser=True)\n" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}