Merge branch 'main' of github.com:ed-donner/llm_engineering
This commit is contained in:
@@ -0,0 +1,463 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2080947c-96d9-447f-8368-cfdc9e5c9960",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Using Semantic chunks with Gemini API and Gemini Embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "53221f1a-a0c1-4506-a3d0-d6626c58e4e0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Regular Imports\n",
|
||||
"import os\n",
|
||||
"import glob\n",
|
||||
"import time\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from tqdm.notebook import tqdm\n",
|
||||
"import gradio as gr"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9a2a7171-a7b6-42a6-96d7-c93f360689ec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visual Import\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from sklearn.manifold import TSNE\n",
|
||||
"import numpy as np\n",
|
||||
"import plotly.graph_objects as go"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "51c9d658-65e5-40a1-8680-d0b561f87649",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Lang Chain Imports\n",
|
||||
"\n",
|
||||
"from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI\n",
|
||||
"from langchain_community.document_loaders import DirectoryLoader, TextLoader\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
|
||||
"from langchain_core.messages import HumanMessage, AIMessage\n",
|
||||
"from langchain_chroma import Chroma\n",
|
||||
"from langchain_experimental.text_splitter import SemanticChunker\n",
|
||||
"from langchain_core.chat_history import InMemoryChatMessageHistory\n",
|
||||
"from langchain_core.runnables.history import RunnableWithMessageHistory\n",
|
||||
"from langchain.chains.combine_documents import create_stuff_documents_chain\n",
|
||||
"from langchain.chains.history_aware_retriever import create_history_aware_retriever\n",
|
||||
"from langchain.chains import create_retrieval_chain\n",
|
||||
"from langchain_core.prompts import MessagesPlaceholder\n",
|
||||
"from langchain_core.runnables import RunnableLambda"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6e7ed82b-b28a-4094-9f77-3b6432dd0f7a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Constants\n",
|
||||
"\n",
|
||||
"CHAT_MODEL = \"gemini-2.5-flash\"\n",
|
||||
"EMBEDDING_MODEL = \"models/text-embedding-004\"\n",
|
||||
"# EMBEDDING_MODEL_EXP = \"models/gemini-embedding-exp-03-07\"\n",
|
||||
"\n",
|
||||
"folders = glob.glob(\"knowledge-base/*\")\n",
|
||||
"text_loader_kwargs = {'encoding': 'utf-8'}\n",
|
||||
"db_name = \"vector_db\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b83281a2-bcae-41ab-a347-0e7f9688d1ed",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"load_dotenv(override=True)\n",
|
||||
"\n",
|
||||
"api_key = os.getenv(\"GOOGLE_API_KEY\")\n",
|
||||
"\n",
|
||||
"if not api_key:\n",
|
||||
" print(\"API Key not found!\")\n",
|
||||
"else:\n",
|
||||
" print(\"API Key loaded in memory\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4fd6d516-772b-478d-9b28-09d42f2277d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def add_metadata(doc, doc_type):\n",
|
||||
" doc.metadata[\"doc_type\"] = doc_type\n",
|
||||
" return doc"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6bc4198b-f989-42c0-95b5-3596448fcaa2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"documents = []\n",
|
||||
"for folder in tqdm(folders, desc=\"Loading folders\"):\n",
|
||||
" doc_type = os.path.basename(folder)\n",
|
||||
" loader = DirectoryLoader(folder, glob=\"**/*.md\", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)\n",
|
||||
" folder_docs = loader.load()\n",
|
||||
" documents.extend([add_metadata(doc, doc_type) for doc in folder_docs])\n",
|
||||
"\n",
|
||||
"print(f\"Total documents loaded: {len(documents)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bb74241f-e9d5-42e8-9a4b-f31018397d66",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create Semantic Chunks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4a3aa17f-f5d0-430a-80da-95c284bd99a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chunking_embedding_model = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, task_type=\"retrieval_document\")\n",
|
||||
"\n",
|
||||
"text_splitter = SemanticChunker(\n",
|
||||
" chunking_embedding_model,\n",
|
||||
" breakpoint_threshold_type=\"percentile\", \n",
|
||||
" breakpoint_threshold_amount=95.0, \n",
|
||||
" min_chunk_size=3 \n",
|
||||
")\n",
|
||||
"\n",
|
||||
"start = time.time()\n",
|
||||
"\n",
|
||||
"semantic_chunks = []\n",
|
||||
"pbar = tqdm(documents, desc=\"Semantic chunking documents\")\n",
|
||||
"\n",
|
||||
"for i, doc in enumerate(pbar):\n",
|
||||
" doc_type = doc.metadata.get('doc_type', 'Unknown')\n",
|
||||
" pbar.set_postfix_str(f\"Processing: {doc_type}\")\n",
|
||||
" try:\n",
|
||||
" doc_chunks = text_splitter.split_documents([doc])\n",
|
||||
" semantic_chunks.extend(doc_chunks)\n",
|
||||
" except Exception as e:\n",
|
||||
" tqdm.write(f\"❌ Failed to split doc ({doc.metadata.get('source', 'unknown source')}): {e}\")\n",
|
||||
"print(f\"⏱️ Took {time.time() - start:.2f} seconds\")\n",
|
||||
"print(f\"Total semantic chunks: {len(semantic_chunks)}\")\n",
|
||||
"\n",
|
||||
"# import time\n",
|
||||
"# start = time.time()\n",
|
||||
"\n",
|
||||
"# try:\n",
|
||||
"# semantic_chunks = text_splitter.split_documents(documents)\n",
|
||||
"# print(f\"✅ Chunking completed with {len(semantic_chunks)} chunks\")\n",
|
||||
"# except Exception as e:\n",
|
||||
"# print(f\"❌ Failed to split documents: {e}\")\n",
|
||||
"\n",
|
||||
"# print(f\"⏱️ Took {time.time() - start:.2f} seconds\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "675b98d6-5ed0-45d1-8f79-765911e6badf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Some Preview of the chunks\n",
|
||||
"for i, doc in enumerate(semantic_chunks[:15]):\n",
|
||||
" print(f\"--- Chunk {i+1} ---\")\n",
|
||||
" print(doc.page_content) \n",
|
||||
" print(\"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c17accff-539a-490b-8a5f-b5ce632a3c71",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Embed with Gemini Embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0bd228bd-37d2-4aaf-b0f6-d94943f6f248",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embedding = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL,task_type=\"retrieval_document\")\n",
|
||||
"\n",
|
||||
"if os.path.exists(db_name):\n",
|
||||
" Chroma(persist_directory=db_name, embedding_function=embedding).delete_collection()\n",
|
||||
"\n",
|
||||
"vectorstore = Chroma.from_documents(\n",
|
||||
" documents=semantic_chunks,\n",
|
||||
" embedding=embedding,\n",
|
||||
" persist_directory=db_name\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"✅ Vectorstore created with {vectorstore._collection.count()} documents\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ce0a3e23-5912-4de2-bf34-3c0936375de1",
|
||||
"metadata": {
|
||||
"jp-MarkdownHeadingCollapsed": true
|
||||
},
|
||||
"source": [
|
||||
"## Visualzing Vectors"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ffdc6f5-ec25-4229-94d4-1fc6bb4d2702",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"collection = vectorstore._collection\n",
|
||||
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
|
||||
"vectors = np.array(result['embeddings'])\n",
|
||||
"documents = result['documents']\n",
|
||||
"metadatas = result['metadatas']\n",
|
||||
"doc_types = [metadata['doc_type'] for metadata in metadatas]\n",
|
||||
"colors = [['blue', 'green', 'red', 'orange'][['products', 'employees', 'contracts', 'company'].index(t)] for t in doc_types]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5428164b-f0d5-4d2b-ac4a-514c43ceaa79",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We humans find it easier to visalize things in 2D!\n",
|
||||
"# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
|
||||
"# (t-distributed stochastic neighbor embedding)\n",
|
||||
"\n",
|
||||
"tsne = TSNE(n_components=2, random_state=42)\n",
|
||||
"reduced_vectors = tsne.fit_transform(vectors)\n",
|
||||
"\n",
|
||||
"# Create the 2D scatter plot\n",
|
||||
"fig = go.Figure(data=[go.Scatter(\n",
|
||||
" x=reduced_vectors[:, 0],\n",
|
||||
" y=reduced_vectors[:, 1],\n",
|
||||
" mode='markers',\n",
|
||||
" marker=dict(size=5, color=colors, opacity=0.8),\n",
|
||||
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(doc_types, documents)],\n",
|
||||
" hoverinfo='text'\n",
|
||||
")])\n",
|
||||
"\n",
|
||||
"fig.update_layout(\n",
|
||||
" title='2D Chroma Vector Store Visualization',\n",
|
||||
" scene=dict(xaxis_title='x',yaxis_title='y'),\n",
|
||||
" width=800,\n",
|
||||
" height=600,\n",
|
||||
" margin=dict(r=20, b=10, l=10, t=40)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "359b8651-a382-4050-8bf8-123e5cdf4d53",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## RAG Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "08a75313-6c68-42e5-bd37-78254123094c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = vectorstore.as_retriever(search_kwargs={\"k\": 20 })\n",
|
||||
"\n",
|
||||
"# Conversation Memory\n",
|
||||
"# memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n",
|
||||
"\n",
|
||||
"chat_llm = ChatGoogleGenerativeAI(model=CHAT_MODEL, temperature=0.7)\n",
|
||||
"\n",
|
||||
"question_generator_template = \"\"\"Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.\n",
|
||||
"If the follow up question is already a standalone question, return it as is.\n",
|
||||
"\n",
|
||||
"Chat History:\n",
|
||||
"{chat_history}\n",
|
||||
"Follow Up Input: {input} \n",
|
||||
"Standalone question:\"\"\"\n",
|
||||
"\n",
|
||||
"question_generator_prompt = ChatPromptTemplate.from_messages([\n",
|
||||
" MessagesPlaceholder(variable_name=\"chat_history\"),\n",
|
||||
" HumanMessagePromptTemplate.from_template(\"{input}\")\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"history_aware_retriever = create_history_aware_retriever(\n",
|
||||
" chat_llm, retriever, question_generator_prompt\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"qa_system_prompt = \"\"\"You are Insurellm’s intelligent virtual assistant, designed to answer questions with accuracy and clarity. Respond naturally and helpfully, as if you're part of the team.\n",
|
||||
"Use the retrieved documents and prior conversation to provide accurate, conversational, and concise answers.Rephrase source facts in a natural tone, not word-for-word.\n",
|
||||
"When referencing people or company history, prioritize clarity and correctness.\n",
|
||||
"Only infer from previous conversation if it provides clear and factual clues. Do not guess or assume missing information.\n",
|
||||
"If you truly don’t have the answer, respond with:\n",
|
||||
"\"I don't have that information.\"\n",
|
||||
"Avoid repeating the user's wording unnecessarily. Do not refer to 'the context', speculate, or make up facts.\n",
|
||||
"\n",
|
||||
"{context}\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"qa_human_prompt = \"{input}\" \n",
|
||||
"\n",
|
||||
"qa_prompt = ChatPromptTemplate.from_messages([\n",
|
||||
" SystemMessagePromptTemplate.from_template(qa_system_prompt),\n",
|
||||
" MessagesPlaceholder(variable_name=\"chat_history\"),\n",
|
||||
" HumanMessagePromptTemplate.from_template(\"{input}\")\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"combine_docs_chain = create_stuff_documents_chain(chat_llm, qa_prompt)\n",
|
||||
"\n",
|
||||
"# inspect_context = RunnableLambda(lambda inputs: (\n",
|
||||
"# print(\"\\n Retrieved Context:\\n\", \"\\n---\\n\".join([doc.page_content for doc in inputs[\"context\"]])),\n",
|
||||
"# inputs # pass it through unchanged\n",
|
||||
"# )[1])\n",
|
||||
"\n",
|
||||
"# inspect_inputs = RunnableLambda(lambda inputs: (\n",
|
||||
"# print(\"\\n Inputs received by the chain:\\n\", inputs),\n",
|
||||
"# inputs\n",
|
||||
"# )[1])\n",
|
||||
"\n",
|
||||
"base_chain = create_retrieval_chain(history_aware_retriever, combine_docs_chain)\n",
|
||||
"\n",
|
||||
"# Using Runnable Lambda as Gradio needs the response to contain only the output (answer) and base_chain would have a dict with input, context, chat_history, answer\n",
|
||||
"\n",
|
||||
"# base_chain_with_output = base_chain | inspect_context | RunnableLambda(lambda res: res[\"answer\"])\n",
|
||||
"# base_chain_with_output = base_chain | RunnableLambda(lambda res: res[\"answer\"])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Session Persistent Chat History \n",
|
||||
"# If we want to persist history between sessions then use MongoDB (or any non sql DB)to store and use MongoDBChatMessageHistory (relevant DB Wrapper)\n",
|
||||
"\n",
|
||||
"chat_histories = {}\n",
|
||||
"\n",
|
||||
"def get_history(session_id):\n",
|
||||
" if session_id not in chat_histories:\n",
|
||||
" chat_histories[session_id] = InMemoryChatMessageHistory()\n",
|
||||
" return chat_histories[session_id]\n",
|
||||
"\n",
|
||||
"# Currently set to streaming ...if one shot response is needed then comment base_chain and output_message_key and enable base_chain_with_output\n",
|
||||
"conversation_chain = RunnableWithMessageHistory(\n",
|
||||
" # base_chain_with_output,\n",
|
||||
" base_chain,\n",
|
||||
" get_history,\n",
|
||||
" output_messages_key=\"answer\", \n",
|
||||
" input_messages_key=\"input\",\n",
|
||||
" history_messages_key=\"chat_history\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "06b58566-70cb-42eb-8b1c-9fe353fe71f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def chat(question, history):\n",
|
||||
" try:\n",
|
||||
" # result = conversation_chain.invoke({\"input\": question, \"chat_history\": memory.buffer_as_messages})\n",
|
||||
" \n",
|
||||
" # memory.chat_memory.add_user_message(question)\n",
|
||||
" # memory.chat_memory.add_ai_message(result[\"answer\"])\n",
|
||||
"\n",
|
||||
" # return result[\"answer\"]\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" session_id = \"default-session\"\n",
|
||||
"\n",
|
||||
" # # FUll chat version\n",
|
||||
" # result = conversation_chain.invoke(\n",
|
||||
" # {\"input\": question},\n",
|
||||
" # config={\"configurable\": {\"session_id\": session_id}}\n",
|
||||
" # )\n",
|
||||
" # # print(result)\n",
|
||||
" # return result\n",
|
||||
"\n",
|
||||
" # Streaming Version\n",
|
||||
" response_buffer = \"\"\n",
|
||||
"\n",
|
||||
" for chunk in conversation_chain.stream({\"input\": question},config={\"configurable\": {\"session_id\": session_id}}):\n",
|
||||
" if \"answer\" in chunk:\n",
|
||||
" response_buffer += chunk[\"answer\"]\n",
|
||||
" yield response_buffer \n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"An error occurred during chat: {e}\")\n",
|
||||
" return \"I apologize, but I encountered an error and cannot answer that right now.\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a577ac66-3952-4821-83d2-8a50bad89971",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "56b63a17-2522-46e5-b5a3-e2e80e52a723",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
552
week5/community-contributions/Wk5-final-multi-doc-type-KB.ipynb
Normal file
552
week5/community-contributions/Wk5-final-multi-doc-type-KB.ipynb
Normal file
@@ -0,0 +1,552 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "61777022-631c-4db0-afeb-70d8d22bc07b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Summary:\n",
|
||||
"This is the project from week 5. The intention was to create a vector db of my own files (from an external drive) which can be used in a RAG solution.\n",
|
||||
"This includes a number of file types (docx, pdf, txt, epub...) and includes the ability to exclude folders.\n",
|
||||
"With the OpenAI embeddings API limit of 300k tokens, it was also necessary to create a batch embeddings process so that there were multiple requests.\n",
|
||||
"This was based on estimating the tokens with a text to token rate of 1:4, however it wasn't perfect and one of the batches still exceeded the 300k limit when running.\n",
|
||||
"I found that the responses from the llm were terrible in the end! I tried playing about with chunk sizes and the minimum # of chunks by llangchain and it did improve but was not fantastic. I also ensured the metadata was sent with each chunk to help.\n",
|
||||
"This really highlighted the real world challenges of implementing RAG!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d78ef79d-e564-4c56-82f3-0485e4bf6986",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install docx2txt\n",
|
||||
"!pip install ebooklib\n",
|
||||
"!pip install python-pptx\n",
|
||||
"!pip install pypdf"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9ec98119-456f-450c-a9a2-f375d74f5ce5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import requests\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"import glob\n",
|
||||
"import gradio as gr\n",
|
||||
"import time\n",
|
||||
"from typing import List"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ac14410b-8c3c-4cf5-900e-fd4c33cdf2b2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports for langchain, plotly and Chroma\n",
|
||||
"\n",
|
||||
"from langchain.document_loaders import (\n",
|
||||
" DirectoryLoader,\n",
|
||||
" Docx2txtLoader,\n",
|
||||
" TextLoader,\n",
|
||||
" PyPDFLoader,\n",
|
||||
" UnstructuredExcelLoader,\n",
|
||||
" BSHTMLLoader\n",
|
||||
")\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter\n",
|
||||
"from langchain.schema import Document\n",
|
||||
"from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n",
|
||||
"from langchain_chroma import Chroma\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from sklearn.manifold import TSNE\n",
|
||||
"import numpy as np\n",
|
||||
"import plotly.graph_objects as go\n",
|
||||
"from langchain.memory import ConversationBufferMemory\n",
|
||||
"from langchain.chains import ConversationalRetrievalChain\n",
|
||||
"from langchain.embeddings import HuggingFaceEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3be698e7-71e1-4c75-9696-e1651e4bf357",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"MODEL = \"gpt-4o-mini\"\n",
|
||||
"db_name = \"vector_db\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6f850068-c05b-4526-9494-034b0077347e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load environment variables in a file called .env\n",
|
||||
"\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0c5baad2-2033-40a6-8ebd-5861b5cf4350",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# handling epubs\n",
|
||||
"\n",
|
||||
"from ebooklib import epub\n",
|
||||
"from bs4 import BeautifulSoup\n",
|
||||
"from langchain.document_loaders.base import BaseLoader\n",
|
||||
"\n",
|
||||
"class EpubLoader(BaseLoader):\n",
|
||||
" def __init__(self, file_path: str):\n",
|
||||
" self.file_path = file_path\n",
|
||||
"\n",
|
||||
" def load(self) -> list[Document]:\n",
|
||||
" book = epub.read_epub(self.file_path)\n",
|
||||
" text = ''\n",
|
||||
" for item in book.get_items():\n",
|
||||
" if item.get_type() == epub.EpubHtml:\n",
|
||||
" soup = BeautifulSoup(item.get_content(), 'html.parser')\n",
|
||||
" text += soup.get_text() + '\\n'\n",
|
||||
"\n",
|
||||
" return [Document(page_content=text, metadata={\"source\": self.file_path})]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bd8b0e4e-d698-4484-bc94-d8b753f386cc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# handling pptx\n",
|
||||
"\n",
|
||||
"from pptx import Presentation\n",
|
||||
"\n",
|
||||
"class PptxLoader(BaseLoader):\n",
|
||||
" def __init__(self, file_path: str):\n",
|
||||
" self.file_path = file_path\n",
|
||||
"\n",
|
||||
" def load(self) -> list[Document]:\n",
|
||||
" prs = Presentation(self.file_path)\n",
|
||||
" text = ''\n",
|
||||
" for slide in prs.slides:\n",
|
||||
" for shape in slide.shapes:\n",
|
||||
" if hasattr(shape, \"text\") and shape.text:\n",
|
||||
" text += shape.text + '\\n'\n",
|
||||
"\n",
|
||||
" return [Document(page_content=text, metadata={\"source\": self.file_path})]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b222b01d-6040-4ff3-a0e3-290819cfe94b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Class based version of document loader which can be expanded more easily for other document types. (Currently includes file types: docx, txt (windows encoding), xlsx, pdfs, epubs, pptx)\n",
|
||||
"\n",
|
||||
"class DocumentLoader:\n",
|
||||
" \"\"\"A clean, extensible document loader for multiple file types.\"\"\"\n",
|
||||
" \n",
|
||||
" def __init__(self, base_path=\"D:/*\", exclude_folders=None):\n",
|
||||
" self.base_path = base_path\n",
|
||||
" self.documents = []\n",
|
||||
" self.exclude_folders = exclude_folders or []\n",
|
||||
" \n",
|
||||
" # Configuration for different file types\n",
|
||||
" self.loader_config = {\n",
|
||||
" 'docx': {\n",
|
||||
" 'loader_cls': Docx2txtLoader,\n",
|
||||
" 'glob_pattern': \"**/*.docx\",\n",
|
||||
" 'loader_kwargs': {},\n",
|
||||
" 'post_process': None\n",
|
||||
" },\n",
|
||||
" 'txt': {\n",
|
||||
" 'loader_cls': TextLoader,\n",
|
||||
" 'glob_pattern': \"**/*.txt\",\n",
|
||||
" 'loader_kwargs': {\"encoding\": \"cp1252\"},\n",
|
||||
" 'post_process': None\n",
|
||||
" },\n",
|
||||
" 'pdf': {\n",
|
||||
" 'loader_cls': PyPDFLoader,\n",
|
||||
" 'glob_pattern': \"**/*.pdf\",\n",
|
||||
" 'loader_kwargs': {},\n",
|
||||
" 'post_process': None\n",
|
||||
" },\n",
|
||||
" 'xlsx': {\n",
|
||||
" 'loader_cls': UnstructuredExcelLoader,\n",
|
||||
" 'glob_pattern': \"**/*.xlsx\",\n",
|
||||
" 'loader_kwargs': {},\n",
|
||||
" 'post_process': None\n",
|
||||
" },\n",
|
||||
" 'html': {\n",
|
||||
" 'loader_cls': BSHTMLLoader,\n",
|
||||
" 'glob_pattern': \"**/*.html\",\n",
|
||||
" 'loader_kwargs': {},\n",
|
||||
" 'post_process': None\n",
|
||||
" },\n",
|
||||
" 'epub': {\n",
|
||||
" 'loader_cls': EpubLoader,\n",
|
||||
" 'glob_pattern': \"**/*.epub\",\n",
|
||||
" 'loader_kwargs': {},\n",
|
||||
" 'post_process': self._process_epub_metadata\n",
|
||||
" },\n",
|
||||
" 'pptx': {\n",
|
||||
" 'loader_cls': PptxLoader,\n",
|
||||
" 'glob_pattern': \"**/*.pptx\",\n",
|
||||
" 'loader_kwargs': {},\n",
|
||||
" 'post_process': None\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" def _get_epub_metadata(self, file_path):\n",
|
||||
" \"\"\"Extract metadata from EPUB files.\"\"\"\n",
|
||||
" try:\n",
|
||||
" book = epub.read_epub(file_path)\n",
|
||||
" title = book.get_metadata('DC', 'title')[0][0] if book.get_metadata('DC', 'title') else None\n",
|
||||
" author = book.get_metadata('DC', 'creator')[0][0] if book.get_metadata('DC', 'creator') else None\n",
|
||||
" return title, author\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Error extracting EPUB metadata: {e}\")\n",
|
||||
" return None, None\n",
|
||||
" \n",
|
||||
" def _process_epub_metadata(self, doc) -> None:\n",
|
||||
" \"\"\"Post-process EPUB documents to add metadata.\"\"\"\n",
|
||||
" title, author = self._get_epub_metadata(doc.metadata['source'])\n",
|
||||
" doc.metadata[\"author\"] = author\n",
|
||||
" doc.metadata[\"title\"] = title\n",
|
||||
" \n",
|
||||
" def _load_file_type(self, folder, file_type, config):\n",
|
||||
" \"\"\"Load documents of a specific file type from a folder.\"\"\"\n",
|
||||
" try:\n",
|
||||
" loader = DirectoryLoader(\n",
|
||||
" folder, \n",
|
||||
" glob=config['glob_pattern'], \n",
|
||||
" loader_cls=config['loader_cls'],\n",
|
||||
" loader_kwargs=config['loader_kwargs']\n",
|
||||
" )\n",
|
||||
" docs = loader.load()\n",
|
||||
" print(f\" Found {len(docs)} .{file_type} files\")\n",
|
||||
" \n",
|
||||
" # Apply post-processing if defined\n",
|
||||
" if config['post_process']:\n",
|
||||
" for doc in docs:\n",
|
||||
" config['post_process'](doc)\n",
|
||||
" \n",
|
||||
" return docs\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\" Error loading .{file_type} files: {e}\")\n",
|
||||
" return []\n",
|
||||
" \n",
|
||||
" def load_all(self):\n",
|
||||
" \"\"\"Load all documents from configured folders.\"\"\"\n",
|
||||
" all_folders = [f for f in glob.glob(self.base_path) if os.path.isdir(f)]\n",
|
||||
"\n",
|
||||
" #filter out excluded folders\n",
|
||||
" folders = []\n",
|
||||
" for folder in all_folders:\n",
|
||||
" folder_name = os.path.basename(folder)\n",
|
||||
" if folder_name not in self.exclude_folders:\n",
|
||||
" folders.append(folder)\n",
|
||||
" else:\n",
|
||||
" print(f\"Excluded folder: {folder_name}\")\n",
|
||||
" \n",
|
||||
" print(\"Scanning folders (directories only):\", folders)\n",
|
||||
" \n",
|
||||
" self.documents = []\n",
|
||||
" \n",
|
||||
" for folder in folders:\n",
|
||||
" doc_type = os.path.basename(folder)\n",
|
||||
" print(f\"\\nProcessing folder: {doc_type}\")\n",
|
||||
" \n",
|
||||
" for file_type, config in self.loader_config.items():\n",
|
||||
" docs = self._load_file_type(folder, file_type, config)\n",
|
||||
" \n",
|
||||
" # Add doc_type metadata to all documents\n",
|
||||
" for doc in docs:\n",
|
||||
" doc.metadata[\"doc_type\"] = doc_type\n",
|
||||
" self.documents.append(doc)\n",
|
||||
" \n",
|
||||
" print(f\"\\nTotal documents loaded: {len(self.documents)}\")\n",
|
||||
" return self.documents\n",
|
||||
" \n",
|
||||
" def add_file_type(self, extension, loader_cls, glob_pattern=None, \n",
|
||||
" loader_kwargs=None, post_process=None):\n",
|
||||
" \"\"\"Add support for a new file type.\"\"\"\n",
|
||||
" self.loader_config[extension] = {\n",
|
||||
" 'loader_cls': loader_cls,\n",
|
||||
" 'glob_pattern': glob_pattern or f\"**/*.{extension}\",\n",
|
||||
" 'loader_kwargs': loader_kwargs or {},\n",
|
||||
" 'post_process': post_process\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"# load\n",
|
||||
"loader = DocumentLoader(\"D:/*\", exclude_folders=[\"Music\", \"Online Courses\", \"Fitness\"])\n",
|
||||
"documents = loader.load_all()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3fd43a4f-b623-4b08-89eb-27d3b3ba0f62",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# create batches (this was required as the # of tokens was exceed the openai request limit)\n",
|
||||
"\n",
|
||||
"def estimate_tokens(text, chars_per_token=4):\n",
|
||||
" \"\"\"Rough estimate of tokens from character count.\"\"\"\n",
|
||||
" return len(text) // chars_per_token\n",
|
||||
"\n",
|
||||
"def create_batches(chunks, max_tokens_per_batch=250000):\n",
|
||||
" batches = []\n",
|
||||
" current_batch = []\n",
|
||||
" current_tokens = 0\n",
|
||||
" \n",
|
||||
" for chunk in chunks:\n",
|
||||
" chunk_tokens = estimate_tokens(chunk.page_content)\n",
|
||||
" \n",
|
||||
" # If adding this chunk would exceed the limit, start a new batch\n",
|
||||
" if current_tokens + chunk_tokens > max_tokens_per_batch and current_batch:\n",
|
||||
" batches.append(current_batch)\n",
|
||||
" current_batch = [chunk]\n",
|
||||
" current_tokens = chunk_tokens\n",
|
||||
" else:\n",
|
||||
" current_batch.append(chunk)\n",
|
||||
" current_tokens += chunk_tokens\n",
|
||||
" \n",
|
||||
" # Add the last batch if it has content\n",
|
||||
" if current_batch:\n",
|
||||
" batches.append(current_batch)\n",
|
||||
" \n",
|
||||
" return batches\n",
|
||||
"\n",
|
||||
"def create_vectorstore_with_progress(chunks, embeddings, db_name, batch_size_tokens=250000):\n",
|
||||
" \n",
|
||||
" # Delete existing database if it exists\n",
|
||||
" if os.path.exists(db_name):\n",
|
||||
" print(f\"Deleting existing database: {db_name}\")\n",
|
||||
" Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()\n",
|
||||
" \n",
|
||||
" # Create batches\n",
|
||||
" batches = create_batches(chunks, batch_size_tokens)\n",
|
||||
" print(f\"Created {len(batches)} batches from {len(chunks)} chunks\")\n",
|
||||
" \n",
|
||||
" # Show batch sizes\n",
|
||||
" for i, batch in enumerate(batches):\n",
|
||||
" total_chars = sum(len(chunk.page_content) for chunk in batch)\n",
|
||||
" estimated_tokens = estimate_tokens(''.join(chunk.page_content for chunk in batch))\n",
|
||||
" print(f\" Batch {i+1}: {len(batch)} chunks, ~{estimated_tokens:,} tokens\")\n",
|
||||
" \n",
|
||||
" vectorstore = None\n",
|
||||
" successful_batches = 0\n",
|
||||
" failed_batches = 0\n",
|
||||
" \n",
|
||||
" for i, batch in enumerate(batches):\n",
|
||||
" print(f\"\\n{'='*50}\")\n",
|
||||
" print(f\"Processing batch {i+1}/{len(batches)}\")\n",
|
||||
" print(f\"{'='*50}\")\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" start_time = time.time()\n",
|
||||
" \n",
|
||||
" if vectorstore is None:\n",
|
||||
" # Create the initial vectorstore\n",
|
||||
" vectorstore = Chroma.from_documents(\n",
|
||||
" documents=batch,\n",
|
||||
" embedding=embeddings,\n",
|
||||
" persist_directory=db_name\n",
|
||||
" )\n",
|
||||
" print(f\"Created initial vectorstore with {len(batch)} documents\")\n",
|
||||
" else:\n",
|
||||
" # Add to existing vectorstore\n",
|
||||
" vectorstore.add_documents(batch)\n",
|
||||
" print(f\"Added {len(batch)} documents to vectorstore\")\n",
|
||||
" \n",
|
||||
" successful_batches += 1\n",
|
||||
" elapsed = time.time() - start_time\n",
|
||||
" print(f\"Processed in {elapsed:.1f} seconds\")\n",
|
||||
" print(f\"Total documents in vectorstore: {vectorstore._collection.count()}\")\n",
|
||||
" \n",
|
||||
" # Rate limiting delay\n",
|
||||
" time.sleep(2)\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" failed_batches += 1\n",
|
||||
" print(f\"Error processing batch {i+1}: {e}\")\n",
|
||||
" print(f\"Continuing with next batch...\")\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" print(f\"\\n{'='*50}\")\n",
|
||||
" print(f\"SUMMARY\")\n",
|
||||
" print(f\"{'='*50}\")\n",
|
||||
" print(f\"Successful batches: {successful_batches}/{len(batches)}\")\n",
|
||||
" print(f\"Failed batches: {failed_batches}/{len(batches)}\")\n",
|
||||
" \n",
|
||||
" if vectorstore:\n",
|
||||
" final_count = vectorstore._collection.count()\n",
|
||||
" print(f\"Final vectorstore contains: {final_count} documents\")\n",
|
||||
" return vectorstore\n",
|
||||
" else:\n",
|
||||
" print(\"Failed to create vectorstore\")\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
"# include metadata\n",
|
||||
"def add_metadata_to_content(doc: Document) -> Document:\n",
|
||||
" metadata_lines = []\n",
|
||||
" if \"doc_type\" in doc.metadata:\n",
|
||||
" metadata_lines.append(f\"Document Type: {doc.metadata['doc_type']}\")\n",
|
||||
" if \"title\" in doc.metadata:\n",
|
||||
" metadata_lines.append(f\"Title: {doc.metadata['title']}\")\n",
|
||||
" if \"author\" in doc.metadata:\n",
|
||||
" metadata_lines.append(f\"Author: {doc.metadata['author']}\")\n",
|
||||
" metadata_text = \"\\n\".join(metadata_lines)\n",
|
||||
"\n",
|
||||
" new_content = f\"{metadata_text}\\n\\n{doc.page_content}\"\n",
|
||||
" return Document(page_content=new_content, metadata=doc.metadata)\n",
|
||||
"\n",
|
||||
"# Apply to all documents before chunking\n",
|
||||
"documents_with_metadata = [add_metadata_to_content(doc) for doc in documents]\n",
|
||||
"\n",
|
||||
"# Chunking\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=200)\n",
|
||||
"chunks = text_splitter.split_documents(documents_with_metadata)\n",
|
||||
"\n",
|
||||
"# Embedding\n",
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"\n",
|
||||
"# Store in vector DB\n",
|
||||
"print(\"Creating vectorstore in batches...\")\n",
|
||||
"vectorstore = create_vectorstore_with_progress(\n",
|
||||
" chunks=chunks,\n",
|
||||
" embeddings=embeddings, \n",
|
||||
" db_name=db_name,\n",
|
||||
" batch_size_tokens=250000\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"if vectorstore:\n",
|
||||
" print(f\"Successfully created vectorstore with {vectorstore._collection.count()} documents\")\n",
|
||||
"else:\n",
|
||||
" print(\"Failed to create vectorstore\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "46c29b11-2ae3-4f6b-901d-5de67a09fd49",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# create a new Chat with OpenAI\n",
|
||||
"llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n",
|
||||
"\n",
|
||||
"# set up the conversation memory for the chat\n",
|
||||
"memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
|
||||
"\n",
|
||||
"# the retriever is an abstraction over the VectorStore that will be used during RAG\n",
|
||||
"retriever = vectorstore.as_retriever(search_kwargs={\"k\": 200})\n",
|
||||
"\n",
|
||||
"# putting it together: set up the conversation chain with the GPT 3.5 LLM, the vector store and memory\n",
|
||||
"conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "be163251-0dfa-4f50-ab05-43c6c0833405",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Wrapping that in a function\n",
|
||||
"\n",
|
||||
"def chat(question, history):\n",
|
||||
" result = conversation_chain.invoke({\"question\": question})\n",
|
||||
" return result[\"answer\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a6320402-8213-47ec-8b05-dda234052274",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# And in Gradio:\n",
|
||||
"\n",
|
||||
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "717e010b-8d7e-4a43-8cb1-9688ffdd76b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's investigate what gets sent behind the scenes\n",
|
||||
"\n",
|
||||
"# from langchain_core.callbacks import StdOutCallbackHandler\n",
|
||||
"\n",
|
||||
"# llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n",
|
||||
"\n",
|
||||
"# memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
|
||||
"\n",
|
||||
"# retriever = vectorstore.as_retriever(search_kwargs={\"k\": 200})\n",
|
||||
"\n",
|
||||
"# conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory, callbacks=[StdOutCallbackHandler()])\n",
|
||||
"\n",
|
||||
"# query = \"Can you name some authors?\"\n",
|
||||
"# result = conversation_chain.invoke({\"question\": query})\n",
|
||||
"# answer = result[\"answer\"]\n",
|
||||
"# print(\"\\nAnswer:\", answer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2333a77e-8d32-4cc2-8ae9-f8e7a979b3ae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
472
week5/community-contributions/day5_gmailRAG.ipynb
Normal file
472
week5/community-contributions/day5_gmailRAG.ipynb
Normal file
@@ -0,0 +1,472 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dfe37963-1af6-44fc-a841-8e462443f5e6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## gmail RAG assistant"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ba2779af-84ef-4227-9e9e-6eaf0df87e77",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import glob\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"import gradio as gr\n",
|
||||
"# NEW IMPORTS FOR GMAIL\n",
|
||||
"from google.auth.transport.requests import Request\n",
|
||||
"from google.oauth2.credentials import Credentials\n",
|
||||
"from google_auth_oauthlib.flow import InstalledAppFlow\n",
|
||||
"from googleapiclient.discovery import build\n",
|
||||
"from datetime import datetime\n",
|
||||
"import base64\n",
|
||||
"from email.mime.text import MIMEText\n",
|
||||
"import re"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "802137aa-8a74-45e0-a487-d1974927d7ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports for langchain, plotly and Chroma\n",
|
||||
"\n",
|
||||
"from langchain.document_loaders import DirectoryLoader, TextLoader\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.schema import Document\n",
|
||||
"from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n",
|
||||
"from langchain_chroma import Chroma\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from sklearn.manifold import TSNE\n",
|
||||
"import numpy as np\n",
|
||||
"import plotly.graph_objects as go\n",
|
||||
"from langchain.memory import ConversationBufferMemory\n",
|
||||
"from langchain.chains import ConversationalRetrievalChain\n",
|
||||
"from langchain.embeddings import HuggingFaceEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "58c85082-e417-4708-9efe-81a5d55d1424",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# price is a factor for our company, so we're going to use a low cost model\n",
|
||||
"\n",
|
||||
"MODEL = \"gpt-4o-mini\"\n",
|
||||
"db_name = \"vector_db\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ee78efcb-60fe-449e-a944-40bab26261af",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load environment variables in a file called .env\n",
|
||||
"\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
|
||||
"# NEW: Gmail API credentials\n",
|
||||
"SCOPES = ['https://www.googleapis.com/auth/gmail.readonly']\n",
|
||||
"CREDENTIALS_FILE = 'credentials.json' # Download from Google Cloud Console\n",
|
||||
"TOKEN_FILE = 'token.json'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "730711a9-6ffe-4eee-8f48-d6cfb7314905",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Read in emails using LangChain's loaders\n",
|
||||
"# IMPORTANT: set the email received date range hard-coded below\n",
|
||||
"\n",
|
||||
"def authenticate_gmail():\n",
|
||||
" \"\"\"Authenticate and return Gmail service object\"\"\"\n",
|
||||
" creds = None\n",
|
||||
" if os.path.exists(TOKEN_FILE):\n",
|
||||
" creds = Credentials.from_authorized_user_file(TOKEN_FILE, SCOPES)\n",
|
||||
" \n",
|
||||
" if not creds or not creds.valid:\n",
|
||||
" if creds and creds.expired and creds.refresh_token:\n",
|
||||
" creds.refresh(Request())\n",
|
||||
" else:\n",
|
||||
" flow = InstalledAppFlow.from_client_secrets_file(CREDENTIALS_FILE, SCOPES)\n",
|
||||
" creds = flow.run_local_server(port=0)\n",
|
||||
" \n",
|
||||
" with open(TOKEN_FILE, 'w') as token:\n",
|
||||
" token.write(creds.to_json())\n",
|
||||
" \n",
|
||||
" return build('gmail', 'v1', credentials=creds)\n",
|
||||
"\n",
|
||||
"def get_email_content(service, message_id):\n",
|
||||
" \"\"\"Extract email content from message\"\"\"\n",
|
||||
" try:\n",
|
||||
" message = service.users().messages().get(userId='me', id=message_id, format='full').execute()\n",
|
||||
" \n",
|
||||
" # Extract basic info\n",
|
||||
" headers = message['payload'].get('headers', [])\n",
|
||||
" subject = next((h['value'] for h in headers if h['name'] == 'Subject'), 'No Subject')\n",
|
||||
" sender = next((h['value'] for h in headers if h['name'] == 'From'), 'Unknown Sender')\n",
|
||||
" date = next((h['value'] for h in headers if h['name'] == 'Date'), 'Unknown Date')\n",
|
||||
" \n",
|
||||
" # Extract body\n",
|
||||
" body = \"\"\n",
|
||||
" if 'parts' in message['payload']:\n",
|
||||
" for part in message['payload']['parts']:\n",
|
||||
" if part['mimeType'] == 'text/plain':\n",
|
||||
" data = part['body']['data']\n",
|
||||
" body = base64.urlsafe_b64decode(data).decode('utf-8')\n",
|
||||
" break\n",
|
||||
" else:\n",
|
||||
" if message['payload']['body'].get('data'):\n",
|
||||
" body = base64.urlsafe_b64decode(message['payload']['body']['data']).decode('utf-8')\n",
|
||||
" \n",
|
||||
" # Clean up body text\n",
|
||||
" body = re.sub(r'\\s+', ' ', body).strip()\n",
|
||||
" \n",
|
||||
" return {\n",
|
||||
" 'subject': subject,\n",
|
||||
" 'sender': sender,\n",
|
||||
" 'date': date,\n",
|
||||
" 'body': body,\n",
|
||||
" 'id': message_id\n",
|
||||
" }\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Error processing message {message_id}: {str(e)}\")\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
"def load_gmail_documents(start_date, end_date, max_emails=100):\n",
|
||||
" \"\"\"Load emails from Gmail between specified dates\"\"\"\n",
|
||||
" service = authenticate_gmail()\n",
|
||||
" \n",
|
||||
" # Format dates for Gmail API (YYYY/MM/DD)\n",
|
||||
" start_date_str = start_date.strftime('%Y/%m/%d')\n",
|
||||
" end_date_str = end_date.strftime('%Y/%m/%d')\n",
|
||||
" \n",
|
||||
" # Build query\n",
|
||||
" query = f'after:{start_date_str} before:{end_date_str}'\n",
|
||||
" \n",
|
||||
" # Get message list\n",
|
||||
" result = service.users().messages().list(userId='me', q=query, maxResults=max_emails).execute()\n",
|
||||
" messages = result.get('messages', [])\n",
|
||||
" \n",
|
||||
" print(f\"Found {len(messages)} emails between {start_date_str} and {end_date_str}\")\n",
|
||||
" \n",
|
||||
" # Convert to LangChain documents\n",
|
||||
" documents = []\n",
|
||||
" for i, message in enumerate(messages):\n",
|
||||
" print(f\"Processing email {i+1}/{len(messages)}\")\n",
|
||||
" email_data = get_email_content(service, message['id'])\n",
|
||||
" \n",
|
||||
" if email_data and email_data['body']:\n",
|
||||
" # Create document content\n",
|
||||
" content = f\"\"\"Subject: {email_data['subject']}\n",
|
||||
"From: {email_data['sender']}\n",
|
||||
"Date: {email_data['date']}\n",
|
||||
"\n",
|
||||
"{email_data['body']}\"\"\"\n",
|
||||
" \n",
|
||||
" # Create LangChain document\n",
|
||||
" doc = Document(\n",
|
||||
" page_content=content,\n",
|
||||
" metadata={\n",
|
||||
" \"doc_type\": \"email\",\n",
|
||||
" \"subject\": email_data['subject'],\n",
|
||||
" \"sender\": email_data['sender'],\n",
|
||||
" \"date\": email_data['date'],\n",
|
||||
" \"message_id\": email_data['id']\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" documents.append(doc)\n",
|
||||
" \n",
|
||||
" return documents\n",
|
||||
"\n",
|
||||
"# SET YOUR DATE RANGE HERE\n",
|
||||
"start_date = datetime(2025, 6, 20) # YYYY, MM, DD\n",
|
||||
"end_date = datetime(2025, 6, 26) # YYYY, MM, DD\n",
|
||||
"\n",
|
||||
"# Load Gmail documents \n",
|
||||
"documents = load_gmail_documents(start_date, end_date, max_emails=200)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c59de72d-f965-44b3-8487-283e4c623b1d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
|
||||
"chunks = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"print(f\"Total number of chunks: {len(chunks)}\")\n",
|
||||
"print(f\"Document types found: {set(doc.metadata['doc_type'] for doc in documents)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "78998399-ac17-4e28-b15f-0b5f51e6ee23",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Put the chunks of data into a Vector Store that associates a Vector Embedding with each chunk\n",
|
||||
"# Chroma is a popular open source Vector Database based on SQLLite\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"\n",
|
||||
"# If you would rather use the free Vector Embeddings from HuggingFace sentence-transformers\n",
|
||||
"# Then replace embeddings = OpenAIEmbeddings()\n",
|
||||
"# with:\n",
|
||||
"# from langchain.embeddings import HuggingFaceEmbeddings\n",
|
||||
"# embeddings = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-MiniLM-L6-v2\")\n",
|
||||
"\n",
|
||||
"# Delete if already exists\n",
|
||||
"\n",
|
||||
"if os.path.exists(db_name):\n",
|
||||
" Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()\n",
|
||||
"\n",
|
||||
"# Create vectorstore\n",
|
||||
"\n",
|
||||
"vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=db_name)\n",
|
||||
"print(f\"Vectorstore created with {vectorstore._collection.count()} documents\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ff2e7687-60d4-4920-a1d7-a34b9f70a250",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's investigate the vectors\n",
|
||||
"\n",
|
||||
"collection = vectorstore._collection\n",
|
||||
"count = collection.count()\n",
|
||||
"\n",
|
||||
"sample_embedding = collection.get(limit=1, include=[\"embeddings\"])[\"embeddings\"][0]\n",
|
||||
"dimensions = len(sample_embedding)\n",
|
||||
"print(f\"There are {count:,} vectors with {dimensions:,} dimensions in the vector store\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0d45462-a818-441c-b010-b85b32bcf618",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Visualizing the Vector Store\n",
|
||||
"\n",
|
||||
"Let's take a minute to look at the documents and their embedding vectors to see what's going on."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b98adf5e-d464-4bd2-9bdf-bc5b6770263b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Prework (with thanks to Jon R for identifying and fixing a bug in this!)\n",
|
||||
"\n",
|
||||
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
|
||||
"vectors = np.array(result['embeddings'])\n",
|
||||
"documents = result['documents']\n",
|
||||
"metadatas = result['metadatas']\n",
|
||||
"\n",
|
||||
"# Alternatively, color by sender:\n",
|
||||
"senders = [metadata.get('sender', 'unknown') for metadata in metadatas]\n",
|
||||
"unique_senders = list(set(senders))\n",
|
||||
"sender_colors = ['blue', 'green', 'red', 'orange', 'purple', 'brown', 'pink', 'gray']\n",
|
||||
"colors = [sender_colors[unique_senders.index(sender) % len(sender_colors)] for sender in senders]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "427149d5-e5d8-4abd-bb6f-7ef0333cca21",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We humans find it easier to visalize things in 2D!\n",
|
||||
"# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
|
||||
"# (t-distributed stochastic neighbor embedding)\n",
|
||||
"\n",
|
||||
"tsne = TSNE(n_components=2, random_state=42)\n",
|
||||
"reduced_vectors = tsne.fit_transform(vectors)\n",
|
||||
"\n",
|
||||
"# Create the 2D scatter plot\n",
|
||||
"fig = go.Figure(data=[go.Scatter(\n",
|
||||
" x=reduced_vectors[:, 0],\n",
|
||||
" y=reduced_vectors[:, 1],\n",
|
||||
" mode='markers',\n",
|
||||
" marker=dict(size=5, color=colors, opacity=0.8),\n",
|
||||
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(senders, documents)],\n",
|
||||
" hoverinfo='text'\n",
|
||||
")])\n",
|
||||
"\n",
|
||||
"fig.update_layout(\n",
|
||||
" title='2D Chroma Vector Store Visualization',\n",
|
||||
" scene=dict(xaxis_title='x',yaxis_title='y'),\n",
|
||||
" width=800,\n",
|
||||
" height=600,\n",
|
||||
" margin=dict(r=20, b=10, l=10, t=40)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e1418e88-acd5-460a-bf2b-4e6efc88e3dd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's try 3D!\n",
|
||||
"\n",
|
||||
"tsne = TSNE(n_components=3, random_state=42)\n",
|
||||
"reduced_vectors = tsne.fit_transform(vectors)\n",
|
||||
"\n",
|
||||
"# Create the 3D scatter plot\n",
|
||||
"fig = go.Figure(data=[go.Scatter3d(\n",
|
||||
" x=reduced_vectors[:, 0],\n",
|
||||
" y=reduced_vectors[:, 1],\n",
|
||||
" z=reduced_vectors[:, 2],\n",
|
||||
" mode='markers',\n",
|
||||
" marker=dict(size=5, color=colors, opacity=0.8),\n",
|
||||
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(senders, documents)],\n",
|
||||
" hoverinfo='text'\n",
|
||||
")])\n",
|
||||
"\n",
|
||||
"fig.update_layout(\n",
|
||||
" title='3D Chroma Vector Store Visualization',\n",
|
||||
" scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),\n",
|
||||
" width=900,\n",
|
||||
" height=700,\n",
|
||||
" margin=dict(r=20, b=10, l=10, t=40)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bbbcb659-13ce-47ab-8a5e-01b930494964",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Langchain and Gradio to prototype a chat with the LLM\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d72567e8-f891-4797-944b-4612dc6613b1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.chains.combine_documents import create_stuff_documents_chain\n",
|
||||
"from langchain.chains import create_retrieval_chain\n",
|
||||
"\n",
|
||||
"# create a new Chat with OpenAI\n",
|
||||
"llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n",
|
||||
"\n",
|
||||
"# Alternative - if you'd like to use Ollama locally, uncomment this line instead\n",
|
||||
"# llm = ChatOpenAI(temperature=0.7, model_name='llama3.2', base_url='http://localhost:11434/v1', api_key='ollama')\n",
|
||||
"\n",
|
||||
"# change LLM standard prompt (standard prompt defaults the answer to be 'I don't know' too often, especially when using a small LLM\n",
|
||||
"\n",
|
||||
"qa_prompt=PromptTemplate.from_template(\"Use the following pieces of context to answer the user's question. Answer as best you can given the information you have;\\\n",
|
||||
" if you have a reasonable idea of the answer,/then explain it and mention that you're unsure. \\\n",
|
||||
" But if you don't know the answer, don't make it up. \\\n",
|
||||
" {context} \\\n",
|
||||
" Question: {question} \\\n",
|
||||
" Helpful Answer:\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Wrap into a StuffDocumentsChain, matching the variable name 'context'\n",
|
||||
"combine_docs_chain = create_stuff_documents_chain(\n",
|
||||
" llm=llm,\n",
|
||||
" prompt=qa_prompt,\n",
|
||||
" document_variable_name=\"context\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# set up the conversation memory for the chat\n",
|
||||
"#memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
|
||||
"memory = ConversationBufferMemory(\n",
|
||||
" memory_key='chat_history', \n",
|
||||
" return_messages=True,\n",
|
||||
" output_key='answer' \n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# the retriever is an abstraction over the VectorStore that will be used during RAG\n",
|
||||
"retriever = vectorstore.as_retriever(search_kwargs={\"k\": 10})\n",
|
||||
"\n",
|
||||
"# putting it together: set up the conversation chain with the GPT 3.5 LLM, the vector store and memory\n",
|
||||
"# conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)\n",
|
||||
"\n",
|
||||
"conversation_chain = ConversationalRetrievalChain.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" retriever=retriever,\n",
|
||||
" memory=memory,\n",
|
||||
" combine_docs_chain_kwargs={\"prompt\": qa_prompt},\n",
|
||||
" return_source_documents=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"def chat(question, history):\n",
|
||||
" result = conversation_chain.invoke({\"question\": question})\n",
|
||||
" return result[\"answer\"]\n",
|
||||
"\n",
|
||||
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fe4229aa-6afe-4592-93a4-71a47ab69846",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user