Merge branch 'main' of github.com:ed-donner/llm_engineering

This commit is contained in:
Edward Donner
2025-07-12 15:26:11 -04:00
51 changed files with 17066 additions and 0 deletions

View File

@@ -0,0 +1,463 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "2080947c-96d9-447f-8368-cfdc9e5c9960",
"metadata": {},
"source": [
"# Using Semantic chunks with Gemini API and Gemini Embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53221f1a-a0c1-4506-a3d0-d6626c58e4e0",
"metadata": {},
"outputs": [],
"source": [
"# Regular Imports\n",
"import os\n",
"import glob\n",
"import time\n",
"from dotenv import load_dotenv\n",
"from tqdm.notebook import tqdm\n",
"import gradio as gr"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a2a7171-a7b6-42a6-96d7-c93f360689ec",
"metadata": {},
"outputs": [],
"source": [
"# Visual Import\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.manifold import TSNE\n",
"import numpy as np\n",
"import plotly.graph_objects as go"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51c9d658-65e5-40a1-8680-d0b561f87649",
"metadata": {},
"outputs": [],
"source": [
"# Lang Chain Imports\n",
"\n",
"from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI\n",
"from langchain_community.document_loaders import DirectoryLoader, TextLoader\n",
"from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
"from langchain_core.messages import HumanMessage, AIMessage\n",
"from langchain_chroma import Chroma\n",
"from langchain_experimental.text_splitter import SemanticChunker\n",
"from langchain_core.chat_history import InMemoryChatMessageHistory\n",
"from langchain_core.runnables.history import RunnableWithMessageHistory\n",
"from langchain.chains.combine_documents import create_stuff_documents_chain\n",
"from langchain.chains.history_aware_retriever import create_history_aware_retriever\n",
"from langchain.chains import create_retrieval_chain\n",
"from langchain_core.prompts import MessagesPlaceholder\n",
"from langchain_core.runnables import RunnableLambda"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e7ed82b-b28a-4094-9f77-3b6432dd0f7a",
"metadata": {},
"outputs": [],
"source": [
"# Constants\n",
"\n",
"CHAT_MODEL = \"gemini-2.5-flash\"\n",
"EMBEDDING_MODEL = \"models/text-embedding-004\"\n",
"# EMBEDDING_MODEL_EXP = \"models/gemini-embedding-exp-03-07\"\n",
"\n",
"folders = glob.glob(\"knowledge-base/*\")\n",
"text_loader_kwargs = {'encoding': 'utf-8'}\n",
"db_name = \"vector_db\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b83281a2-bcae-41ab-a347-0e7f9688d1ed",
"metadata": {},
"outputs": [],
"source": [
"load_dotenv(override=True)\n",
"\n",
"api_key = os.getenv(\"GOOGLE_API_KEY\")\n",
"\n",
"if not api_key:\n",
" print(\"API Key not found!\")\n",
"else:\n",
" print(\"API Key loaded in memory\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4fd6d516-772b-478d-9b28-09d42f2277d7",
"metadata": {},
"outputs": [],
"source": [
"def add_metadata(doc, doc_type):\n",
" doc.metadata[\"doc_type\"] = doc_type\n",
" return doc"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6bc4198b-f989-42c0-95b5-3596448fcaa2",
"metadata": {},
"outputs": [],
"source": [
"documents = []\n",
"for folder in tqdm(folders, desc=\"Loading folders\"):\n",
" doc_type = os.path.basename(folder)\n",
" loader = DirectoryLoader(folder, glob=\"**/*.md\", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)\n",
" folder_docs = loader.load()\n",
" documents.extend([add_metadata(doc, doc_type) for doc in folder_docs])\n",
"\n",
"print(f\"Total documents loaded: {len(documents)}\")"
]
},
{
"cell_type": "markdown",
"id": "bb74241f-e9d5-42e8-9a4b-f31018397d66",
"metadata": {},
"source": [
"## Create Semantic Chunks"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a3aa17f-f5d0-430a-80da-95c284bd99a8",
"metadata": {},
"outputs": [],
"source": [
"chunking_embedding_model = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, task_type=\"retrieval_document\")\n",
"\n",
"text_splitter = SemanticChunker(\n",
" chunking_embedding_model,\n",
" breakpoint_threshold_type=\"percentile\", \n",
" breakpoint_threshold_amount=95.0, \n",
" min_chunk_size=3 \n",
")\n",
"\n",
"start = time.time()\n",
"\n",
"semantic_chunks = []\n",
"pbar = tqdm(documents, desc=\"Semantic chunking documents\")\n",
"\n",
"for i, doc in enumerate(pbar):\n",
" doc_type = doc.metadata.get('doc_type', 'Unknown')\n",
" pbar.set_postfix_str(f\"Processing: {doc_type}\")\n",
" try:\n",
" doc_chunks = text_splitter.split_documents([doc])\n",
" semantic_chunks.extend(doc_chunks)\n",
" except Exception as e:\n",
" tqdm.write(f\"❌ Failed to split doc ({doc.metadata.get('source', 'unknown source')}): {e}\")\n",
"print(f\"⏱️ Took {time.time() - start:.2f} seconds\")\n",
"print(f\"Total semantic chunks: {len(semantic_chunks)}\")\n",
"\n",
"# import time\n",
"# start = time.time()\n",
"\n",
"# try:\n",
"# semantic_chunks = text_splitter.split_documents(documents)\n",
"# print(f\"✅ Chunking completed with {len(semantic_chunks)} chunks\")\n",
"# except Exception as e:\n",
"# print(f\"❌ Failed to split documents: {e}\")\n",
"\n",
"# print(f\"⏱️ Took {time.time() - start:.2f} seconds\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "675b98d6-5ed0-45d1-8f79-765911e6badf",
"metadata": {},
"outputs": [],
"source": [
"# Some Preview of the chunks\n",
"for i, doc in enumerate(semantic_chunks[:15]):\n",
" print(f\"--- Chunk {i+1} ---\")\n",
" print(doc.page_content) \n",
" print(\"\\n\")"
]
},
{
"cell_type": "markdown",
"id": "c17accff-539a-490b-8a5f-b5ce632a3c71",
"metadata": {},
"source": [
"## Embed with Gemini Embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0bd228bd-37d2-4aaf-b0f6-d94943f6f248",
"metadata": {},
"outputs": [],
"source": [
"embedding = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL,task_type=\"retrieval_document\")\n",
"\n",
"if os.path.exists(db_name):\n",
" Chroma(persist_directory=db_name, embedding_function=embedding).delete_collection()\n",
"\n",
"vectorstore = Chroma.from_documents(\n",
" documents=semantic_chunks,\n",
" embedding=embedding,\n",
" persist_directory=db_name\n",
")\n",
"\n",
"print(f\"✅ Vectorstore created with {vectorstore._collection.count()} documents\")"
]
},
{
"cell_type": "markdown",
"id": "ce0a3e23-5912-4de2-bf34-3c0936375de1",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Visualzing Vectors"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ffdc6f5-ec25-4229-94d4-1fc6bb4d2702",
"metadata": {},
"outputs": [],
"source": [
"collection = vectorstore._collection\n",
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
"vectors = np.array(result['embeddings'])\n",
"documents = result['documents']\n",
"metadatas = result['metadatas']\n",
"doc_types = [metadata['doc_type'] for metadata in metadatas]\n",
"colors = [['blue', 'green', 'red', 'orange'][['products', 'employees', 'contracts', 'company'].index(t)] for t in doc_types]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5428164b-f0d5-4d2b-ac4a-514c43ceaa79",
"metadata": {},
"outputs": [],
"source": [
"# We humans find it easier to visalize things in 2D!\n",
"# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
"# (t-distributed stochastic neighbor embedding)\n",
"\n",
"tsne = TSNE(n_components=2, random_state=42)\n",
"reduced_vectors = tsne.fit_transform(vectors)\n",
"\n",
"# Create the 2D scatter plot\n",
"fig = go.Figure(data=[go.Scatter(\n",
" x=reduced_vectors[:, 0],\n",
" y=reduced_vectors[:, 1],\n",
" mode='markers',\n",
" marker=dict(size=5, color=colors, opacity=0.8),\n",
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(doc_types, documents)],\n",
" hoverinfo='text'\n",
")])\n",
"\n",
"fig.update_layout(\n",
" title='2D Chroma Vector Store Visualization',\n",
" scene=dict(xaxis_title='x',yaxis_title='y'),\n",
" width=800,\n",
" height=600,\n",
" margin=dict(r=20, b=10, l=10, t=40)\n",
")\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "markdown",
"id": "359b8651-a382-4050-8bf8-123e5cdf4d53",
"metadata": {},
"source": [
"## RAG Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08a75313-6c68-42e5-bd37-78254123094c",
"metadata": {},
"outputs": [],
"source": [
"retriever = vectorstore.as_retriever(search_kwargs={\"k\": 20 })\n",
"\n",
"# Conversation Memory\n",
"# memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n",
"\n",
"chat_llm = ChatGoogleGenerativeAI(model=CHAT_MODEL, temperature=0.7)\n",
"\n",
"question_generator_template = \"\"\"Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.\n",
"If the follow up question is already a standalone question, return it as is.\n",
"\n",
"Chat History:\n",
"{chat_history}\n",
"Follow Up Input: {input} \n",
"Standalone question:\"\"\"\n",
"\n",
"question_generator_prompt = ChatPromptTemplate.from_messages([\n",
" MessagesPlaceholder(variable_name=\"chat_history\"),\n",
" HumanMessagePromptTemplate.from_template(\"{input}\")\n",
"])\n",
"\n",
"history_aware_retriever = create_history_aware_retriever(\n",
" chat_llm, retriever, question_generator_prompt\n",
")\n",
"\n",
"qa_system_prompt = \"\"\"You are Insurellms intelligent virtual assistant, designed to answer questions with accuracy and clarity. Respond naturally and helpfully, as if you're part of the team.\n",
"Use the retrieved documents and prior conversation to provide accurate, conversational, and concise answers.Rephrase source facts in a natural tone, not word-for-word.\n",
"When referencing people or company history, prioritize clarity and correctness.\n",
"Only infer from previous conversation if it provides clear and factual clues. Do not guess or assume missing information.\n",
"If you truly dont have the answer, respond with:\n",
"\"I don't have that information.\"\n",
"Avoid repeating the user's wording unnecessarily. Do not refer to 'the context', speculate, or make up facts.\n",
"\n",
"{context}\"\"\"\n",
"\n",
"\n",
"qa_human_prompt = \"{input}\" \n",
"\n",
"qa_prompt = ChatPromptTemplate.from_messages([\n",
" SystemMessagePromptTemplate.from_template(qa_system_prompt),\n",
" MessagesPlaceholder(variable_name=\"chat_history\"),\n",
" HumanMessagePromptTemplate.from_template(\"{input}\")\n",
"])\n",
"\n",
"combine_docs_chain = create_stuff_documents_chain(chat_llm, qa_prompt)\n",
"\n",
"# inspect_context = RunnableLambda(lambda inputs: (\n",
"# print(\"\\n Retrieved Context:\\n\", \"\\n---\\n\".join([doc.page_content for doc in inputs[\"context\"]])),\n",
"# inputs # pass it through unchanged\n",
"# )[1])\n",
"\n",
"# inspect_inputs = RunnableLambda(lambda inputs: (\n",
"# print(\"\\n Inputs received by the chain:\\n\", inputs),\n",
"# inputs\n",
"# )[1])\n",
"\n",
"base_chain = create_retrieval_chain(history_aware_retriever, combine_docs_chain)\n",
"\n",
"# Using Runnable Lambda as Gradio needs the response to contain only the output (answer) and base_chain would have a dict with input, context, chat_history, answer\n",
"\n",
"# base_chain_with_output = base_chain | inspect_context | RunnableLambda(lambda res: res[\"answer\"])\n",
"# base_chain_with_output = base_chain | RunnableLambda(lambda res: res[\"answer\"])\n",
"\n",
"\n",
"# Session Persistent Chat History \n",
"# If we want to persist history between sessions then use MongoDB (or any non sql DB)to store and use MongoDBChatMessageHistory (relevant DB Wrapper)\n",
"\n",
"chat_histories = {}\n",
"\n",
"def get_history(session_id):\n",
" if session_id not in chat_histories:\n",
" chat_histories[session_id] = InMemoryChatMessageHistory()\n",
" return chat_histories[session_id]\n",
"\n",
"# Currently set to streaming ...if one shot response is needed then comment base_chain and output_message_key and enable base_chain_with_output\n",
"conversation_chain = RunnableWithMessageHistory(\n",
" # base_chain_with_output,\n",
" base_chain,\n",
" get_history,\n",
" output_messages_key=\"answer\", \n",
" input_messages_key=\"input\",\n",
" history_messages_key=\"chat_history\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06b58566-70cb-42eb-8b1c-9fe353fe71f0",
"metadata": {},
"outputs": [],
"source": [
"def chat(question, history):\n",
" try:\n",
" # result = conversation_chain.invoke({\"input\": question, \"chat_history\": memory.buffer_as_messages})\n",
" \n",
" # memory.chat_memory.add_user_message(question)\n",
" # memory.chat_memory.add_ai_message(result[\"answer\"])\n",
"\n",
" # return result[\"answer\"]\n",
"\n",
" \n",
" session_id = \"default-session\"\n",
"\n",
" # # FUll chat version\n",
" # result = conversation_chain.invoke(\n",
" # {\"input\": question},\n",
" # config={\"configurable\": {\"session_id\": session_id}}\n",
" # )\n",
" # # print(result)\n",
" # return result\n",
"\n",
" # Streaming Version\n",
" response_buffer = \"\"\n",
"\n",
" for chunk in conversation_chain.stream({\"input\": question},config={\"configurable\": {\"session_id\": session_id}}):\n",
" if \"answer\" in chunk:\n",
" response_buffer += chunk[\"answer\"]\n",
" yield response_buffer \n",
" except Exception as e:\n",
" print(f\"An error occurred during chat: {e}\")\n",
" return \"I apologize, but I encountered an error and cannot answer that right now.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a577ac66-3952-4821-83d2-8a50bad89971",
"metadata": {},
"outputs": [],
"source": [
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "56b63a17-2522-46e5-b5a3-e2e80e52a723",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,552 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "61777022-631c-4db0-afeb-70d8d22bc07b",
"metadata": {},
"source": [
"Summary:\n",
"This is the project from week 5. The intention was to create a vector db of my own files (from an external drive) which can be used in a RAG solution.\n",
"This includes a number of file types (docx, pdf, txt, epub...) and includes the ability to exclude folders.\n",
"With the OpenAI embeddings API limit of 300k tokens, it was also necessary to create a batch embeddings process so that there were multiple requests.\n",
"This was based on estimating the tokens with a text to token rate of 1:4, however it wasn't perfect and one of the batches still exceeded the 300k limit when running.\n",
"I found that the responses from the llm were terrible in the end! I tried playing about with chunk sizes and the minimum # of chunks by llangchain and it did improve but was not fantastic. I also ensured the metadata was sent with each chunk to help.\n",
"This really highlighted the real world challenges of implementing RAG!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d78ef79d-e564-4c56-82f3-0485e4bf6986",
"metadata": {},
"outputs": [],
"source": [
"!pip install docx2txt\n",
"!pip install ebooklib\n",
"!pip install python-pptx\n",
"!pip install pypdf"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ec98119-456f-450c-a9a2-f375d74f5ce5",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import requests\n",
"from dotenv import load_dotenv\n",
"import glob\n",
"import gradio as gr\n",
"import time\n",
"from typing import List"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac14410b-8c3c-4cf5-900e-fd4c33cdf2b2",
"metadata": {},
"outputs": [],
"source": [
"# imports for langchain, plotly and Chroma\n",
"\n",
"from langchain.document_loaders import (\n",
" DirectoryLoader,\n",
" Docx2txtLoader,\n",
" TextLoader,\n",
" PyPDFLoader,\n",
" UnstructuredExcelLoader,\n",
" BSHTMLLoader\n",
")\n",
"from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter\n",
"from langchain.schema import Document\n",
"from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n",
"from langchain_chroma import Chroma\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.manifold import TSNE\n",
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"from langchain.memory import ConversationBufferMemory\n",
"from langchain.chains import ConversationalRetrievalChain\n",
"from langchain.embeddings import HuggingFaceEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3be698e7-71e1-4c75-9696-e1651e4bf357",
"metadata": {},
"outputs": [],
"source": [
"MODEL = \"gpt-4o-mini\"\n",
"db_name = \"vector_db\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f850068-c05b-4526-9494-034b0077347e",
"metadata": {},
"outputs": [],
"source": [
"# Load environment variables in a file called .env\n",
"\n",
"load_dotenv(override=True)\n",
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c5baad2-2033-40a6-8ebd-5861b5cf4350",
"metadata": {},
"outputs": [],
"source": [
"# handling epubs\n",
"\n",
"from ebooklib import epub\n",
"from bs4 import BeautifulSoup\n",
"from langchain.document_loaders.base import BaseLoader\n",
"\n",
"class EpubLoader(BaseLoader):\n",
" def __init__(self, file_path: str):\n",
" self.file_path = file_path\n",
"\n",
" def load(self) -> list[Document]:\n",
" book = epub.read_epub(self.file_path)\n",
" text = ''\n",
" for item in book.get_items():\n",
" if item.get_type() == epub.EpubHtml:\n",
" soup = BeautifulSoup(item.get_content(), 'html.parser')\n",
" text += soup.get_text() + '\\n'\n",
"\n",
" return [Document(page_content=text, metadata={\"source\": self.file_path})]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd8b0e4e-d698-4484-bc94-d8b753f386cc",
"metadata": {},
"outputs": [],
"source": [
"# handling pptx\n",
"\n",
"from pptx import Presentation\n",
"\n",
"class PptxLoader(BaseLoader):\n",
" def __init__(self, file_path: str):\n",
" self.file_path = file_path\n",
"\n",
" def load(self) -> list[Document]:\n",
" prs = Presentation(self.file_path)\n",
" text = ''\n",
" for slide in prs.slides:\n",
" for shape in slide.shapes:\n",
" if hasattr(shape, \"text\") and shape.text:\n",
" text += shape.text + '\\n'\n",
"\n",
" return [Document(page_content=text, metadata={\"source\": self.file_path})]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b222b01d-6040-4ff3-a0e3-290819cfe94b",
"metadata": {},
"outputs": [],
"source": [
"# Class based version of document loader which can be expanded more easily for other document types. (Currently includes file types: docx, txt (windows encoding), xlsx, pdfs, epubs, pptx)\n",
"\n",
"class DocumentLoader:\n",
" \"\"\"A clean, extensible document loader for multiple file types.\"\"\"\n",
" \n",
" def __init__(self, base_path=\"D:/*\", exclude_folders=None):\n",
" self.base_path = base_path\n",
" self.documents = []\n",
" self.exclude_folders = exclude_folders or []\n",
" \n",
" # Configuration for different file types\n",
" self.loader_config = {\n",
" 'docx': {\n",
" 'loader_cls': Docx2txtLoader,\n",
" 'glob_pattern': \"**/*.docx\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': None\n",
" },\n",
" 'txt': {\n",
" 'loader_cls': TextLoader,\n",
" 'glob_pattern': \"**/*.txt\",\n",
" 'loader_kwargs': {\"encoding\": \"cp1252\"},\n",
" 'post_process': None\n",
" },\n",
" 'pdf': {\n",
" 'loader_cls': PyPDFLoader,\n",
" 'glob_pattern': \"**/*.pdf\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': None\n",
" },\n",
" 'xlsx': {\n",
" 'loader_cls': UnstructuredExcelLoader,\n",
" 'glob_pattern': \"**/*.xlsx\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': None\n",
" },\n",
" 'html': {\n",
" 'loader_cls': BSHTMLLoader,\n",
" 'glob_pattern': \"**/*.html\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': None\n",
" },\n",
" 'epub': {\n",
" 'loader_cls': EpubLoader,\n",
" 'glob_pattern': \"**/*.epub\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': self._process_epub_metadata\n",
" },\n",
" 'pptx': {\n",
" 'loader_cls': PptxLoader,\n",
" 'glob_pattern': \"**/*.pptx\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': None\n",
" }\n",
" }\n",
" \n",
" def _get_epub_metadata(self, file_path):\n",
" \"\"\"Extract metadata from EPUB files.\"\"\"\n",
" try:\n",
" book = epub.read_epub(file_path)\n",
" title = book.get_metadata('DC', 'title')[0][0] if book.get_metadata('DC', 'title') else None\n",
" author = book.get_metadata('DC', 'creator')[0][0] if book.get_metadata('DC', 'creator') else None\n",
" return title, author\n",
" except Exception as e:\n",
" print(f\"Error extracting EPUB metadata: {e}\")\n",
" return None, None\n",
" \n",
" def _process_epub_metadata(self, doc) -> None:\n",
" \"\"\"Post-process EPUB documents to add metadata.\"\"\"\n",
" title, author = self._get_epub_metadata(doc.metadata['source'])\n",
" doc.metadata[\"author\"] = author\n",
" doc.metadata[\"title\"] = title\n",
" \n",
" def _load_file_type(self, folder, file_type, config):\n",
" \"\"\"Load documents of a specific file type from a folder.\"\"\"\n",
" try:\n",
" loader = DirectoryLoader(\n",
" folder, \n",
" glob=config['glob_pattern'], \n",
" loader_cls=config['loader_cls'],\n",
" loader_kwargs=config['loader_kwargs']\n",
" )\n",
" docs = loader.load()\n",
" print(f\" Found {len(docs)} .{file_type} files\")\n",
" \n",
" # Apply post-processing if defined\n",
" if config['post_process']:\n",
" for doc in docs:\n",
" config['post_process'](doc)\n",
" \n",
" return docs\n",
" \n",
" except Exception as e:\n",
" print(f\" Error loading .{file_type} files: {e}\")\n",
" return []\n",
" \n",
" def load_all(self):\n",
" \"\"\"Load all documents from configured folders.\"\"\"\n",
" all_folders = [f for f in glob.glob(self.base_path) if os.path.isdir(f)]\n",
"\n",
" #filter out excluded folders\n",
" folders = []\n",
" for folder in all_folders:\n",
" folder_name = os.path.basename(folder)\n",
" if folder_name not in self.exclude_folders:\n",
" folders.append(folder)\n",
" else:\n",
" print(f\"Excluded folder: {folder_name}\")\n",
" \n",
" print(\"Scanning folders (directories only):\", folders)\n",
" \n",
" self.documents = []\n",
" \n",
" for folder in folders:\n",
" doc_type = os.path.basename(folder)\n",
" print(f\"\\nProcessing folder: {doc_type}\")\n",
" \n",
" for file_type, config in self.loader_config.items():\n",
" docs = self._load_file_type(folder, file_type, config)\n",
" \n",
" # Add doc_type metadata to all documents\n",
" for doc in docs:\n",
" doc.metadata[\"doc_type\"] = doc_type\n",
" self.documents.append(doc)\n",
" \n",
" print(f\"\\nTotal documents loaded: {len(self.documents)}\")\n",
" return self.documents\n",
" \n",
" def add_file_type(self, extension, loader_cls, glob_pattern=None, \n",
" loader_kwargs=None, post_process=None):\n",
" \"\"\"Add support for a new file type.\"\"\"\n",
" self.loader_config[extension] = {\n",
" 'loader_cls': loader_cls,\n",
" 'glob_pattern': glob_pattern or f\"**/*.{extension}\",\n",
" 'loader_kwargs': loader_kwargs or {},\n",
" 'post_process': post_process\n",
" }\n",
"\n",
"# load\n",
"loader = DocumentLoader(\"D:/*\", exclude_folders=[\"Music\", \"Online Courses\", \"Fitness\"])\n",
"documents = loader.load_all()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3fd43a4f-b623-4b08-89eb-27d3b3ba0f62",
"metadata": {},
"outputs": [],
"source": [
"# create batches (this was required as the # of tokens was exceed the openai request limit)\n",
"\n",
"def estimate_tokens(text, chars_per_token=4):\n",
" \"\"\"Rough estimate of tokens from character count.\"\"\"\n",
" return len(text) // chars_per_token\n",
"\n",
"def create_batches(chunks, max_tokens_per_batch=250000):\n",
" batches = []\n",
" current_batch = []\n",
" current_tokens = 0\n",
" \n",
" for chunk in chunks:\n",
" chunk_tokens = estimate_tokens(chunk.page_content)\n",
" \n",
" # If adding this chunk would exceed the limit, start a new batch\n",
" if current_tokens + chunk_tokens > max_tokens_per_batch and current_batch:\n",
" batches.append(current_batch)\n",
" current_batch = [chunk]\n",
" current_tokens = chunk_tokens\n",
" else:\n",
" current_batch.append(chunk)\n",
" current_tokens += chunk_tokens\n",
" \n",
" # Add the last batch if it has content\n",
" if current_batch:\n",
" batches.append(current_batch)\n",
" \n",
" return batches\n",
"\n",
"def create_vectorstore_with_progress(chunks, embeddings, db_name, batch_size_tokens=250000):\n",
" \n",
" # Delete existing database if it exists\n",
" if os.path.exists(db_name):\n",
" print(f\"Deleting existing database: {db_name}\")\n",
" Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()\n",
" \n",
" # Create batches\n",
" batches = create_batches(chunks, batch_size_tokens)\n",
" print(f\"Created {len(batches)} batches from {len(chunks)} chunks\")\n",
" \n",
" # Show batch sizes\n",
" for i, batch in enumerate(batches):\n",
" total_chars = sum(len(chunk.page_content) for chunk in batch)\n",
" estimated_tokens = estimate_tokens(''.join(chunk.page_content for chunk in batch))\n",
" print(f\" Batch {i+1}: {len(batch)} chunks, ~{estimated_tokens:,} tokens\")\n",
" \n",
" vectorstore = None\n",
" successful_batches = 0\n",
" failed_batches = 0\n",
" \n",
" for i, batch in enumerate(batches):\n",
" print(f\"\\n{'='*50}\")\n",
" print(f\"Processing batch {i+1}/{len(batches)}\")\n",
" print(f\"{'='*50}\")\n",
" \n",
" try:\n",
" start_time = time.time()\n",
" \n",
" if vectorstore is None:\n",
" # Create the initial vectorstore\n",
" vectorstore = Chroma.from_documents(\n",
" documents=batch,\n",
" embedding=embeddings,\n",
" persist_directory=db_name\n",
" )\n",
" print(f\"Created initial vectorstore with {len(batch)} documents\")\n",
" else:\n",
" # Add to existing vectorstore\n",
" vectorstore.add_documents(batch)\n",
" print(f\"Added {len(batch)} documents to vectorstore\")\n",
" \n",
" successful_batches += 1\n",
" elapsed = time.time() - start_time\n",
" print(f\"Processed in {elapsed:.1f} seconds\")\n",
" print(f\"Total documents in vectorstore: {vectorstore._collection.count()}\")\n",
" \n",
" # Rate limiting delay\n",
" time.sleep(2)\n",
" \n",
" except Exception as e:\n",
" failed_batches += 1\n",
" print(f\"Error processing batch {i+1}: {e}\")\n",
" print(f\"Continuing with next batch...\")\n",
" continue\n",
" \n",
" print(f\"\\n{'='*50}\")\n",
" print(f\"SUMMARY\")\n",
" print(f\"{'='*50}\")\n",
" print(f\"Successful batches: {successful_batches}/{len(batches)}\")\n",
" print(f\"Failed batches: {failed_batches}/{len(batches)}\")\n",
" \n",
" if vectorstore:\n",
" final_count = vectorstore._collection.count()\n",
" print(f\"Final vectorstore contains: {final_count} documents\")\n",
" return vectorstore\n",
" else:\n",
" print(\"Failed to create vectorstore\")\n",
" return None\n",
"\n",
"# include metadata\n",
"def add_metadata_to_content(doc: Document) -> Document:\n",
" metadata_lines = []\n",
" if \"doc_type\" in doc.metadata:\n",
" metadata_lines.append(f\"Document Type: {doc.metadata['doc_type']}\")\n",
" if \"title\" in doc.metadata:\n",
" metadata_lines.append(f\"Title: {doc.metadata['title']}\")\n",
" if \"author\" in doc.metadata:\n",
" metadata_lines.append(f\"Author: {doc.metadata['author']}\")\n",
" metadata_text = \"\\n\".join(metadata_lines)\n",
"\n",
" new_content = f\"{metadata_text}\\n\\n{doc.page_content}\"\n",
" return Document(page_content=new_content, metadata=doc.metadata)\n",
"\n",
"# Apply to all documents before chunking\n",
"documents_with_metadata = [add_metadata_to_content(doc) for doc in documents]\n",
"\n",
"# Chunking\n",
"text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=200)\n",
"chunks = text_splitter.split_documents(documents_with_metadata)\n",
"\n",
"# Embedding\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"# Store in vector DB\n",
"print(\"Creating vectorstore in batches...\")\n",
"vectorstore = create_vectorstore_with_progress(\n",
" chunks=chunks,\n",
" embeddings=embeddings, \n",
" db_name=db_name,\n",
" batch_size_tokens=250000\n",
")\n",
"\n",
"if vectorstore:\n",
" print(f\"Successfully created vectorstore with {vectorstore._collection.count()} documents\")\n",
"else:\n",
" print(\"Failed to create vectorstore\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46c29b11-2ae3-4f6b-901d-5de67a09fd49",
"metadata": {},
"outputs": [],
"source": [
"# create a new Chat with OpenAI\n",
"llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n",
"\n",
"# set up the conversation memory for the chat\n",
"memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
"\n",
"# the retriever is an abstraction over the VectorStore that will be used during RAG\n",
"retriever = vectorstore.as_retriever(search_kwargs={\"k\": 200})\n",
"\n",
"# putting it together: set up the conversation chain with the GPT 3.5 LLM, the vector store and memory\n",
"conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be163251-0dfa-4f50-ab05-43c6c0833405",
"metadata": {},
"outputs": [],
"source": [
"# Wrapping that in a function\n",
"\n",
"def chat(question, history):\n",
" result = conversation_chain.invoke({\"question\": question})\n",
" return result[\"answer\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6320402-8213-47ec-8b05-dda234052274",
"metadata": {},
"outputs": [],
"source": [
"# And in Gradio:\n",
"\n",
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "717e010b-8d7e-4a43-8cb1-9688ffdd76b6",
"metadata": {},
"outputs": [],
"source": [
"# Let's investigate what gets sent behind the scenes\n",
"\n",
"# from langchain_core.callbacks import StdOutCallbackHandler\n",
"\n",
"# llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n",
"\n",
"# memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
"\n",
"# retriever = vectorstore.as_retriever(search_kwargs={\"k\": 200})\n",
"\n",
"# conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory, callbacks=[StdOutCallbackHandler()])\n",
"\n",
"# query = \"Can you name some authors?\"\n",
"# result = conversation_chain.invoke({\"question\": query})\n",
"# answer = result[\"answer\"]\n",
"# print(\"\\nAnswer:\", answer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2333a77e-8d32-4cc2-8ae9-f8e7a979b3ae",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,472 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "dfe37963-1af6-44fc-a841-8e462443f5e6",
"metadata": {},
"source": [
"## gmail RAG assistant"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba2779af-84ef-4227-9e9e-6eaf0df87e77",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import glob\n",
"from dotenv import load_dotenv\n",
"import gradio as gr\n",
"# NEW IMPORTS FOR GMAIL\n",
"from google.auth.transport.requests import Request\n",
"from google.oauth2.credentials import Credentials\n",
"from google_auth_oauthlib.flow import InstalledAppFlow\n",
"from googleapiclient.discovery import build\n",
"from datetime import datetime\n",
"import base64\n",
"from email.mime.text import MIMEText\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "802137aa-8a74-45e0-a487-d1974927d7ca",
"metadata": {},
"outputs": [],
"source": [
"# imports for langchain, plotly and Chroma\n",
"\n",
"from langchain.document_loaders import DirectoryLoader, TextLoader\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.schema import Document\n",
"from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n",
"from langchain_chroma import Chroma\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.manifold import TSNE\n",
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"from langchain.memory import ConversationBufferMemory\n",
"from langchain.chains import ConversationalRetrievalChain\n",
"from langchain.embeddings import HuggingFaceEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58c85082-e417-4708-9efe-81a5d55d1424",
"metadata": {},
"outputs": [],
"source": [
"# price is a factor for our company, so we're going to use a low cost model\n",
"\n",
"MODEL = \"gpt-4o-mini\"\n",
"db_name = \"vector_db\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee78efcb-60fe-449e-a944-40bab26261af",
"metadata": {},
"outputs": [],
"source": [
"# Load environment variables in a file called .env\n",
"\n",
"load_dotenv(override=True)\n",
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
"# NEW: Gmail API credentials\n",
"SCOPES = ['https://www.googleapis.com/auth/gmail.readonly']\n",
"CREDENTIALS_FILE = 'credentials.json' # Download from Google Cloud Console\n",
"TOKEN_FILE = 'token.json'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "730711a9-6ffe-4eee-8f48-d6cfb7314905",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Read in emails using LangChain's loaders\n",
"# IMPORTANT: set the email received date range hard-coded below\n",
"\n",
"def authenticate_gmail():\n",
" \"\"\"Authenticate and return Gmail service object\"\"\"\n",
" creds = None\n",
" if os.path.exists(TOKEN_FILE):\n",
" creds = Credentials.from_authorized_user_file(TOKEN_FILE, SCOPES)\n",
" \n",
" if not creds or not creds.valid:\n",
" if creds and creds.expired and creds.refresh_token:\n",
" creds.refresh(Request())\n",
" else:\n",
" flow = InstalledAppFlow.from_client_secrets_file(CREDENTIALS_FILE, SCOPES)\n",
" creds = flow.run_local_server(port=0)\n",
" \n",
" with open(TOKEN_FILE, 'w') as token:\n",
" token.write(creds.to_json())\n",
" \n",
" return build('gmail', 'v1', credentials=creds)\n",
"\n",
"def get_email_content(service, message_id):\n",
" \"\"\"Extract email content from message\"\"\"\n",
" try:\n",
" message = service.users().messages().get(userId='me', id=message_id, format='full').execute()\n",
" \n",
" # Extract basic info\n",
" headers = message['payload'].get('headers', [])\n",
" subject = next((h['value'] for h in headers if h['name'] == 'Subject'), 'No Subject')\n",
" sender = next((h['value'] for h in headers if h['name'] == 'From'), 'Unknown Sender')\n",
" date = next((h['value'] for h in headers if h['name'] == 'Date'), 'Unknown Date')\n",
" \n",
" # Extract body\n",
" body = \"\"\n",
" if 'parts' in message['payload']:\n",
" for part in message['payload']['parts']:\n",
" if part['mimeType'] == 'text/plain':\n",
" data = part['body']['data']\n",
" body = base64.urlsafe_b64decode(data).decode('utf-8')\n",
" break\n",
" else:\n",
" if message['payload']['body'].get('data'):\n",
" body = base64.urlsafe_b64decode(message['payload']['body']['data']).decode('utf-8')\n",
" \n",
" # Clean up body text\n",
" body = re.sub(r'\\s+', ' ', body).strip()\n",
" \n",
" return {\n",
" 'subject': subject,\n",
" 'sender': sender,\n",
" 'date': date,\n",
" 'body': body,\n",
" 'id': message_id\n",
" }\n",
" except Exception as e:\n",
" print(f\"Error processing message {message_id}: {str(e)}\")\n",
" return None\n",
"\n",
"def load_gmail_documents(start_date, end_date, max_emails=100):\n",
" \"\"\"Load emails from Gmail between specified dates\"\"\"\n",
" service = authenticate_gmail()\n",
" \n",
" # Format dates for Gmail API (YYYY/MM/DD)\n",
" start_date_str = start_date.strftime('%Y/%m/%d')\n",
" end_date_str = end_date.strftime('%Y/%m/%d')\n",
" \n",
" # Build query\n",
" query = f'after:{start_date_str} before:{end_date_str}'\n",
" \n",
" # Get message list\n",
" result = service.users().messages().list(userId='me', q=query, maxResults=max_emails).execute()\n",
" messages = result.get('messages', [])\n",
" \n",
" print(f\"Found {len(messages)} emails between {start_date_str} and {end_date_str}\")\n",
" \n",
" # Convert to LangChain documents\n",
" documents = []\n",
" for i, message in enumerate(messages):\n",
" print(f\"Processing email {i+1}/{len(messages)}\")\n",
" email_data = get_email_content(service, message['id'])\n",
" \n",
" if email_data and email_data['body']:\n",
" # Create document content\n",
" content = f\"\"\"Subject: {email_data['subject']}\n",
"From: {email_data['sender']}\n",
"Date: {email_data['date']}\n",
"\n",
"{email_data['body']}\"\"\"\n",
" \n",
" # Create LangChain document\n",
" doc = Document(\n",
" page_content=content,\n",
" metadata={\n",
" \"doc_type\": \"email\",\n",
" \"subject\": email_data['subject'],\n",
" \"sender\": email_data['sender'],\n",
" \"date\": email_data['date'],\n",
" \"message_id\": email_data['id']\n",
" }\n",
" )\n",
" documents.append(doc)\n",
" \n",
" return documents\n",
"\n",
"# SET YOUR DATE RANGE HERE\n",
"start_date = datetime(2025, 6, 20) # YYYY, MM, DD\n",
"end_date = datetime(2025, 6, 26) # YYYY, MM, DD\n",
"\n",
"# Load Gmail documents \n",
"documents = load_gmail_documents(start_date, end_date, max_emails=200)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c59de72d-f965-44b3-8487-283e4c623b1d",
"metadata": {},
"outputs": [],
"source": [
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
"chunks = text_splitter.split_documents(documents)\n",
"\n",
"print(f\"Total number of chunks: {len(chunks)}\")\n",
"print(f\"Document types found: {set(doc.metadata['doc_type'] for doc in documents)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78998399-ac17-4e28-b15f-0b5f51e6ee23",
"metadata": {},
"outputs": [],
"source": [
"# Put the chunks of data into a Vector Store that associates a Vector Embedding with each chunk\n",
"# Chroma is a popular open source Vector Database based on SQLLite\n",
"\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"# If you would rather use the free Vector Embeddings from HuggingFace sentence-transformers\n",
"# Then replace embeddings = OpenAIEmbeddings()\n",
"# with:\n",
"# from langchain.embeddings import HuggingFaceEmbeddings\n",
"# embeddings = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-MiniLM-L6-v2\")\n",
"\n",
"# Delete if already exists\n",
"\n",
"if os.path.exists(db_name):\n",
" Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()\n",
"\n",
"# Create vectorstore\n",
"\n",
"vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=db_name)\n",
"print(f\"Vectorstore created with {vectorstore._collection.count()} documents\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff2e7687-60d4-4920-a1d7-a34b9f70a250",
"metadata": {},
"outputs": [],
"source": [
"# Let's investigate the vectors\n",
"\n",
"collection = vectorstore._collection\n",
"count = collection.count()\n",
"\n",
"sample_embedding = collection.get(limit=1, include=[\"embeddings\"])[\"embeddings\"][0]\n",
"dimensions = len(sample_embedding)\n",
"print(f\"There are {count:,} vectors with {dimensions:,} dimensions in the vector store\")"
]
},
{
"cell_type": "markdown",
"id": "b0d45462-a818-441c-b010-b85b32bcf618",
"metadata": {},
"source": [
"## Visualizing the Vector Store\n",
"\n",
"Let's take a minute to look at the documents and their embedding vectors to see what's going on."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b98adf5e-d464-4bd2-9bdf-bc5b6770263b",
"metadata": {},
"outputs": [],
"source": [
"# Prework (with thanks to Jon R for identifying and fixing a bug in this!)\n",
"\n",
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
"vectors = np.array(result['embeddings'])\n",
"documents = result['documents']\n",
"metadatas = result['metadatas']\n",
"\n",
"# Alternatively, color by sender:\n",
"senders = [metadata.get('sender', 'unknown') for metadata in metadatas]\n",
"unique_senders = list(set(senders))\n",
"sender_colors = ['blue', 'green', 'red', 'orange', 'purple', 'brown', 'pink', 'gray']\n",
"colors = [sender_colors[unique_senders.index(sender) % len(sender_colors)] for sender in senders]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "427149d5-e5d8-4abd-bb6f-7ef0333cca21",
"metadata": {},
"outputs": [],
"source": [
"# We humans find it easier to visalize things in 2D!\n",
"# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
"# (t-distributed stochastic neighbor embedding)\n",
"\n",
"tsne = TSNE(n_components=2, random_state=42)\n",
"reduced_vectors = tsne.fit_transform(vectors)\n",
"\n",
"# Create the 2D scatter plot\n",
"fig = go.Figure(data=[go.Scatter(\n",
" x=reduced_vectors[:, 0],\n",
" y=reduced_vectors[:, 1],\n",
" mode='markers',\n",
" marker=dict(size=5, color=colors, opacity=0.8),\n",
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(senders, documents)],\n",
" hoverinfo='text'\n",
")])\n",
"\n",
"fig.update_layout(\n",
" title='2D Chroma Vector Store Visualization',\n",
" scene=dict(xaxis_title='x',yaxis_title='y'),\n",
" width=800,\n",
" height=600,\n",
" margin=dict(r=20, b=10, l=10, t=40)\n",
")\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1418e88-acd5-460a-bf2b-4e6efc88e3dd",
"metadata": {},
"outputs": [],
"source": [
"# Let's try 3D!\n",
"\n",
"tsne = TSNE(n_components=3, random_state=42)\n",
"reduced_vectors = tsne.fit_transform(vectors)\n",
"\n",
"# Create the 3D scatter plot\n",
"fig = go.Figure(data=[go.Scatter3d(\n",
" x=reduced_vectors[:, 0],\n",
" y=reduced_vectors[:, 1],\n",
" z=reduced_vectors[:, 2],\n",
" mode='markers',\n",
" marker=dict(size=5, color=colors, opacity=0.8),\n",
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(senders, documents)],\n",
" hoverinfo='text'\n",
")])\n",
"\n",
"fig.update_layout(\n",
" title='3D Chroma Vector Store Visualization',\n",
" scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),\n",
" width=900,\n",
" height=700,\n",
" margin=dict(r=20, b=10, l=10, t=40)\n",
")\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "markdown",
"id": "bbbcb659-13ce-47ab-8a5e-01b930494964",
"metadata": {},
"source": [
"## Langchain and Gradio to prototype a chat with the LLM\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d72567e8-f891-4797-944b-4612dc6613b1",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains.combine_documents import create_stuff_documents_chain\n",
"from langchain.chains import create_retrieval_chain\n",
"\n",
"# create a new Chat with OpenAI\n",
"llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n",
"\n",
"# Alternative - if you'd like to use Ollama locally, uncomment this line instead\n",
"# llm = ChatOpenAI(temperature=0.7, model_name='llama3.2', base_url='http://localhost:11434/v1', api_key='ollama')\n",
"\n",
"# change LLM standard prompt (standard prompt defaults the answer to be 'I don't know' too often, especially when using a small LLM\n",
"\n",
"qa_prompt=PromptTemplate.from_template(\"Use the following pieces of context to answer the user's question. Answer as best you can given the information you have;\\\n",
" if you have a reasonable idea of the answer,/then explain it and mention that you're unsure. \\\n",
" But if you don't know the answer, don't make it up. \\\n",
" {context} \\\n",
" Question: {question} \\\n",
" Helpful Answer:\"\n",
" )\n",
"\n",
"\n",
"# Wrap into a StuffDocumentsChain, matching the variable name 'context'\n",
"combine_docs_chain = create_stuff_documents_chain(\n",
" llm=llm,\n",
" prompt=qa_prompt,\n",
" document_variable_name=\"context\"\n",
")\n",
"\n",
"# set up the conversation memory for the chat\n",
"#memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
"memory = ConversationBufferMemory(\n",
" memory_key='chat_history', \n",
" return_messages=True,\n",
" output_key='answer' \n",
")\n",
"\n",
"# the retriever is an abstraction over the VectorStore that will be used during RAG\n",
"retriever = vectorstore.as_retriever(search_kwargs={\"k\": 10})\n",
"\n",
"# putting it together: set up the conversation chain with the GPT 3.5 LLM, the vector store and memory\n",
"# conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)\n",
"\n",
"conversation_chain = ConversationalRetrievalChain.from_llm(\n",
" llm=llm,\n",
" retriever=retriever,\n",
" memory=memory,\n",
" combine_docs_chain_kwargs={\"prompt\": qa_prompt},\n",
" return_source_documents=True\n",
")\n",
"\n",
"def chat(question, history):\n",
" result = conversation_chain.invoke({\"question\": question})\n",
" return result[\"answer\"]\n",
"\n",
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe4229aa-6afe-4592-93a4-71a47ab69846",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}