diff --git a/week5/community-contributions/Week5_day5_Gemini_Semantic_Chunks.ipynb b/week5/community-contributions/Week5_day5_Gemini_Semantic_Chunks.ipynb new file mode 100644 index 0000000..d4144c2 --- /dev/null +++ b/week5/community-contributions/Week5_day5_Gemini_Semantic_Chunks.ipynb @@ -0,0 +1,463 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2080947c-96d9-447f-8368-cfdc9e5c9960", + "metadata": {}, + "source": [ + "# Using Semantic chunks with Gemini API and Gemini Embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53221f1a-a0c1-4506-a3d0-d6626c58e4e0", + "metadata": {}, + "outputs": [], + "source": [ + "# Regular Imports\n", + "import os\n", + "import glob\n", + "import time\n", + "from dotenv import load_dotenv\n", + "from tqdm.notebook import tqdm\n", + "import gradio as gr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a2a7171-a7b6-42a6-96d7-c93f360689ec", + "metadata": {}, + "outputs": [], + "source": [ + "# Visual Import\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.manifold import TSNE\n", + "import numpy as np\n", + "import plotly.graph_objects as go" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51c9d658-65e5-40a1-8680-d0b561f87649", + "metadata": {}, + "outputs": [], + "source": [ + "# Lang Chain Imports\n", + "\n", + "from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI\n", + "from langchain_community.document_loaders import DirectoryLoader, TextLoader\n", + "from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n", + "from langchain_core.messages import HumanMessage, AIMessage\n", + "from langchain_chroma import Chroma\n", + "from langchain_experimental.text_splitter import SemanticChunker\n", + "from langchain_core.chat_history import InMemoryChatMessageHistory\n", + "from langchain_core.runnables.history import RunnableWithMessageHistory\n", + "from langchain.chains.combine_documents import create_stuff_documents_chain\n", + "from langchain.chains.history_aware_retriever import create_history_aware_retriever\n", + "from langchain.chains import create_retrieval_chain\n", + "from langchain_core.prompts import MessagesPlaceholder\n", + "from langchain_core.runnables import RunnableLambda" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e7ed82b-b28a-4094-9f77-3b6432dd0f7a", + "metadata": {}, + "outputs": [], + "source": [ + "# Constants\n", + "\n", + "CHAT_MODEL = \"gemini-2.5-flash\"\n", + "EMBEDDING_MODEL = \"models/text-embedding-004\"\n", + "# EMBEDDING_MODEL_EXP = \"models/gemini-embedding-exp-03-07\"\n", + "\n", + "folders = glob.glob(\"knowledge-base/*\")\n", + "text_loader_kwargs = {'encoding': 'utf-8'}\n", + "db_name = \"vector_db\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b83281a2-bcae-41ab-a347-0e7f9688d1ed", + "metadata": {}, + "outputs": [], + "source": [ + "load_dotenv(override=True)\n", + "\n", + "api_key = os.getenv(\"GOOGLE_API_KEY\")\n", + "\n", + "if not api_key:\n", + " print(\"API Key not found!\")\n", + "else:\n", + " print(\"API Key loaded in memory\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fd6d516-772b-478d-9b28-09d42f2277d7", + "metadata": {}, + "outputs": [], + "source": [ + "def add_metadata(doc, doc_type):\n", + " doc.metadata[\"doc_type\"] = doc_type\n", + " return doc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6bc4198b-f989-42c0-95b5-3596448fcaa2", + "metadata": {}, + "outputs": [], + "source": [ + "documents = []\n", + "for folder in tqdm(folders, desc=\"Loading folders\"):\n", + " doc_type = os.path.basename(folder)\n", + " loader = DirectoryLoader(folder, glob=\"**/*.md\", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)\n", + " folder_docs = loader.load()\n", + " documents.extend([add_metadata(doc, doc_type) for doc in folder_docs])\n", + "\n", + "print(f\"Total documents loaded: {len(documents)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "bb74241f-e9d5-42e8-9a4b-f31018397d66", + "metadata": {}, + "source": [ + "## Create Semantic Chunks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a3aa17f-f5d0-430a-80da-95c284bd99a8", + "metadata": {}, + "outputs": [], + "source": [ + "chunking_embedding_model = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, task_type=\"retrieval_document\")\n", + "\n", + "text_splitter = SemanticChunker(\n", + " chunking_embedding_model,\n", + " breakpoint_threshold_type=\"percentile\", \n", + " breakpoint_threshold_amount=95.0, \n", + " min_chunk_size=3 \n", + ")\n", + "\n", + "start = time.time()\n", + "\n", + "semantic_chunks = []\n", + "pbar = tqdm(documents, desc=\"Semantic chunking documents\")\n", + "\n", + "for i, doc in enumerate(pbar):\n", + " doc_type = doc.metadata.get('doc_type', 'Unknown')\n", + " pbar.set_postfix_str(f\"Processing: {doc_type}\")\n", + " try:\n", + " doc_chunks = text_splitter.split_documents([doc])\n", + " semantic_chunks.extend(doc_chunks)\n", + " except Exception as e:\n", + " tqdm.write(f\"❌ Failed to split doc ({doc.metadata.get('source', 'unknown source')}): {e}\")\n", + "print(f\"⏱️ Took {time.time() - start:.2f} seconds\")\n", + "print(f\"Total semantic chunks: {len(semantic_chunks)}\")\n", + "\n", + "# import time\n", + "# start = time.time()\n", + "\n", + "# try:\n", + "# semantic_chunks = text_splitter.split_documents(documents)\n", + "# print(f\"✅ Chunking completed with {len(semantic_chunks)} chunks\")\n", + "# except Exception as e:\n", + "# print(f\"❌ Failed to split documents: {e}\")\n", + "\n", + "# print(f\"⏱️ Took {time.time() - start:.2f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "675b98d6-5ed0-45d1-8f79-765911e6badf", + "metadata": {}, + "outputs": [], + "source": [ + "# Some Preview of the chunks\n", + "for i, doc in enumerate(semantic_chunks[:15]):\n", + " print(f\"--- Chunk {i+1} ---\")\n", + " print(doc.page_content) \n", + " print(\"\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "c17accff-539a-490b-8a5f-b5ce632a3c71", + "metadata": {}, + "source": [ + "## Embed with Gemini Embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0bd228bd-37d2-4aaf-b0f6-d94943f6f248", + "metadata": {}, + "outputs": [], + "source": [ + "embedding = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL,task_type=\"retrieval_document\")\n", + "\n", + "if os.path.exists(db_name):\n", + " Chroma(persist_directory=db_name, embedding_function=embedding).delete_collection()\n", + "\n", + "vectorstore = Chroma.from_documents(\n", + " documents=semantic_chunks,\n", + " embedding=embedding,\n", + " persist_directory=db_name\n", + ")\n", + "\n", + "print(f\"✅ Vectorstore created with {vectorstore._collection.count()} documents\")" + ] + }, + { + "cell_type": "markdown", + "id": "ce0a3e23-5912-4de2-bf34-3c0936375de1", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Visualzing Vectors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ffdc6f5-ec25-4229-94d4-1fc6bb4d2702", + "metadata": {}, + "outputs": [], + "source": [ + "collection = vectorstore._collection\n", + "result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n", + "vectors = np.array(result['embeddings'])\n", + "documents = result['documents']\n", + "metadatas = result['metadatas']\n", + "doc_types = [metadata['doc_type'] for metadata in metadatas]\n", + "colors = [['blue', 'green', 'red', 'orange'][['products', 'employees', 'contracts', 'company'].index(t)] for t in doc_types]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5428164b-f0d5-4d2b-ac4a-514c43ceaa79", + "metadata": {}, + "outputs": [], + "source": [ + "# We humans find it easier to visalize things in 2D!\n", + "# Reduce the dimensionality of the vectors to 2D using t-SNE\n", + "# (t-distributed stochastic neighbor embedding)\n", + "\n", + "tsne = TSNE(n_components=2, random_state=42)\n", + "reduced_vectors = tsne.fit_transform(vectors)\n", + "\n", + "# Create the 2D scatter plot\n", + "fig = go.Figure(data=[go.Scatter(\n", + " x=reduced_vectors[:, 0],\n", + " y=reduced_vectors[:, 1],\n", + " mode='markers',\n", + " marker=dict(size=5, color=colors, opacity=0.8),\n", + " text=[f\"Type: {t}
Text: {d[:100]}...\" for t, d in zip(doc_types, documents)],\n", + " hoverinfo='text'\n", + ")])\n", + "\n", + "fig.update_layout(\n", + " title='2D Chroma Vector Store Visualization',\n", + " scene=dict(xaxis_title='x',yaxis_title='y'),\n", + " width=800,\n", + " height=600,\n", + " margin=dict(r=20, b=10, l=10, t=40)\n", + ")\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "id": "359b8651-a382-4050-8bf8-123e5cdf4d53", + "metadata": {}, + "source": [ + "## RAG Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08a75313-6c68-42e5-bd37-78254123094c", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = vectorstore.as_retriever(search_kwargs={\"k\": 20 })\n", + "\n", + "# Conversation Memory\n", + "# memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n", + "\n", + "chat_llm = ChatGoogleGenerativeAI(model=CHAT_MODEL, temperature=0.7)\n", + "\n", + "question_generator_template = \"\"\"Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.\n", + "If the follow up question is already a standalone question, return it as is.\n", + "\n", + "Chat History:\n", + "{chat_history}\n", + "Follow Up Input: {input} \n", + "Standalone question:\"\"\"\n", + "\n", + "question_generator_prompt = ChatPromptTemplate.from_messages([\n", + " MessagesPlaceholder(variable_name=\"chat_history\"),\n", + " HumanMessagePromptTemplate.from_template(\"{input}\")\n", + "])\n", + "\n", + "history_aware_retriever = create_history_aware_retriever(\n", + " chat_llm, retriever, question_generator_prompt\n", + ")\n", + "\n", + "qa_system_prompt = \"\"\"You are Insurellm’s intelligent virtual assistant, designed to answer questions with accuracy and clarity. Respond naturally and helpfully, as if you're part of the team.\n", + "Use the retrieved documents and prior conversation to provide accurate, conversational, and concise answers.Rephrase source facts in a natural tone, not word-for-word.\n", + "When referencing people or company history, prioritize clarity and correctness.\n", + "Only infer from previous conversation if it provides clear and factual clues. Do not guess or assume missing information.\n", + "If you truly don’t have the answer, respond with:\n", + "\"I don't have that information.\"\n", + "Avoid repeating the user's wording unnecessarily. Do not refer to 'the context', speculate, or make up facts.\n", + "\n", + "{context}\"\"\"\n", + "\n", + "\n", + "qa_human_prompt = \"{input}\" \n", + "\n", + "qa_prompt = ChatPromptTemplate.from_messages([\n", + " SystemMessagePromptTemplate.from_template(qa_system_prompt),\n", + " MessagesPlaceholder(variable_name=\"chat_history\"),\n", + " HumanMessagePromptTemplate.from_template(\"{input}\")\n", + "])\n", + "\n", + "combine_docs_chain = create_stuff_documents_chain(chat_llm, qa_prompt)\n", + "\n", + "# inspect_context = RunnableLambda(lambda inputs: (\n", + "# print(\"\\n Retrieved Context:\\n\", \"\\n---\\n\".join([doc.page_content for doc in inputs[\"context\"]])),\n", + "# inputs # pass it through unchanged\n", + "# )[1])\n", + "\n", + "# inspect_inputs = RunnableLambda(lambda inputs: (\n", + "# print(\"\\n Inputs received by the chain:\\n\", inputs),\n", + "# inputs\n", + "# )[1])\n", + "\n", + "base_chain = create_retrieval_chain(history_aware_retriever, combine_docs_chain)\n", + "\n", + "# Using Runnable Lambda as Gradio needs the response to contain only the output (answer) and base_chain would have a dict with input, context, chat_history, answer\n", + "\n", + "# base_chain_with_output = base_chain | inspect_context | RunnableLambda(lambda res: res[\"answer\"])\n", + "# base_chain_with_output = base_chain | RunnableLambda(lambda res: res[\"answer\"])\n", + "\n", + "\n", + "# Session Persistent Chat History \n", + "# If we want to persist history between sessions then use MongoDB (or any non sql DB)to store and use MongoDBChatMessageHistory (relevant DB Wrapper)\n", + "\n", + "chat_histories = {}\n", + "\n", + "def get_history(session_id):\n", + " if session_id not in chat_histories:\n", + " chat_histories[session_id] = InMemoryChatMessageHistory()\n", + " return chat_histories[session_id]\n", + "\n", + "# Currently set to streaming ...if one shot response is needed then comment base_chain and output_message_key and enable base_chain_with_output\n", + "conversation_chain = RunnableWithMessageHistory(\n", + " # base_chain_with_output,\n", + " base_chain,\n", + " get_history,\n", + " output_messages_key=\"answer\", \n", + " input_messages_key=\"input\",\n", + " history_messages_key=\"chat_history\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06b58566-70cb-42eb-8b1c-9fe353fe71f0", + "metadata": {}, + "outputs": [], + "source": [ + "def chat(question, history):\n", + " try:\n", + " # result = conversation_chain.invoke({\"input\": question, \"chat_history\": memory.buffer_as_messages})\n", + " \n", + " # memory.chat_memory.add_user_message(question)\n", + " # memory.chat_memory.add_ai_message(result[\"answer\"])\n", + "\n", + " # return result[\"answer\"]\n", + "\n", + " \n", + " session_id = \"default-session\"\n", + "\n", + " # # FUll chat version\n", + " # result = conversation_chain.invoke(\n", + " # {\"input\": question},\n", + " # config={\"configurable\": {\"session_id\": session_id}}\n", + " # )\n", + " # # print(result)\n", + " # return result\n", + "\n", + " # Streaming Version\n", + " response_buffer = \"\"\n", + "\n", + " for chunk in conversation_chain.stream({\"input\": question},config={\"configurable\": {\"session_id\": session_id}}):\n", + " if \"answer\" in chunk:\n", + " response_buffer += chunk[\"answer\"]\n", + " yield response_buffer \n", + " except Exception as e:\n", + " print(f\"An error occurred during chat: {e}\")\n", + " return \"I apologize, but I encountered an error and cannot answer that right now.\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a577ac66-3952-4821-83d2-8a50bad89971", + "metadata": {}, + "outputs": [], + "source": [ + "view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56b63a17-2522-46e5-b5a3-e2e80e52a723", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}