473 lines
17 KiB
Plaintext
473 lines
17 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "dfe37963-1af6-44fc-a841-8e462443f5e6",
|
|
"metadata": {},
|
|
"source": [
|
|
"## gmail RAG assistant"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ba2779af-84ef-4227-9e9e-6eaf0df87e77",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# imports\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import glob\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"import gradio as gr\n",
|
|
"# NEW IMPORTS FOR GMAIL\n",
|
|
"from google.auth.transport.requests import Request\n",
|
|
"from google.oauth2.credentials import Credentials\n",
|
|
"from google_auth_oauthlib.flow import InstalledAppFlow\n",
|
|
"from googleapiclient.discovery import build\n",
|
|
"from datetime import datetime\n",
|
|
"import base64\n",
|
|
"from email.mime.text import MIMEText\n",
|
|
"import re"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "802137aa-8a74-45e0-a487-d1974927d7ca",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# imports for langchain, plotly and Chroma\n",
|
|
"\n",
|
|
"from langchain.document_loaders import DirectoryLoader, TextLoader\n",
|
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
|
"from langchain.schema import Document\n",
|
|
"from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n",
|
|
"from langchain_chroma import Chroma\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from sklearn.manifold import TSNE\n",
|
|
"import numpy as np\n",
|
|
"import plotly.graph_objects as go\n",
|
|
"from langchain.memory import ConversationBufferMemory\n",
|
|
"from langchain.chains import ConversationalRetrievalChain\n",
|
|
"from langchain.embeddings import HuggingFaceEmbeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "58c85082-e417-4708-9efe-81a5d55d1424",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# price is a factor for our company, so we're going to use a low cost model\n",
|
|
"\n",
|
|
"MODEL = \"gpt-4o-mini\"\n",
|
|
"db_name = \"vector_db\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ee78efcb-60fe-449e-a944-40bab26261af",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Load environment variables in a file called .env\n",
|
|
"\n",
|
|
"load_dotenv(override=True)\n",
|
|
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
|
|
"# NEW: Gmail API credentials\n",
|
|
"SCOPES = ['https://www.googleapis.com/auth/gmail.readonly']\n",
|
|
"CREDENTIALS_FILE = 'credentials.json' # Download from Google Cloud Console\n",
|
|
"TOKEN_FILE = 'token.json'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "730711a9-6ffe-4eee-8f48-d6cfb7314905",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Read in emails using LangChain's loaders\n",
|
|
"# IMPORTANT: set the email received date range hard-coded below\n",
|
|
"\n",
|
|
"def authenticate_gmail():\n",
|
|
" \"\"\"Authenticate and return Gmail service object\"\"\"\n",
|
|
" creds = None\n",
|
|
" if os.path.exists(TOKEN_FILE):\n",
|
|
" creds = Credentials.from_authorized_user_file(TOKEN_FILE, SCOPES)\n",
|
|
" \n",
|
|
" if not creds or not creds.valid:\n",
|
|
" if creds and creds.expired and creds.refresh_token:\n",
|
|
" creds.refresh(Request())\n",
|
|
" else:\n",
|
|
" flow = InstalledAppFlow.from_client_secrets_file(CREDENTIALS_FILE, SCOPES)\n",
|
|
" creds = flow.run_local_server(port=0)\n",
|
|
" \n",
|
|
" with open(TOKEN_FILE, 'w') as token:\n",
|
|
" token.write(creds.to_json())\n",
|
|
" \n",
|
|
" return build('gmail', 'v1', credentials=creds)\n",
|
|
"\n",
|
|
"def get_email_content(service, message_id):\n",
|
|
" \"\"\"Extract email content from message\"\"\"\n",
|
|
" try:\n",
|
|
" message = service.users().messages().get(userId='me', id=message_id, format='full').execute()\n",
|
|
" \n",
|
|
" # Extract basic info\n",
|
|
" headers = message['payload'].get('headers', [])\n",
|
|
" subject = next((h['value'] for h in headers if h['name'] == 'Subject'), 'No Subject')\n",
|
|
" sender = next((h['value'] for h in headers if h['name'] == 'From'), 'Unknown Sender')\n",
|
|
" date = next((h['value'] for h in headers if h['name'] == 'Date'), 'Unknown Date')\n",
|
|
" \n",
|
|
" # Extract body\n",
|
|
" body = \"\"\n",
|
|
" if 'parts' in message['payload']:\n",
|
|
" for part in message['payload']['parts']:\n",
|
|
" if part['mimeType'] == 'text/plain':\n",
|
|
" data = part['body']['data']\n",
|
|
" body = base64.urlsafe_b64decode(data).decode('utf-8')\n",
|
|
" break\n",
|
|
" else:\n",
|
|
" if message['payload']['body'].get('data'):\n",
|
|
" body = base64.urlsafe_b64decode(message['payload']['body']['data']).decode('utf-8')\n",
|
|
" \n",
|
|
" # Clean up body text\n",
|
|
" body = re.sub(r'\\s+', ' ', body).strip()\n",
|
|
" \n",
|
|
" return {\n",
|
|
" 'subject': subject,\n",
|
|
" 'sender': sender,\n",
|
|
" 'date': date,\n",
|
|
" 'body': body,\n",
|
|
" 'id': message_id\n",
|
|
" }\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error processing message {message_id}: {str(e)}\")\n",
|
|
" return None\n",
|
|
"\n",
|
|
"def load_gmail_documents(start_date, end_date, max_emails=100):\n",
|
|
" \"\"\"Load emails from Gmail between specified dates\"\"\"\n",
|
|
" service = authenticate_gmail()\n",
|
|
" \n",
|
|
" # Format dates for Gmail API (YYYY/MM/DD)\n",
|
|
" start_date_str = start_date.strftime('%Y/%m/%d')\n",
|
|
" end_date_str = end_date.strftime('%Y/%m/%d')\n",
|
|
" \n",
|
|
" # Build query\n",
|
|
" query = f'after:{start_date_str} before:{end_date_str}'\n",
|
|
" \n",
|
|
" # Get message list\n",
|
|
" result = service.users().messages().list(userId='me', q=query, maxResults=max_emails).execute()\n",
|
|
" messages = result.get('messages', [])\n",
|
|
" \n",
|
|
" print(f\"Found {len(messages)} emails between {start_date_str} and {end_date_str}\")\n",
|
|
" \n",
|
|
" # Convert to LangChain documents\n",
|
|
" documents = []\n",
|
|
" for i, message in enumerate(messages):\n",
|
|
" print(f\"Processing email {i+1}/{len(messages)}\")\n",
|
|
" email_data = get_email_content(service, message['id'])\n",
|
|
" \n",
|
|
" if email_data and email_data['body']:\n",
|
|
" # Create document content\n",
|
|
" content = f\"\"\"Subject: {email_data['subject']}\n",
|
|
"From: {email_data['sender']}\n",
|
|
"Date: {email_data['date']}\n",
|
|
"\n",
|
|
"{email_data['body']}\"\"\"\n",
|
|
" \n",
|
|
" # Create LangChain document\n",
|
|
" doc = Document(\n",
|
|
" page_content=content,\n",
|
|
" metadata={\n",
|
|
" \"doc_type\": \"email\",\n",
|
|
" \"subject\": email_data['subject'],\n",
|
|
" \"sender\": email_data['sender'],\n",
|
|
" \"date\": email_data['date'],\n",
|
|
" \"message_id\": email_data['id']\n",
|
|
" }\n",
|
|
" )\n",
|
|
" documents.append(doc)\n",
|
|
" \n",
|
|
" return documents\n",
|
|
"\n",
|
|
"# SET YOUR DATE RANGE HERE\n",
|
|
"start_date = datetime(2025, 6, 20) # YYYY, MM, DD\n",
|
|
"end_date = datetime(2025, 6, 26) # YYYY, MM, DD\n",
|
|
"\n",
|
|
"# Load Gmail documents \n",
|
|
"documents = load_gmail_documents(start_date, end_date, max_emails=200)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c59de72d-f965-44b3-8487-283e4c623b1d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
|
|
"chunks = text_splitter.split_documents(documents)\n",
|
|
"\n",
|
|
"print(f\"Total number of chunks: {len(chunks)}\")\n",
|
|
"print(f\"Document types found: {set(doc.metadata['doc_type'] for doc in documents)}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "78998399-ac17-4e28-b15f-0b5f51e6ee23",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Put the chunks of data into a Vector Store that associates a Vector Embedding with each chunk\n",
|
|
"# Chroma is a popular open source Vector Database based on SQLLite\n",
|
|
"\n",
|
|
"embeddings = OpenAIEmbeddings()\n",
|
|
"\n",
|
|
"# If you would rather use the free Vector Embeddings from HuggingFace sentence-transformers\n",
|
|
"# Then replace embeddings = OpenAIEmbeddings()\n",
|
|
"# with:\n",
|
|
"# from langchain.embeddings import HuggingFaceEmbeddings\n",
|
|
"# embeddings = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-MiniLM-L6-v2\")\n",
|
|
"\n",
|
|
"# Delete if already exists\n",
|
|
"\n",
|
|
"if os.path.exists(db_name):\n",
|
|
" Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()\n",
|
|
"\n",
|
|
"# Create vectorstore\n",
|
|
"\n",
|
|
"vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=db_name)\n",
|
|
"print(f\"Vectorstore created with {vectorstore._collection.count()} documents\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ff2e7687-60d4-4920-a1d7-a34b9f70a250",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's investigate the vectors\n",
|
|
"\n",
|
|
"collection = vectorstore._collection\n",
|
|
"count = collection.count()\n",
|
|
"\n",
|
|
"sample_embedding = collection.get(limit=1, include=[\"embeddings\"])[\"embeddings\"][0]\n",
|
|
"dimensions = len(sample_embedding)\n",
|
|
"print(f\"There are {count:,} vectors with {dimensions:,} dimensions in the vector store\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b0d45462-a818-441c-b010-b85b32bcf618",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Visualizing the Vector Store\n",
|
|
"\n",
|
|
"Let's take a minute to look at the documents and their embedding vectors to see what's going on."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b98adf5e-d464-4bd2-9bdf-bc5b6770263b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Prework (with thanks to Jon R for identifying and fixing a bug in this!)\n",
|
|
"\n",
|
|
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
|
|
"vectors = np.array(result['embeddings'])\n",
|
|
"documents = result['documents']\n",
|
|
"metadatas = result['metadatas']\n",
|
|
"\n",
|
|
"# Alternatively, color by sender:\n",
|
|
"senders = [metadata.get('sender', 'unknown') for metadata in metadatas]\n",
|
|
"unique_senders = list(set(senders))\n",
|
|
"sender_colors = ['blue', 'green', 'red', 'orange', 'purple', 'brown', 'pink', 'gray']\n",
|
|
"colors = [sender_colors[unique_senders.index(sender) % len(sender_colors)] for sender in senders]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "427149d5-e5d8-4abd-bb6f-7ef0333cca21",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# We humans find it easier to visalize things in 2D!\n",
|
|
"# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
|
|
"# (t-distributed stochastic neighbor embedding)\n",
|
|
"\n",
|
|
"tsne = TSNE(n_components=2, random_state=42)\n",
|
|
"reduced_vectors = tsne.fit_transform(vectors)\n",
|
|
"\n",
|
|
"# Create the 2D scatter plot\n",
|
|
"fig = go.Figure(data=[go.Scatter(\n",
|
|
" x=reduced_vectors[:, 0],\n",
|
|
" y=reduced_vectors[:, 1],\n",
|
|
" mode='markers',\n",
|
|
" marker=dict(size=5, color=colors, opacity=0.8),\n",
|
|
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(senders, documents)],\n",
|
|
" hoverinfo='text'\n",
|
|
")])\n",
|
|
"\n",
|
|
"fig.update_layout(\n",
|
|
" title='2D Chroma Vector Store Visualization',\n",
|
|
" scene=dict(xaxis_title='x',yaxis_title='y'),\n",
|
|
" width=800,\n",
|
|
" height=600,\n",
|
|
" margin=dict(r=20, b=10, l=10, t=40)\n",
|
|
")\n",
|
|
"\n",
|
|
"fig.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e1418e88-acd5-460a-bf2b-4e6efc88e3dd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's try 3D!\n",
|
|
"\n",
|
|
"tsne = TSNE(n_components=3, random_state=42)\n",
|
|
"reduced_vectors = tsne.fit_transform(vectors)\n",
|
|
"\n",
|
|
"# Create the 3D scatter plot\n",
|
|
"fig = go.Figure(data=[go.Scatter3d(\n",
|
|
" x=reduced_vectors[:, 0],\n",
|
|
" y=reduced_vectors[:, 1],\n",
|
|
" z=reduced_vectors[:, 2],\n",
|
|
" mode='markers',\n",
|
|
" marker=dict(size=5, color=colors, opacity=0.8),\n",
|
|
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(senders, documents)],\n",
|
|
" hoverinfo='text'\n",
|
|
")])\n",
|
|
"\n",
|
|
"fig.update_layout(\n",
|
|
" title='3D Chroma Vector Store Visualization',\n",
|
|
" scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),\n",
|
|
" width=900,\n",
|
|
" height=700,\n",
|
|
" margin=dict(r=20, b=10, l=10, t=40)\n",
|
|
")\n",
|
|
"\n",
|
|
"fig.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "bbbcb659-13ce-47ab-8a5e-01b930494964",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Langchain and Gradio to prototype a chat with the LLM\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d72567e8-f891-4797-944b-4612dc6613b1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"from langchain.prompts import PromptTemplate\n",
|
|
"from langchain.chains.combine_documents import create_stuff_documents_chain\n",
|
|
"from langchain.chains import create_retrieval_chain\n",
|
|
"\n",
|
|
"# create a new Chat with OpenAI\n",
|
|
"llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n",
|
|
"\n",
|
|
"# Alternative - if you'd like to use Ollama locally, uncomment this line instead\n",
|
|
"# llm = ChatOpenAI(temperature=0.7, model_name='llama3.2', base_url='http://localhost:11434/v1', api_key='ollama')\n",
|
|
"\n",
|
|
"# change LLM standard prompt (standard prompt defaults the answer to be 'I don't know' too often, especially when using a small LLM\n",
|
|
"\n",
|
|
"qa_prompt=PromptTemplate.from_template(\"Use the following pieces of context to answer the user's question. Answer as best you can given the information you have;\\\n",
|
|
" if you have a reasonable idea of the answer,/then explain it and mention that you're unsure. \\\n",
|
|
" But if you don't know the answer, don't make it up. \\\n",
|
|
" {context} \\\n",
|
|
" Question: {question} \\\n",
|
|
" Helpful Answer:\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"# Wrap into a StuffDocumentsChain, matching the variable name 'context'\n",
|
|
"combine_docs_chain = create_stuff_documents_chain(\n",
|
|
" llm=llm,\n",
|
|
" prompt=qa_prompt,\n",
|
|
" document_variable_name=\"context\"\n",
|
|
")\n",
|
|
"\n",
|
|
"# set up the conversation memory for the chat\n",
|
|
"#memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
|
|
"memory = ConversationBufferMemory(\n",
|
|
" memory_key='chat_history', \n",
|
|
" return_messages=True,\n",
|
|
" output_key='answer' \n",
|
|
")\n",
|
|
"\n",
|
|
"# the retriever is an abstraction over the VectorStore that will be used during RAG\n",
|
|
"retriever = vectorstore.as_retriever(search_kwargs={\"k\": 10})\n",
|
|
"\n",
|
|
"# putting it together: set up the conversation chain with the GPT 3.5 LLM, the vector store and memory\n",
|
|
"# conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)\n",
|
|
"\n",
|
|
"conversation_chain = ConversationalRetrievalChain.from_llm(\n",
|
|
" llm=llm,\n",
|
|
" retriever=retriever,\n",
|
|
" memory=memory,\n",
|
|
" combine_docs_chain_kwargs={\"prompt\": qa_prompt},\n",
|
|
" return_source_documents=True\n",
|
|
")\n",
|
|
"\n",
|
|
"def chat(question, history):\n",
|
|
" result = conversation_chain.invoke({\"question\": question})\n",
|
|
" return result[\"answer\"]\n",
|
|
"\n",
|
|
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fe4229aa-6afe-4592-93a4-71a47ab69846",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|