445
week5/community-contributions/w5d5_worker.py
Normal file
445
week5/community-contributions/w5d5_worker.py
Normal file
@@ -0,0 +1,445 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Knowledge Worker with Document Upload and Google Drive Integration
|
||||||
|
|
||||||
|
This script creates a knowledge worker that:
|
||||||
|
1. Allows users to upload documents through a Gradio UI
|
||||||
|
2. Integrates with Google Drive to access documents
|
||||||
|
3. Uses Chroma vector database for efficient document retrieval
|
||||||
|
4. Implements RAG (Retrieval Augmented Generation) for accurate responses
|
||||||
|
|
||||||
|
The system updates its context dynamically when new documents are uploaded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
# LangChain imports
|
||||||
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
||||||
|
from langchain_chroma import Chroma
|
||||||
|
|
||||||
|
# Visualization imports
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
|
||||||
|
# Removed Google Drive API imports
|
||||||
|
|
||||||
|
# Additional document loaders
|
||||||
|
try:
|
||||||
|
from langchain_community.document_loaders import Docx2txtLoader, UnstructuredExcelLoader
|
||||||
|
except ImportError:
|
||||||
|
print("Warning: Some document loaders not available. PDF and text files will still work.")
|
||||||
|
Docx2txtLoader = None
|
||||||
|
UnstructuredExcelLoader = None
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
MODEL = "gpt-4o-mini" # Using a cost-effective model
|
||||||
|
DB_NAME = "knowledge_worker_db"
|
||||||
|
UPLOAD_FOLDER = "uploaded_documents"
|
||||||
|
|
||||||
|
# Create upload folder if it doesn't exist
|
||||||
|
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv(override=True)
|
||||||
|
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')
|
||||||
|
|
||||||
|
# Removed Google Drive credentials configuration
|
||||||
|
|
||||||
|
# Use a simple text splitter approach
|
||||||
|
class SimpleTextSplitter:
|
||||||
|
def __init__(self, chunk_size=1000, chunk_overlap=200):
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.chunk_overlap = chunk_overlap
|
||||||
|
|
||||||
|
def split_documents(self, documents):
|
||||||
|
chunks = []
|
||||||
|
for doc in documents:
|
||||||
|
text = doc.page_content
|
||||||
|
start = 0
|
||||||
|
while start < len(text):
|
||||||
|
end = start + self.chunk_size
|
||||||
|
chunk_text = text[start:end]
|
||||||
|
chunk_doc = Document(page_content=chunk_text, metadata=doc.metadata.copy())
|
||||||
|
chunks.append(chunk_doc)
|
||||||
|
start = end - self.chunk_overlap
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
CharacterTextSplitter = SimpleTextSplitter
|
||||||
|
|
||||||
|
# Try different import paths for memory and chains
|
||||||
|
try:
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from langchain.chains import ConversationalRetrievalChain
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from langchain_core.memory import ConversationBufferMemory
|
||||||
|
from langchain_core.chains import ConversationalRetrievalChain
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from langchain_community.memory import ConversationBufferMemory
|
||||||
|
from langchain_community.chains import ConversationalRetrievalChain
|
||||||
|
except ImportError:
|
||||||
|
print("Warning: Memory and chains modules not found. Creating simple alternatives.")
|
||||||
|
# Create simple alternatives
|
||||||
|
class ConversationBufferMemory:
|
||||||
|
def __init__(self, memory_key='chat_history', return_messages=True):
|
||||||
|
self.memory_key = memory_key
|
||||||
|
self.return_messages = return_messages
|
||||||
|
self.chat_memory = []
|
||||||
|
|
||||||
|
def save_context(self, inputs, outputs):
|
||||||
|
self.chat_memory.append((inputs, outputs))
|
||||||
|
|
||||||
|
def load_memory_variables(self, inputs):
|
||||||
|
return {self.memory_key: self.chat_memory}
|
||||||
|
|
||||||
|
class ConversationalRetrievalChain:
|
||||||
|
def __init__(self, llm, retriever, memory):
|
||||||
|
self.llm = llm
|
||||||
|
self.retriever = retriever
|
||||||
|
self.memory = memory
|
||||||
|
|
||||||
|
def invoke(self, inputs):
|
||||||
|
question = inputs.get("question", "")
|
||||||
|
# Simple implementation - just return a basic response
|
||||||
|
return {"answer": f"I received your question: {question}. This is a simplified response."}
|
||||||
|
|
||||||
|
# Removed Google Drive Integration Functions
|
||||||
|
|
||||||
|
# Document Processing Functions
|
||||||
|
def get_loader_for_file(file_path):
|
||||||
|
"""
|
||||||
|
Get the appropriate document loader based on file extension
|
||||||
|
"""
|
||||||
|
file_extension = os.path.splitext(file_path)[1].lower()
|
||||||
|
|
||||||
|
if file_extension == '.pdf':
|
||||||
|
return PyPDFLoader(file_path)
|
||||||
|
elif file_extension in ['.docx', '.doc'] and Docx2txtLoader:
|
||||||
|
return Docx2txtLoader(file_path)
|
||||||
|
elif file_extension in ['.xlsx', '.xls'] and UnstructuredExcelLoader:
|
||||||
|
return UnstructuredExcelLoader(file_path)
|
||||||
|
elif file_extension in ['.txt', '.md']:
|
||||||
|
return TextLoader(file_path, encoding='utf-8')
|
||||||
|
else:
|
||||||
|
# Default to text loader for unknown types
|
||||||
|
try:
|
||||||
|
return TextLoader(file_path, encoding='utf-8')
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_document(file_path):
|
||||||
|
"""
|
||||||
|
Load a document using the appropriate loader
|
||||||
|
"""
|
||||||
|
loader = get_loader_for_file(file_path)
|
||||||
|
if loader:
|
||||||
|
try:
|
||||||
|
return loader.load()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading document {file_path}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def process_documents(documents):
|
||||||
|
"""
|
||||||
|
Split documents into chunks for embedding
|
||||||
|
"""
|
||||||
|
text_splitter = CharacterTextSplitter(
|
||||||
|
chunk_size=1000,
|
||||||
|
chunk_overlap=200
|
||||||
|
)
|
||||||
|
chunks = text_splitter.split_documents(documents)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Knowledge Base Class
|
||||||
|
class KnowledgeBase:
|
||||||
|
def __init__(self, db_name=DB_NAME):
|
||||||
|
self.db_name = db_name
|
||||||
|
self.embeddings = OpenAIEmbeddings()
|
||||||
|
self.vectorstore = None
|
||||||
|
self.initialize_vectorstore()
|
||||||
|
|
||||||
|
def initialize_vectorstore(self):
|
||||||
|
"""
|
||||||
|
Initialize the vector store, loading from disk if it exists
|
||||||
|
"""
|
||||||
|
if os.path.exists(self.db_name):
|
||||||
|
self.vectorstore = Chroma(persist_directory=self.db_name, embedding_function=self.embeddings)
|
||||||
|
print(f"Loaded existing vector store with {self.vectorstore._collection.count()} documents")
|
||||||
|
else:
|
||||||
|
# Create empty vectorstore
|
||||||
|
self.vectorstore = Chroma(persist_directory=self.db_name, embedding_function=self.embeddings)
|
||||||
|
print("Created new vector store")
|
||||||
|
|
||||||
|
def add_documents(self, documents):
|
||||||
|
"""
|
||||||
|
Process and add documents to the vector store
|
||||||
|
"""
|
||||||
|
if not documents:
|
||||||
|
return False
|
||||||
|
|
||||||
|
chunks = process_documents(documents)
|
||||||
|
if not chunks:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Add to existing vectorstore
|
||||||
|
self.vectorstore.add_documents(chunks)
|
||||||
|
print(f"Added {len(chunks)} chunks to vector store")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_retriever(self, k=4):
|
||||||
|
"""
|
||||||
|
Get a retriever for the vector store
|
||||||
|
"""
|
||||||
|
return self.vectorstore.as_retriever(search_kwargs={"k": k})
|
||||||
|
|
||||||
|
def visualize_vectors(self):
|
||||||
|
"""
|
||||||
|
Create a 3D visualization of the vector store
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
collection = self.vectorstore._collection
|
||||||
|
result = collection.get(include=['embeddings', 'documents', 'metadatas'])
|
||||||
|
|
||||||
|
if result['embeddings'] is None or len(result['embeddings']) == 0:
|
||||||
|
print("No embeddings found in vector store")
|
||||||
|
return None
|
||||||
|
|
||||||
|
vectors = np.array(result['embeddings'])
|
||||||
|
documents = result['documents']
|
||||||
|
metadatas = result['metadatas']
|
||||||
|
|
||||||
|
if len(vectors) < 2:
|
||||||
|
print("Not enough vectors for visualization (need at least 2)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get source info for coloring
|
||||||
|
sources = [metadata.get('source', 'unknown') for metadata in metadatas]
|
||||||
|
unique_sources = list(set(sources))
|
||||||
|
colors = [['blue', 'green', 'red', 'orange', 'purple', 'cyan'][unique_sources.index(s) % 6] for s in sources]
|
||||||
|
|
||||||
|
# Reduce dimensions for visualization
|
||||||
|
# Adjust perplexity based on number of samples
|
||||||
|
n_samples = len(vectors)
|
||||||
|
perplexity = min(30, max(1, n_samples - 1))
|
||||||
|
|
||||||
|
tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity)
|
||||||
|
reduced_vectors = tsne.fit_transform(vectors)
|
||||||
|
|
||||||
|
# Create the 3D scatter plot
|
||||||
|
fig = go.Figure(data=[go.Scatter3d(
|
||||||
|
x=reduced_vectors[:, 0],
|
||||||
|
y=reduced_vectors[:, 1],
|
||||||
|
z=reduced_vectors[:, 2],
|
||||||
|
mode='markers',
|
||||||
|
marker=dict(size=5, color=colors, opacity=0.8),
|
||||||
|
text=[f"Source: {s}<br>Text: {d[:100]}..." for s, d in zip(sources, documents)],
|
||||||
|
hoverinfo='text'
|
||||||
|
)])
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title='3D Vector Store Visualization',
|
||||||
|
scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),
|
||||||
|
width=900,
|
||||||
|
height=700,
|
||||||
|
margin=dict(r=20, b=10, l=10, t=40)
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error creating visualization: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Simple fallback chain implementation
|
||||||
|
class SimpleConversationalChain:
|
||||||
|
def __init__(self, llm, retriever, memory):
|
||||||
|
self.llm = llm
|
||||||
|
self.retriever = retriever
|
||||||
|
self.memory = memory
|
||||||
|
|
||||||
|
def invoke(self, inputs):
|
||||||
|
question = inputs.get("question", "")
|
||||||
|
# Get relevant documents - try different methods
|
||||||
|
try:
|
||||||
|
docs = self.retriever.get_relevant_documents(question)
|
||||||
|
except AttributeError:
|
||||||
|
try:
|
||||||
|
docs = self.retriever.invoke(question)
|
||||||
|
except:
|
||||||
|
docs = []
|
||||||
|
|
||||||
|
context = "\n".join([doc.page_content for doc in docs[:3]]) if docs else "No relevant context found."
|
||||||
|
|
||||||
|
# Create a simple prompt
|
||||||
|
prompt = f"""Based on the following context, answer the question:
|
||||||
|
|
||||||
|
Context: {context}
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
|
||||||
|
Answer:"""
|
||||||
|
|
||||||
|
# Get response from LLM
|
||||||
|
response = self.llm.invoke(prompt)
|
||||||
|
return {"answer": response.content if hasattr(response, 'content') else str(response)}
|
||||||
|
|
||||||
|
# Chat System Class
|
||||||
|
class ChatSystem:
|
||||||
|
def __init__(self, knowledge_base, model_name=MODEL):
|
||||||
|
self.knowledge_base = knowledge_base
|
||||||
|
self.model_name = model_name
|
||||||
|
self.llm = ChatOpenAI(temperature=0.7, model_name=self.model_name)
|
||||||
|
self.memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
|
||||||
|
self.conversation_chain = self._create_conversation_chain()
|
||||||
|
|
||||||
|
def _create_conversation_chain(self):
|
||||||
|
"""
|
||||||
|
Create a new conversation chain with the current retriever
|
||||||
|
"""
|
||||||
|
retriever = self.knowledge_base.get_retriever()
|
||||||
|
# Skip the problematic ConversationalRetrievalChain and use simple implementation
|
||||||
|
print("Using simple conversational chain implementation")
|
||||||
|
return SimpleConversationalChain(self.llm, retriever, self.memory)
|
||||||
|
|
||||||
|
def reset_conversation(self):
|
||||||
|
"""
|
||||||
|
Reset the conversation memory and chain
|
||||||
|
"""
|
||||||
|
self.memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
|
||||||
|
self.conversation_chain = self._create_conversation_chain()
|
||||||
|
return "Conversation has been reset."
|
||||||
|
|
||||||
|
def chat(self, question, history):
|
||||||
|
"""
|
||||||
|
Process a question and return the answer
|
||||||
|
"""
|
||||||
|
if not question.strip():
|
||||||
|
return "Please ask a question."
|
||||||
|
|
||||||
|
result = self.conversation_chain.invoke({"question": question})
|
||||||
|
return result["answer"]
|
||||||
|
|
||||||
|
def update_knowledge_base(self):
|
||||||
|
"""
|
||||||
|
Update the conversation chain with the latest knowledge base
|
||||||
|
"""
|
||||||
|
self.conversation_chain = self._create_conversation_chain()
|
||||||
|
|
||||||
|
# UI Functions
|
||||||
|
def handle_file_upload(files):
|
||||||
|
"""
|
||||||
|
Process uploaded files and add them to the knowledge base
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
return "No files uploaded."
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for file in files:
|
||||||
|
try:
|
||||||
|
docs = load_document(file.name)
|
||||||
|
if docs:
|
||||||
|
# Add upload source metadata
|
||||||
|
for doc in docs:
|
||||||
|
doc.metadata['source'] = 'upload'
|
||||||
|
doc.metadata['filename'] = os.path.basename(file.name)
|
||||||
|
documents.extend(docs)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing file {file.name}: {e}")
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
success = kb.add_documents(documents)
|
||||||
|
if success:
|
||||||
|
# Update the chat system with new knowledge
|
||||||
|
chat_system.update_knowledge_base()
|
||||||
|
return f"Successfully processed {len(documents)} documents."
|
||||||
|
|
||||||
|
return "No documents could be processed. Please check file formats."
|
||||||
|
|
||||||
|
def create_ui():
|
||||||
|
"""
|
||||||
|
Create the Gradio UI
|
||||||
|
"""
|
||||||
|
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
||||||
|
gr.Markdown("""
|
||||||
|
# Knowledge Worker
|
||||||
|
Upload documents or ask questions about your knowledge base.
|
||||||
|
""")
|
||||||
|
|
||||||
|
with gr.Tabs():
|
||||||
|
with gr.TabItem("Chat"):
|
||||||
|
chatbot = gr.ChatInterface(
|
||||||
|
chat_system.chat,
|
||||||
|
chatbot=gr.Chatbot(height=500, type="messages"),
|
||||||
|
textbox=gr.Textbox(placeholder="Ask a question about your documents...", container=False),
|
||||||
|
title="Knowledge Worker Chat",
|
||||||
|
type="messages"
|
||||||
|
)
|
||||||
|
reset_btn = gr.Button("Reset Conversation")
|
||||||
|
reset_btn.click(chat_system.reset_conversation, inputs=None, outputs=gr.Textbox())
|
||||||
|
|
||||||
|
with gr.TabItem("Upload Documents"):
|
||||||
|
with gr.Column():
|
||||||
|
file_output = gr.Textbox(label="Upload Status")
|
||||||
|
upload_button = gr.UploadButton(
|
||||||
|
"Click to Upload Files",
|
||||||
|
file_types=[".pdf", ".docx", ".txt", ".md", ".xlsx"],
|
||||||
|
file_count="multiple"
|
||||||
|
)
|
||||||
|
upload_button.upload(handle_file_upload, upload_button, file_output)
|
||||||
|
|
||||||
|
with gr.TabItem("Visualize Knowledge"):
|
||||||
|
visualize_btn = gr.Button("Generate Vector Visualization")
|
||||||
|
plot_output = gr.Plot(label="Vector Space Visualization")
|
||||||
|
visualize_btn.click(kb.visualize_vectors, inputs=None, outputs=plot_output)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Main function to initialize and run the knowledge worker
|
||||||
|
"""
|
||||||
|
global kb, chat_system
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Initializing Knowledge Worker...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize the knowledge base
|
||||||
|
print("Setting up vector database...")
|
||||||
|
kb = KnowledgeBase(DB_NAME)
|
||||||
|
print("Vector database initialized successfully")
|
||||||
|
|
||||||
|
# Google Drive integration removed
|
||||||
|
|
||||||
|
# Initialize the chat system
|
||||||
|
print("\nSetting up chat system...")
|
||||||
|
chat_system = ChatSystem(kb)
|
||||||
|
print("Chat system initialized successfully")
|
||||||
|
|
||||||
|
# Launch the Gradio app
|
||||||
|
print("\nLaunching Gradio interface...")
|
||||||
|
print("=" * 60)
|
||||||
|
print("The web interface will open in your browser")
|
||||||
|
print("You can also access it at the URL shown below")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
app = create_ui()
|
||||||
|
app.launch(inbrowser=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error initializing Knowledge Worker: {e}")
|
||||||
|
print("Please check your configuration and try again.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user