Update legal_qna_with_rag_on_bare_acts.ipynb
This commit is contained in:
@@ -0,0 +1,270 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d27544d4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import Dict, List, Optional, Tuple\n",
|
||||
"\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from openai import OpenAI\n",
|
||||
"import gradio as gr\n",
|
||||
"\n",
|
||||
"# ---- load env ----\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"\n",
|
||||
"# ---- OpenAI-compatible base URLs (Gemini & Groq) ----\n",
|
||||
"GEMINI_BASE = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
|
||||
"GROQ_BASE = \"https://api.groq.com/openai/v1\"\n",
|
||||
"\n",
|
||||
"OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n",
|
||||
"GOOGLE_API_KEY = os.getenv(\"GOOGLE_API_KEY\") # Gemini\n",
|
||||
"GROQ_API_KEY = os.getenv(\"GROQ_API_KEY\") # Groq\n",
|
||||
"\n",
|
||||
"# ---- create clients only if keys exist ----\n",
|
||||
"openai_client = OpenAI() if OPENAI_API_KEY else None\n",
|
||||
"gemini_client = OpenAI(api_key=GOOGLE_API_KEY, base_url=GEMINI_BASE) if GOOGLE_API_KEY else None\n",
|
||||
"groq_client = OpenAI(api_key=GROQ_API_KEY, base_url=GROQ_BASE) if GROQ_API_KEY else None\n",
|
||||
"\n",
|
||||
"# ---- model registry (label -> client/model) ----\n",
|
||||
"MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}\n",
|
||||
"def _register(label: str, client: Optional[OpenAI], model_id: str):\n",
|
||||
" if client is not None:\n",
|
||||
" MODEL_REGISTRY[label] = {\"client\": client, \"model\": model_id}\n",
|
||||
"\n",
|
||||
"# OpenAI\n",
|
||||
"_register(\"OpenAI • GPT-4o-mini\", openai_client, \"gpt-4o-mini\")\n",
|
||||
"\n",
|
||||
"# Gemini\n",
|
||||
"_register(\"Gemini • 2.5 Flash\", gemini_client, \"gemini-2.5-flash\")\n",
|
||||
"_register(\"Gemini • 2.5 Pro\", gemini_client, \"gemini-2.5-pro\")\n",
|
||||
"\n",
|
||||
"# Groq\n",
|
||||
"_register(\"Groq • Llama 3.3 70B\", groq_client, \"llama-3.3-70b-versatile\")\n",
|
||||
"_register(\"Groq • Llama 3.1 8B\", groq_client, \"llama-3.1-8b-instant\")\n",
|
||||
"\n",
|
||||
"AVAILABLE_MODELS = list(MODEL_REGISTRY.keys())\n",
|
||||
"DEFAULT_MODEL = AVAILABLE_MODELS[0] if AVAILABLE_MODELS else \"OpenAI • GPT-4o-mini\"\n",
|
||||
"\n",
|
||||
"print(\"Providers configured →\",\n",
|
||||
" f\"OpenAI:{bool(OPENAI_API_KEY)} Gemini:{bool(GOOGLE_API_KEY)} Groq:{bool(GROQ_API_KEY)}\")\n",
|
||||
"print(\"Models available →\", \", \".join(AVAILABLE_MODELS) or \"None (add API keys in .env)\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "efe4e4db",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass(frozen=True)\n",
|
||||
"class LLMRoute:\n",
|
||||
" client: OpenAI\n",
|
||||
" model: str\n",
|
||||
"\n",
|
||||
"class MultiLLM:\n",
|
||||
" \"\"\"OpenAI-compatible chat across providers (OpenAI, Gemini, Groq).\"\"\"\n",
|
||||
" def __init__(self, registry: Dict[str, Dict[str, object]]):\n",
|
||||
" self._routes: Dict[str, LLMRoute] = {\n",
|
||||
" k: LLMRoute(client=v[\"client\"], model=str(v[\"model\"])) for k, v in registry.items()\n",
|
||||
" }\n",
|
||||
" if not self._routes:\n",
|
||||
" raise RuntimeError(\"No LLM providers configured. Add API keys in .env.\")\n",
|
||||
"\n",
|
||||
" def complete(self, *, model_label: str, system: str, user: str) -> str:\n",
|
||||
" if model_label not in self._routes:\n",
|
||||
" raise ValueError(f\"Unknown model: {model_label}\")\n",
|
||||
" r = self._routes[model_label]\n",
|
||||
" resp = r.client.chat.completions.create(\n",
|
||||
" model=r.model,\n",
|
||||
" messages=[{\"role\":\"system\",\"content\":system},\n",
|
||||
" {\"role\":\"user\",\"content\":user}]\n",
|
||||
" )\n",
|
||||
" return (resp.choices[0].message.content or \"\").strip()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "30636b66",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def chunk_text(text: str, size: int = 900, overlap: int = 150) -> List[str]:\n",
|
||||
" \"\"\"Greedy fixed-size chunking with overlap (simple & fast).\"\"\"\n",
|
||||
" out, i, n = [], 0, len(text)\n",
|
||||
" while i < n:\n",
|
||||
" j = min(i + size, n)\n",
|
||||
" out.append(text[i:j])\n",
|
||||
" i = max(j - overlap, j)\n",
|
||||
" return out\n",
|
||||
"\n",
|
||||
"def load_bare_acts(root: str = \"knowledge_base/bare_acts\") -> List[Tuple[str, str]]:\n",
|
||||
" \"\"\"Return list of (source_id, text). source_id is filename stem.\"\"\"\n",
|
||||
" base = Path(root)\n",
|
||||
" if not base.exists():\n",
|
||||
" raise FileNotFoundError(f\"Folder not found: {base.resolve()}\")\n",
|
||||
" pairs = []\n",
|
||||
" for p in sorted(base.glob(\"*.txt\")):\n",
|
||||
" pairs.append((p.stem, p.read_text(encoding=\"utf-8\")))\n",
|
||||
" if not pairs:\n",
|
||||
" raise RuntimeError(\"No .txt files found under knowledge_base/bare_acts\")\n",
|
||||
" return pairs\n",
|
||||
"\n",
|
||||
"acts_raw = load_bare_acts()\n",
|
||||
"print(\"Bare Acts loaded:\", [s for s,_ in acts_raw])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "af537e05",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import chromadb\n",
|
||||
"from chromadb import PersistentClient\n",
|
||||
"from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n",
|
||||
"\n",
|
||||
"class BareActsIndex:\n",
|
||||
" \"\"\"Owns the vector DB lifecycle & retrieval.\"\"\"\n",
|
||||
" def __init__(self, db_path: str = \"vector_db\", collection: str = \"bare_acts\",\n",
|
||||
" embed_model: str = \"sentence-transformers/all-MiniLM-L6-v2\"):\n",
|
||||
" self.db_path = db_path\n",
|
||||
" self.collection_name = collection\n",
|
||||
" self.embed_fn = SentenceTransformerEmbeddingFunction(model_name=embed_model)\n",
|
||||
" self.client: PersistentClient = PersistentClient(path=db_path)\n",
|
||||
" self.col = self.client.get_or_create_collection(\n",
|
||||
" name=self.collection_name,\n",
|
||||
" embedding_function=self.embed_fn\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def rebuild(self, docs: List[Tuple[str, str]]):\n",
|
||||
" \"\"\"Idempotent rebuild: clears and re-adds chunks with metadata.\"\"\"\n",
|
||||
" try:\n",
|
||||
" self.client.delete_collection(self.collection_name)\n",
|
||||
" except Exception:\n",
|
||||
" pass\n",
|
||||
" self.col = self.client.get_or_create_collection(\n",
|
||||
" name=self.collection_name,\n",
|
||||
" embedding_function=self.embed_fn\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" ids, texts, metas = [], [], []\n",
|
||||
" for src, text in docs:\n",
|
||||
" for idx, ch in enumerate(chunk_text(text)):\n",
|
||||
" ids.append(f\"{src}-{idx}\")\n",
|
||||
" texts.append(ch)\n",
|
||||
" metas.append({\"source\": src, \"chunk_id\": idx})\n",
|
||||
" self.col.add(ids=ids, documents=texts, metadatas=metas)\n",
|
||||
" print(f\"Indexed {len(texts)} chunks from {len(docs)} files → {self.collection_name}\")\n",
|
||||
"\n",
|
||||
" def query(self, q: str, k: int = 6) -> List[Dict]:\n",
|
||||
" res = self.col.query(query_texts=[q], n_results=k)\n",
|
||||
" docs = res.get(\"documents\", [[]])[0]\n",
|
||||
" metas = res.get(\"metadatas\", [[]])[0]\n",
|
||||
" return [{\"text\": d, \"meta\": m} for d, m in zip(docs, metas)]\n",
|
||||
"\n",
|
||||
"# build (or rebuild) the index once\n",
|
||||
"index = BareActsIndex()\n",
|
||||
"index.rebuild(acts_raw)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7eec89e4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class PromptBuilder:\n",
|
||||
" \"\"\"Small utility to keep prompting consistent and auditable.\"\"\"\n",
|
||||
" SYSTEM = (\n",
|
||||
" \"You are a precise legal assistant for Indian Bare Acts. \"\n",
|
||||
" \"Answer ONLY from the provided context. If the answer is not in context, say you don't know. \"\n",
|
||||
" \"Cite the sources by file name (e.g., ipc, coi, bns) in brackets.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def build_user(query: str, contexts: List[Dict]) -> str:\n",
|
||||
" ctx = \"\\n\\n---\\n\\n\".join(\n",
|
||||
" f\"[{c['meta']['source']} #{c['meta']['chunk_id']}]\\n{c['text']}\" for c in contexts\n",
|
||||
" )\n",
|
||||
" return f\"Question:\\n{query}\\n\\nContext:\\n{ctx}\\n\\nInstructions:\\n- Keep answers concise.\\n- Quote key lines when useful.\\n- Add [source] inline.\"\n",
|
||||
"\n",
|
||||
"class RagQAService:\n",
|
||||
" \"\"\"Coordinates retrieval + generation.\"\"\"\n",
|
||||
" def __init__(self, index: BareActsIndex, llm: MultiLLM):\n",
|
||||
" self.index = index\n",
|
||||
" self.llm = llm\n",
|
||||
" self.builder = PromptBuilder()\n",
|
||||
"\n",
|
||||
" def answer(self, *, question: str, model_label: str, k: int = 6) -> str:\n",
|
||||
" ctx = self.index.query(question, k=k)\n",
|
||||
" user = self.builder.build_user(question, ctx)\n",
|
||||
" reply = self.llm.complete(model_label=model_label, system=self.builder.SYSTEM, user=user)\n",
|
||||
"\n",
|
||||
" # Append sources deterministically (post-processing for transparency)\n",
|
||||
" sources = \", \".join(sorted({c[\"meta\"][\"source\"] for c in ctx}))\n",
|
||||
" return f\"{reply}\\n\\n— Sources: {sources}\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4862732b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = MultiLLM(MODEL_REGISTRY)\n",
|
||||
"qa_service = RagQAService(index=index, llm=llm)\n",
|
||||
"\n",
|
||||
"# quick smoke test (won't spend tokens if no keys for that provider)\n",
|
||||
"if AVAILABLE_MODELS:\n",
|
||||
" print(\"Ready. Default model:\", DEFAULT_MODEL)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0b1512b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def chat_fn(message: str, history: List[Dict], model_label: str, top_k: int) -> str:\n",
|
||||
" try:\n",
|
||||
" return qa_service.answer(question=message, model_label=model_label, k=int(top_k))\n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"⚠️ {e}\"\n",
|
||||
"\n",
|
||||
"with gr.Blocks(title=\"Legal QnA • Bare Acts (RAG + Multi-LLM)\") as app:\n",
|
||||
" gr.Markdown(\"### 🧑⚖️ Legal Q&A on Bare Acts (RAG) — Multi-Provider LLM\")\n",
|
||||
" with gr.Row():\n",
|
||||
" model_dd = gr.Dropdown(choices=AVAILABLE_MODELS or [\"OpenAI • GPT-4o-mini\"],\n",
|
||||
" value=DEFAULT_MODEL if AVAILABLE_MODELS else None,\n",
|
||||
" label=\"Model\")\n",
|
||||
" topk = gr.Slider(2, 12, value=6, step=1, label=\"Top-K context\")\n",
|
||||
"\n",
|
||||
" chat = gr.ChatInterface(fn=chat_fn,\n",
|
||||
" type=\"messages\",\n",
|
||||
" additional_inputs=[model_dd, topk])\n",
|
||||
"\n",
|
||||
"app.launch(inbrowser=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user