{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "9f0759f2-5e46-438a-ad8e-b5d5771ec9ec", "metadata": {}, "outputs": [], "source": [ "# RAG based Gradio solution to give information from related documents, using Llama3.2 and nomic-embed-text over OLLAMA\n", "# Took help of Claude and Course material." ] }, { "cell_type": "code", "execution_count": null, "id": "448bd8f4-9181-4039-829f-d3f0a5f14171", "metadata": {}, "outputs": [], "source": [ "import os, glob\n", "import sqlite3\n", "import json\n", "import numpy as np\n", "from typing import List, Dict, Tuple\n", "import requests\n", "import gradio as gr\n", "from datetime import datetime\n", "\n", "embedding_model = 'nomic-embed-text'\n", "llm_model = 'llama3.2'\n", "RagDist_k = 6\n", "folders = glob.glob(\"../../week5/knowledge-base/*\")\n", "folders" ] }, { "cell_type": "code", "execution_count": null, "id": "dc085852-a80f-4f2c-b31a-80ceda10bec6", "metadata": {}, "outputs": [], "source": [ "\n", "class OllamaEmbeddings:\n", " \"\"\"Generate embeddings using Ollama's embedding models.\"\"\"\n", " \n", " def __init__(self, model: str = embedding_model, base_url: str = \"http://localhost:11434\"):\n", " self.model = model\n", " self.base_url = base_url\n", " \n", " def embed_text(self, text: str) -> List[float]:\n", " \"\"\"Generate embedding for a single text.\"\"\"\n", " print('Processing', text[:70].replace('\\n',' | '))\n", " response = requests.post(\n", " f\"{self.base_url}/api/embeddings\",\n", " json={\"model\": self.model, \"prompt\": text}\n", " )\n", " if response.status_code == 200:\n", " return response.json()[\"embedding\"]\n", " else:\n", " raise Exception(f\"Error generating embedding: {response.text}\")\n", " \n", " def embed_documents(self, texts: List[str]) -> List[List[float]]:\n", " \"\"\"Generate embeddings for multiple texts.\"\"\"\n", " return [self.embed_text(text) for text in texts]\n", "\n", "\n", "class SQLiteVectorStore:\n", " \"\"\"Vector store using SQLite for storing and retrieving document embeddings.\"\"\"\n", " \n", " def __init__(self, db_path: str = \"vector_store.db\"):\n", " self.db_path = db_path\n", " self.conn = sqlite3.connect(db_path, check_same_thread=False)\n", " self._create_table()\n", " \n", " def _create_table(self):\n", " \"\"\"Create the documents table if it doesn't exist.\"\"\"\n", " cursor = self.conn.cursor()\n", " cursor.execute(\"\"\"\n", " CREATE TABLE IF NOT EXISTS documents (\n", " id INTEGER PRIMARY KEY AUTOINCREMENT,\n", " content TEXT NOT NULL,\n", " embedding TEXT NOT NULL,\n", " metadata TEXT,\n", " created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP\n", " )\n", " \"\"\")\n", " self.conn.commit()\n", " \n", " def add_documents(self, texts: List[str], embeddings: List[List[float]], \n", " metadatas: List[Dict] = None):\n", " \"\"\"Add documents with their embeddings to the store.\"\"\"\n", " cursor = self.conn.cursor()\n", " if metadatas is None:\n", " metadatas = [{}] * len(texts)\n", " \n", " for text, embedding, metadata in zip(texts, embeddings, metadatas):\n", " cursor.execute(\"\"\"\n", " INSERT INTO documents (content, embedding, metadata)\n", " VALUES (?, ?, ?)\n", " \"\"\", (text, json.dumps(embedding), json.dumps(metadata)))\n", " \n", " self.conn.commit()\n", " \n", " def cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:\n", " \"\"\"Calculate cosine similarity between two vectors.\"\"\"\n", " return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))\n", " \n", " def similarity_search(self, query_embedding: List[float], k: int = 3) -> List[Tuple[str, float, Dict]]:\n", " \"\"\"Search for the k most similar documents.\"\"\"\n", " cursor = self.conn.cursor()\n", " cursor.execute(\"SELECT content, embedding, metadata FROM documents\")\n", " results = cursor.fetchall()\n", " \n", " query_vec = np.array(query_embedding)\n", " similarities = []\n", " \n", " for content, embedding_json, metadata_json in results:\n", " doc_vec = np.array(json.loads(embedding_json))\n", " similarity = self.cosine_similarity(query_vec, doc_vec)\n", " similarities.append((content, similarity, json.loads(metadata_json)))\n", " \n", " # Sort by similarity (highest first) and return top k\n", " similarities.sort(key=lambda x: x[1], reverse=True)\n", " return similarities[:k]\n", " \n", " def clear_all(self):\n", " \"\"\"Clear all documents from the store.\"\"\"\n", " cursor = self.conn.cursor()\n", " cursor.execute(\"DELETE FROM documents\")\n", " self.conn.commit()\n", " \n", " def get_document_count(self) -> int:\n", " \"\"\"Get the total number of documents in the store.\"\"\"\n", " cursor = self.conn.cursor()\n", " cursor.execute(\"SELECT COUNT(*) FROM documents\")\n", " return cursor.fetchone()[0]\n", "\n", "\n", "class OllamaLLM:\n", " \"\"\"Interact with Ollama LLM for text generation.\"\"\"\n", " \n", " def __init__(self, model: str = llm_model, base_url: str = \"http://localhost:11434\"):\n", " self.model = model\n", " self.base_url = base_url\n", " \n", " def generate(self, prompt: str, stream: bool = False) -> str:\n", " \"\"\"Generate text from the LLM.\"\"\"\n", " response = requests.post(\n", " f\"{self.base_url}/api/generate\",\n", " json={\"model\": self.model, \"prompt\": prompt, \"stream\": stream}\n", " )\n", " \n", " if response.status_code == 200:\n", " return response.json()[\"response\"]\n", " else:\n", " raise Exception(f\"Error generating response: {response.text}\")\n", "\n", "\n", "class RAGSystem:\n", " \"\"\"RAG system combining vector store, embeddings, and LLM.\"\"\"\n", " \n", " def __init__(self, embedding_model: str = embedding_model, \n", " llm_model: str = llm_model,\n", " db_path: str = \"vector_store.db\"):\n", " self.embeddings = OllamaEmbeddings(model=embedding_model)\n", " self.vector_store = SQLiteVectorStore(db_path=db_path)\n", " self.llm = OllamaLLM(model=llm_model)\n", " \n", " def add_documents(self, documents: List[Dict[str, str]]):\n", " \"\"\"\n", " Add documents to the RAG system.\n", " documents: List of dicts with 'content' and optional 'metadata'\n", " \"\"\"\n", " texts = [doc['content'] for doc in documents]\n", " metadatas = [doc.get('metadata', {}) for doc in documents]\n", " \n", " print(f\"Generating embeddings for {len(texts)} documents...\")\n", " embeddings = self.embeddings.embed_documents(texts)\n", " \n", " print(\"Storing documents in vector store...\")\n", " self.vector_store.add_documents(texts, embeddings, metadatas)\n", " print(f\"Successfully added {len(texts)} documents!\")\n", " \n", " def query(self, question: str, k: int = 3) -> str:\n", " \"\"\"Query the RAG system with a question.\"\"\"\n", " # Generate embedding for the query\n", " query_embedding = self.embeddings.embed_text(question)\n", " \n", " # Retrieve relevant documents\n", " results = self.vector_store.similarity_search(query_embedding, k=k)\n", " \n", " if not results:\n", " return \"I don't have any information to answer this question.\"\n", " \n", " # Build context from retrieved documents\n", " context = \"\\n\\n\".join([\n", " f\"Document {i+1} (Relevance: {score:.2f}):\\n{content}\"\n", " for i, (content, score, _) in enumerate(results)\n", " ])\n", " \n", " # Create prompt for LLM\n", " prompt = f\"\"\"You are a helpful assistant answering questions based on the provided context.\n", " Use the following context to answer the question. If you cannot answer the question based on the context, say so.\n", " \n", " Context:\n", " {context}\n", " \n", " Question: {question}\n", " \n", " Answer:\"\"\"\n", " \n", " # Generate response\n", " response = self.llm.generate(prompt)\n", " return response\n", " \n", " def get_stats(self) -> str:\n", " \"\"\"Get statistics about the RAG system.\"\"\"\n", " doc_count = self.vector_store.get_document_count()\n", " return f\"Total documents in database: {doc_count}\"\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "37cbaa24-6e17-4712-8c90-429264b9b82e", "metadata": {}, "outputs": [], "source": [ "def load_documents() -> List[Dict[str, str]]:\n", " \"\"\"\n", " Read all files from specified folders and format them for RAG system. \n", " Args:\n", " folders: List of folder paths to read files from\n", " Returns:\n", " List of dictionaries with 'content' and 'metadata' keys\n", " \"\"\"\n", " from pathlib import Path\n", " \n", " documents = []\n", " supported_extensions = {'.md'}\n", " \n", " for folder in folders:\n", " folder_path = Path(folder)\n", " \n", " if not folder_path.exists():\n", " print(f\"Warning: Folder '{folder}' does not exist. Skipping...\")\n", " continue\n", " \n", " if not folder_path.is_dir():\n", " print(f\"Warning: '{folder}' is not a directory. Skipping...\")\n", " continue\n", " \n", " folder_name = folder_path.name\n", " \n", " # Get all files in the folder\n", " files = [f for f in folder_path.iterdir() if f.is_file()]\n", " \n", " for file_path in files:\n", " # Check if file extension is supported\n", " if file_path.suffix.lower() not in supported_extensions:\n", " print(f\"Skipping unsupported file type: {file_path.name}\")\n", " continue\n", " \n", " try:\n", " # Read file content\n", " with open(file_path, 'r', encoding='utf-8') as f:\n", " content = f.read()\n", " \n", " # Create document dictionary\n", " document = {\n", " 'metadata': {\n", " 'type': folder_name,\n", " 'name': file_path.name,\n", " 'datalen': len(content)\n", " },\n", " 'content': content,\n", " }\n", " \n", " documents.append(document)\n", " print(f\"āœ“ Loaded: {file_path.name} from folder '{folder_name}'\")\n", " \n", " except Exception as e:\n", " print(f\"Error reading file {file_path.name}: {str(e)}\")\n", " continue\n", " \n", " print(f\"\\nTotal documents loaded: {len(documents)}\")\n", " return documents\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d257bd84-fd7b-4a64-bc5b-148b30b00aa3", "metadata": {}, "outputs": [], "source": [ "def create_gradio_interface(rag_system: RAGSystem):\n", " \"\"\"Create Gradio chat interface for the RAG system.\"\"\"\n", " \n", " def chat_fn(message, history):\n", " \"\"\"Process chat messages.\"\"\"\n", " try:\n", " response = rag_system.query(message, k=RagDist_k)\n", " return response\n", " except Exception as e:\n", " return f\"Error: {str(e)}\\n\\nMake sure Ollama is running with the required models installed.\"\n", " \n", " def load_data():\n", " \"\"\"Load sample documents into the system.\"\"\"\n", " try:\n", " documents = load_documents()\n", " rag_system.add_documents(documents)\n", " stats = rag_system.get_stats()\n", " return f\"āœ… Sample documents loaded successfully!\\n{stats}\"\n", " except Exception as e:\n", " return f\"āŒ Error loading documents: {str(e)}\"\n", " \n", " def get_stats():\n", " \"\"\"Get system statistics.\"\"\"\n", " return rag_system.get_stats()\n", " \n", " with gr.Blocks(title=\"RAG System - Company Knowledge Base\", theme=gr.themes.Soft()) as demo:\n", " gr.Markdown(\"# šŸ¤– RAG System - Company Knowledge Base\")\n", " gr.Markdown(\"Ask questions about company information, contracts, employees, and products.\")\n", " \n", " with gr.Row():\n", " with gr.Column(scale=3):\n", " chatbot = gr.ChatInterface(\n", " fn=chat_fn,\n", " examples=[\n", " \"Who is the CTO of the company?\",\n", " \"Who is the CEO of the company?\",\n", " \"What products does the company offer?\",\n", " ],\n", " title=\"\",\n", " description=\"šŸ’¬ Chat with the company knowledge base\"\n", " )\n", " \n", " with gr.Column(scale=1):\n", " gr.Markdown(\"### šŸ“Š System Controls\")\n", " load_btn = gr.Button(\"šŸ“„ Load Documents\", variant=\"primary\")\n", " stats_btn = gr.Button(\"šŸ“ˆ Get Statistics\")\n", " output_box = gr.Textbox(label=\"System Output\", lines=5)\n", " \n", " load_btn.click(fn=load_data, outputs=output_box)\n", " stats_btn.click(fn=get_stats, outputs=output_box)\n", " \n", " gr.Markdown(f\"\"\"\n", " ### šŸ“ Instructions:\n", " 1. Make sure Ollama is running\n", " 2. Click \"Load Sample Documents\" \n", " 3. Start asking questions!\n", " \n", " ### šŸ”§ Required Models:\n", " - `ollama pull {embedding_model}`\n", " - `ollama pull {llm_model}`\n", " \"\"\")\n", " \n", " return demo\n", "\n", "\n", "def main():\n", " \"\"\"Main function to run the RAG system.\"\"\"\n", " print(\"=\" * 60)\n", " print(\"RAG System with Ollama and SQLite\")\n", " print(\"=\" * 60)\n", " \n", " # Initialize RAG system\n", " print(\"\\nInitializing RAG system...\")\n", " rag_system = RAGSystem(\n", " embedding_model=embedding_model,\n", " llm_model=llm_model,\n", " db_path=\"vector_store.db\"\n", " )\n", " \n", " print(\"\\nāš ļø Make sure Ollama is running and you have the required models:\")\n", " print(f\" - ollama pull {embedding_model}\")\n", " print(f\" - ollama pull {llm_model}\")\n", " print(\"\\nStarting Gradio interface...\")\n", " \n", " # Create and launch Gradio interface\n", " demo = create_gradio_interface(rag_system)\n", " demo.launch(share=False)\n", "\n", "\n", "main()" ] }, { "cell_type": "code", "execution_count": null, "id": "01b4ff0e-36a5-43b5-8ecf-59e42a18a908", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }