Merge pull request #477 from Indraveen-Technologies/week-5-contributions
Week 5 Contribution: RAG AI Bot using Semantic Chunk and Gemini API
This commit is contained in:
@@ -0,0 +1,463 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2080947c-96d9-447f-8368-cfdc9e5c9960",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Using Semantic chunks with Gemini API and Gemini Embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "53221f1a-a0c1-4506-a3d0-d6626c58e4e0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Regular Imports\n",
|
||||
"import os\n",
|
||||
"import glob\n",
|
||||
"import time\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from tqdm.notebook import tqdm\n",
|
||||
"import gradio as gr"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9a2a7171-a7b6-42a6-96d7-c93f360689ec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visual Import\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from sklearn.manifold import TSNE\n",
|
||||
"import numpy as np\n",
|
||||
"import plotly.graph_objects as go"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "51c9d658-65e5-40a1-8680-d0b561f87649",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Lang Chain Imports\n",
|
||||
"\n",
|
||||
"from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI\n",
|
||||
"from langchain_community.document_loaders import DirectoryLoader, TextLoader\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
|
||||
"from langchain_core.messages import HumanMessage, AIMessage\n",
|
||||
"from langchain_chroma import Chroma\n",
|
||||
"from langchain_experimental.text_splitter import SemanticChunker\n",
|
||||
"from langchain_core.chat_history import InMemoryChatMessageHistory\n",
|
||||
"from langchain_core.runnables.history import RunnableWithMessageHistory\n",
|
||||
"from langchain.chains.combine_documents import create_stuff_documents_chain\n",
|
||||
"from langchain.chains.history_aware_retriever import create_history_aware_retriever\n",
|
||||
"from langchain.chains import create_retrieval_chain\n",
|
||||
"from langchain_core.prompts import MessagesPlaceholder\n",
|
||||
"from langchain_core.runnables import RunnableLambda"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6e7ed82b-b28a-4094-9f77-3b6432dd0f7a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Constants\n",
|
||||
"\n",
|
||||
"CHAT_MODEL = \"gemini-2.5-flash\"\n",
|
||||
"EMBEDDING_MODEL = \"models/text-embedding-004\"\n",
|
||||
"# EMBEDDING_MODEL_EXP = \"models/gemini-embedding-exp-03-07\"\n",
|
||||
"\n",
|
||||
"folders = glob.glob(\"knowledge-base/*\")\n",
|
||||
"text_loader_kwargs = {'encoding': 'utf-8'}\n",
|
||||
"db_name = \"vector_db\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b83281a2-bcae-41ab-a347-0e7f9688d1ed",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"load_dotenv(override=True)\n",
|
||||
"\n",
|
||||
"api_key = os.getenv(\"GOOGLE_API_KEY\")\n",
|
||||
"\n",
|
||||
"if not api_key:\n",
|
||||
" print(\"API Key not found!\")\n",
|
||||
"else:\n",
|
||||
" print(\"API Key loaded in memory\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4fd6d516-772b-478d-9b28-09d42f2277d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def add_metadata(doc, doc_type):\n",
|
||||
" doc.metadata[\"doc_type\"] = doc_type\n",
|
||||
" return doc"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6bc4198b-f989-42c0-95b5-3596448fcaa2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"documents = []\n",
|
||||
"for folder in tqdm(folders, desc=\"Loading folders\"):\n",
|
||||
" doc_type = os.path.basename(folder)\n",
|
||||
" loader = DirectoryLoader(folder, glob=\"**/*.md\", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)\n",
|
||||
" folder_docs = loader.load()\n",
|
||||
" documents.extend([add_metadata(doc, doc_type) for doc in folder_docs])\n",
|
||||
"\n",
|
||||
"print(f\"Total documents loaded: {len(documents)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bb74241f-e9d5-42e8-9a4b-f31018397d66",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create Semantic Chunks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4a3aa17f-f5d0-430a-80da-95c284bd99a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chunking_embedding_model = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, task_type=\"retrieval_document\")\n",
|
||||
"\n",
|
||||
"text_splitter = SemanticChunker(\n",
|
||||
" chunking_embedding_model,\n",
|
||||
" breakpoint_threshold_type=\"percentile\", \n",
|
||||
" breakpoint_threshold_amount=95.0, \n",
|
||||
" min_chunk_size=3 \n",
|
||||
")\n",
|
||||
"\n",
|
||||
"start = time.time()\n",
|
||||
"\n",
|
||||
"semantic_chunks = []\n",
|
||||
"pbar = tqdm(documents, desc=\"Semantic chunking documents\")\n",
|
||||
"\n",
|
||||
"for i, doc in enumerate(pbar):\n",
|
||||
" doc_type = doc.metadata.get('doc_type', 'Unknown')\n",
|
||||
" pbar.set_postfix_str(f\"Processing: {doc_type}\")\n",
|
||||
" try:\n",
|
||||
" doc_chunks = text_splitter.split_documents([doc])\n",
|
||||
" semantic_chunks.extend(doc_chunks)\n",
|
||||
" except Exception as e:\n",
|
||||
" tqdm.write(f\"❌ Failed to split doc ({doc.metadata.get('source', 'unknown source')}): {e}\")\n",
|
||||
"print(f\"⏱️ Took {time.time() - start:.2f} seconds\")\n",
|
||||
"print(f\"Total semantic chunks: {len(semantic_chunks)}\")\n",
|
||||
"\n",
|
||||
"# import time\n",
|
||||
"# start = time.time()\n",
|
||||
"\n",
|
||||
"# try:\n",
|
||||
"# semantic_chunks = text_splitter.split_documents(documents)\n",
|
||||
"# print(f\"✅ Chunking completed with {len(semantic_chunks)} chunks\")\n",
|
||||
"# except Exception as e:\n",
|
||||
"# print(f\"❌ Failed to split documents: {e}\")\n",
|
||||
"\n",
|
||||
"# print(f\"⏱️ Took {time.time() - start:.2f} seconds\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "675b98d6-5ed0-45d1-8f79-765911e6badf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Some Preview of the chunks\n",
|
||||
"for i, doc in enumerate(semantic_chunks[:15]):\n",
|
||||
" print(f\"--- Chunk {i+1} ---\")\n",
|
||||
" print(doc.page_content) \n",
|
||||
" print(\"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c17accff-539a-490b-8a5f-b5ce632a3c71",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Embed with Gemini Embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0bd228bd-37d2-4aaf-b0f6-d94943f6f248",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embedding = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL,task_type=\"retrieval_document\")\n",
|
||||
"\n",
|
||||
"if os.path.exists(db_name):\n",
|
||||
" Chroma(persist_directory=db_name, embedding_function=embedding).delete_collection()\n",
|
||||
"\n",
|
||||
"vectorstore = Chroma.from_documents(\n",
|
||||
" documents=semantic_chunks,\n",
|
||||
" embedding=embedding,\n",
|
||||
" persist_directory=db_name\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"✅ Vectorstore created with {vectorstore._collection.count()} documents\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ce0a3e23-5912-4de2-bf34-3c0936375de1",
|
||||
"metadata": {
|
||||
"jp-MarkdownHeadingCollapsed": true
|
||||
},
|
||||
"source": [
|
||||
"## Visualzing Vectors"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ffdc6f5-ec25-4229-94d4-1fc6bb4d2702",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"collection = vectorstore._collection\n",
|
||||
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
|
||||
"vectors = np.array(result['embeddings'])\n",
|
||||
"documents = result['documents']\n",
|
||||
"metadatas = result['metadatas']\n",
|
||||
"doc_types = [metadata['doc_type'] for metadata in metadatas]\n",
|
||||
"colors = [['blue', 'green', 'red', 'orange'][['products', 'employees', 'contracts', 'company'].index(t)] for t in doc_types]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5428164b-f0d5-4d2b-ac4a-514c43ceaa79",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We humans find it easier to visalize things in 2D!\n",
|
||||
"# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
|
||||
"# (t-distributed stochastic neighbor embedding)\n",
|
||||
"\n",
|
||||
"tsne = TSNE(n_components=2, random_state=42)\n",
|
||||
"reduced_vectors = tsne.fit_transform(vectors)\n",
|
||||
"\n",
|
||||
"# Create the 2D scatter plot\n",
|
||||
"fig = go.Figure(data=[go.Scatter(\n",
|
||||
" x=reduced_vectors[:, 0],\n",
|
||||
" y=reduced_vectors[:, 1],\n",
|
||||
" mode='markers',\n",
|
||||
" marker=dict(size=5, color=colors, opacity=0.8),\n",
|
||||
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(doc_types, documents)],\n",
|
||||
" hoverinfo='text'\n",
|
||||
")])\n",
|
||||
"\n",
|
||||
"fig.update_layout(\n",
|
||||
" title='2D Chroma Vector Store Visualization',\n",
|
||||
" scene=dict(xaxis_title='x',yaxis_title='y'),\n",
|
||||
" width=800,\n",
|
||||
" height=600,\n",
|
||||
" margin=dict(r=20, b=10, l=10, t=40)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "359b8651-a382-4050-8bf8-123e5cdf4d53",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## RAG Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "08a75313-6c68-42e5-bd37-78254123094c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = vectorstore.as_retriever(search_kwargs={\"k\": 20 })\n",
|
||||
"\n",
|
||||
"# Conversation Memory\n",
|
||||
"# memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n",
|
||||
"\n",
|
||||
"chat_llm = ChatGoogleGenerativeAI(model=CHAT_MODEL, temperature=0.7)\n",
|
||||
"\n",
|
||||
"question_generator_template = \"\"\"Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.\n",
|
||||
"If the follow up question is already a standalone question, return it as is.\n",
|
||||
"\n",
|
||||
"Chat History:\n",
|
||||
"{chat_history}\n",
|
||||
"Follow Up Input: {input} \n",
|
||||
"Standalone question:\"\"\"\n",
|
||||
"\n",
|
||||
"question_generator_prompt = ChatPromptTemplate.from_messages([\n",
|
||||
" MessagesPlaceholder(variable_name=\"chat_history\"),\n",
|
||||
" HumanMessagePromptTemplate.from_template(\"{input}\")\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"history_aware_retriever = create_history_aware_retriever(\n",
|
||||
" chat_llm, retriever, question_generator_prompt\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"qa_system_prompt = \"\"\"You are Insurellm’s intelligent virtual assistant, designed to answer questions with accuracy and clarity. Respond naturally and helpfully, as if you're part of the team.\n",
|
||||
"Use the retrieved documents and prior conversation to provide accurate, conversational, and concise answers.Rephrase source facts in a natural tone, not word-for-word.\n",
|
||||
"When referencing people or company history, prioritize clarity and correctness.\n",
|
||||
"Only infer from previous conversation if it provides clear and factual clues. Do not guess or assume missing information.\n",
|
||||
"If you truly don’t have the answer, respond with:\n",
|
||||
"\"I don't have that information.\"\n",
|
||||
"Avoid repeating the user's wording unnecessarily. Do not refer to 'the context', speculate, or make up facts.\n",
|
||||
"\n",
|
||||
"{context}\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"qa_human_prompt = \"{input}\" \n",
|
||||
"\n",
|
||||
"qa_prompt = ChatPromptTemplate.from_messages([\n",
|
||||
" SystemMessagePromptTemplate.from_template(qa_system_prompt),\n",
|
||||
" MessagesPlaceholder(variable_name=\"chat_history\"),\n",
|
||||
" HumanMessagePromptTemplate.from_template(\"{input}\")\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"combine_docs_chain = create_stuff_documents_chain(chat_llm, qa_prompt)\n",
|
||||
"\n",
|
||||
"# inspect_context = RunnableLambda(lambda inputs: (\n",
|
||||
"# print(\"\\n Retrieved Context:\\n\", \"\\n---\\n\".join([doc.page_content for doc in inputs[\"context\"]])),\n",
|
||||
"# inputs # pass it through unchanged\n",
|
||||
"# )[1])\n",
|
||||
"\n",
|
||||
"# inspect_inputs = RunnableLambda(lambda inputs: (\n",
|
||||
"# print(\"\\n Inputs received by the chain:\\n\", inputs),\n",
|
||||
"# inputs\n",
|
||||
"# )[1])\n",
|
||||
"\n",
|
||||
"base_chain = create_retrieval_chain(history_aware_retriever, combine_docs_chain)\n",
|
||||
"\n",
|
||||
"# Using Runnable Lambda as Gradio needs the response to contain only the output (answer) and base_chain would have a dict with input, context, chat_history, answer\n",
|
||||
"\n",
|
||||
"# base_chain_with_output = base_chain | inspect_context | RunnableLambda(lambda res: res[\"answer\"])\n",
|
||||
"# base_chain_with_output = base_chain | RunnableLambda(lambda res: res[\"answer\"])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Session Persistent Chat History \n",
|
||||
"# If we want to persist history between sessions then use MongoDB (or any non sql DB)to store and use MongoDBChatMessageHistory (relevant DB Wrapper)\n",
|
||||
"\n",
|
||||
"chat_histories = {}\n",
|
||||
"\n",
|
||||
"def get_history(session_id):\n",
|
||||
" if session_id not in chat_histories:\n",
|
||||
" chat_histories[session_id] = InMemoryChatMessageHistory()\n",
|
||||
" return chat_histories[session_id]\n",
|
||||
"\n",
|
||||
"# Currently set to streaming ...if one shot response is needed then comment base_chain and output_message_key and enable base_chain_with_output\n",
|
||||
"conversation_chain = RunnableWithMessageHistory(\n",
|
||||
" # base_chain_with_output,\n",
|
||||
" base_chain,\n",
|
||||
" get_history,\n",
|
||||
" output_messages_key=\"answer\", \n",
|
||||
" input_messages_key=\"input\",\n",
|
||||
" history_messages_key=\"chat_history\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "06b58566-70cb-42eb-8b1c-9fe353fe71f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def chat(question, history):\n",
|
||||
" try:\n",
|
||||
" # result = conversation_chain.invoke({\"input\": question, \"chat_history\": memory.buffer_as_messages})\n",
|
||||
" \n",
|
||||
" # memory.chat_memory.add_user_message(question)\n",
|
||||
" # memory.chat_memory.add_ai_message(result[\"answer\"])\n",
|
||||
"\n",
|
||||
" # return result[\"answer\"]\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" session_id = \"default-session\"\n",
|
||||
"\n",
|
||||
" # # FUll chat version\n",
|
||||
" # result = conversation_chain.invoke(\n",
|
||||
" # {\"input\": question},\n",
|
||||
" # config={\"configurable\": {\"session_id\": session_id}}\n",
|
||||
" # )\n",
|
||||
" # # print(result)\n",
|
||||
" # return result\n",
|
||||
"\n",
|
||||
" # Streaming Version\n",
|
||||
" response_buffer = \"\"\n",
|
||||
"\n",
|
||||
" for chunk in conversation_chain.stream({\"input\": question},config={\"configurable\": {\"session_id\": session_id}}):\n",
|
||||
" if \"answer\" in chunk:\n",
|
||||
" response_buffer += chunk[\"answer\"]\n",
|
||||
" yield response_buffer \n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"An error occurred during chat: {e}\")\n",
|
||||
" return \"I apologize, but I encountered an error and cannot answer that right now.\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a577ac66-3952-4821-83d2-8a50bad89971",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "56b63a17-2522-46e5-b5a3-e2e80e52a723",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user