diff --git a/week5/community-contributions/w5d5_worker.py b/week5/community-contributions/w5d5_worker.py new file mode 100644 index 0000000..822cfcd --- /dev/null +++ b/week5/community-contributions/w5d5_worker.py @@ -0,0 +1,445 @@ +#!/usr/bin/env python3 +""" +Knowledge Worker with Document Upload and Google Drive Integration + +This script creates a knowledge worker that: +1. Allows users to upload documents through a Gradio UI +2. Integrates with Google Drive to access documents +3. Uses Chroma vector database for efficient document retrieval +4. Implements RAG (Retrieval Augmented Generation) for accurate responses + +The system updates its context dynamically when new documents are uploaded. +""" + +import os +import glob +import tempfile +from pathlib import Path +from dotenv import load_dotenv +import gradio as gr + +# LangChain imports +from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader +from langchain_core.documents import Document +from langchain_openai import OpenAIEmbeddings, ChatOpenAI +from langchain_chroma import Chroma + +# Visualization imports +import numpy as np +from sklearn.manifold import TSNE +import plotly.graph_objects as go + +# Removed Google Drive API imports + +# Additional document loaders +try: + from langchain_community.document_loaders import Docx2txtLoader, UnstructuredExcelLoader +except ImportError: + print("Warning: Some document loaders not available. PDF and text files will still work.") + Docx2txtLoader = None + UnstructuredExcelLoader = None + +# Configuration +MODEL = "gpt-4o-mini" # Using a cost-effective model +DB_NAME = "knowledge_worker_db" +UPLOAD_FOLDER = "uploaded_documents" + +# Create upload folder if it doesn't exist +os.makedirs(UPLOAD_FOLDER, exist_ok=True) + +# Load environment variables +load_dotenv(override=True) +os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env') + +# Removed Google Drive credentials configuration + +# Use a simple text splitter approach +class SimpleTextSplitter: + def __init__(self, chunk_size=1000, chunk_overlap=200): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + def split_documents(self, documents): + chunks = [] + for doc in documents: + text = doc.page_content + start = 0 + while start < len(text): + end = start + self.chunk_size + chunk_text = text[start:end] + chunk_doc = Document(page_content=chunk_text, metadata=doc.metadata.copy()) + chunks.append(chunk_doc) + start = end - self.chunk_overlap + return chunks + +CharacterTextSplitter = SimpleTextSplitter + +# Try different import paths for memory and chains +try: + from langchain.memory import ConversationBufferMemory + from langchain.chains import ConversationalRetrievalChain +except ImportError: + try: + from langchain_core.memory import ConversationBufferMemory + from langchain_core.chains import ConversationalRetrievalChain + except ImportError: + try: + from langchain_community.memory import ConversationBufferMemory + from langchain_community.chains import ConversationalRetrievalChain + except ImportError: + print("Warning: Memory and chains modules not found. Creating simple alternatives.") + # Create simple alternatives + class ConversationBufferMemory: + def __init__(self, memory_key='chat_history', return_messages=True): + self.memory_key = memory_key + self.return_messages = return_messages + self.chat_memory = [] + + def save_context(self, inputs, outputs): + self.chat_memory.append((inputs, outputs)) + + def load_memory_variables(self, inputs): + return {self.memory_key: self.chat_memory} + + class ConversationalRetrievalChain: + def __init__(self, llm, retriever, memory): + self.llm = llm + self.retriever = retriever + self.memory = memory + + def invoke(self, inputs): + question = inputs.get("question", "") + # Simple implementation - just return a basic response + return {"answer": f"I received your question: {question}. This is a simplified response."} + +# Removed Google Drive Integration Functions + +# Document Processing Functions +def get_loader_for_file(file_path): + """ + Get the appropriate document loader based on file extension + """ + file_extension = os.path.splitext(file_path)[1].lower() + + if file_extension == '.pdf': + return PyPDFLoader(file_path) + elif file_extension in ['.docx', '.doc'] and Docx2txtLoader: + return Docx2txtLoader(file_path) + elif file_extension in ['.xlsx', '.xls'] and UnstructuredExcelLoader: + return UnstructuredExcelLoader(file_path) + elif file_extension in ['.txt', '.md']: + return TextLoader(file_path, encoding='utf-8') + else: + # Default to text loader for unknown types + try: + return TextLoader(file_path, encoding='utf-8') + except: + return None + +def load_document(file_path): + """ + Load a document using the appropriate loader + """ + loader = get_loader_for_file(file_path) + if loader: + try: + return loader.load() + except Exception as e: + print(f"Error loading document {file_path}: {e}") + return [] + +def process_documents(documents): + """ + Split documents into chunks for embedding + """ + text_splitter = CharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200 + ) + chunks = text_splitter.split_documents(documents) + return chunks + +# Knowledge Base Class +class KnowledgeBase: + def __init__(self, db_name=DB_NAME): + self.db_name = db_name + self.embeddings = OpenAIEmbeddings() + self.vectorstore = None + self.initialize_vectorstore() + + def initialize_vectorstore(self): + """ + Initialize the vector store, loading from disk if it exists + """ + if os.path.exists(self.db_name): + self.vectorstore = Chroma(persist_directory=self.db_name, embedding_function=self.embeddings) + print(f"Loaded existing vector store with {self.vectorstore._collection.count()} documents") + else: + # Create empty vectorstore + self.vectorstore = Chroma(persist_directory=self.db_name, embedding_function=self.embeddings) + print("Created new vector store") + + def add_documents(self, documents): + """ + Process and add documents to the vector store + """ + if not documents: + return False + + chunks = process_documents(documents) + if not chunks: + return False + + # Add to existing vectorstore + self.vectorstore.add_documents(chunks) + print(f"Added {len(chunks)} chunks to vector store") + return True + + def get_retriever(self, k=4): + """ + Get a retriever for the vector store + """ + return self.vectorstore.as_retriever(search_kwargs={"k": k}) + + def visualize_vectors(self): + """ + Create a 3D visualization of the vector store + """ + try: + collection = self.vectorstore._collection + result = collection.get(include=['embeddings', 'documents', 'metadatas']) + + if result['embeddings'] is None or len(result['embeddings']) == 0: + print("No embeddings found in vector store") + return None + + vectors = np.array(result['embeddings']) + documents = result['documents'] + metadatas = result['metadatas'] + + if len(vectors) < 2: + print("Not enough vectors for visualization (need at least 2)") + return None + + # Get source info for coloring + sources = [metadata.get('source', 'unknown') for metadata in metadatas] + unique_sources = list(set(sources)) + colors = [['blue', 'green', 'red', 'orange', 'purple', 'cyan'][unique_sources.index(s) % 6] for s in sources] + + # Reduce dimensions for visualization + # Adjust perplexity based on number of samples + n_samples = len(vectors) + perplexity = min(30, max(1, n_samples - 1)) + + tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity) + reduced_vectors = tsne.fit_transform(vectors) + + # Create the 3D scatter plot + fig = go.Figure(data=[go.Scatter3d( + x=reduced_vectors[:, 0], + y=reduced_vectors[:, 1], + z=reduced_vectors[:, 2], + mode='markers', + marker=dict(size=5, color=colors, opacity=0.8), + text=[f"Source: {s}
Text: {d[:100]}..." for s, d in zip(sources, documents)], + hoverinfo='text' + )]) + + fig.update_layout( + title='3D Vector Store Visualization', + scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'), + width=900, + height=700, + margin=dict(r=20, b=10, l=10, t=40) + ) + + return fig + + except Exception as e: + print(f"Error creating visualization: {e}") + return None + +# Simple fallback chain implementation +class SimpleConversationalChain: + def __init__(self, llm, retriever, memory): + self.llm = llm + self.retriever = retriever + self.memory = memory + + def invoke(self, inputs): + question = inputs.get("question", "") + # Get relevant documents - try different methods + try: + docs = self.retriever.get_relevant_documents(question) + except AttributeError: + try: + docs = self.retriever.invoke(question) + except: + docs = [] + + context = "\n".join([doc.page_content for doc in docs[:3]]) if docs else "No relevant context found." + + # Create a simple prompt + prompt = f"""Based on the following context, answer the question: + +Context: {context} + +Question: {question} + +Answer:""" + + # Get response from LLM + response = self.llm.invoke(prompt) + return {"answer": response.content if hasattr(response, 'content') else str(response)} + +# Chat System Class +class ChatSystem: + def __init__(self, knowledge_base, model_name=MODEL): + self.knowledge_base = knowledge_base + self.model_name = model_name + self.llm = ChatOpenAI(temperature=0.7, model_name=self.model_name) + self.memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True) + self.conversation_chain = self._create_conversation_chain() + + def _create_conversation_chain(self): + """ + Create a new conversation chain with the current retriever + """ + retriever = self.knowledge_base.get_retriever() + # Skip the problematic ConversationalRetrievalChain and use simple implementation + print("Using simple conversational chain implementation") + return SimpleConversationalChain(self.llm, retriever, self.memory) + + def reset_conversation(self): + """ + Reset the conversation memory and chain + """ + self.memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True) + self.conversation_chain = self._create_conversation_chain() + return "Conversation has been reset." + + def chat(self, question, history): + """ + Process a question and return the answer + """ + if not question.strip(): + return "Please ask a question." + + result = self.conversation_chain.invoke({"question": question}) + return result["answer"] + + def update_knowledge_base(self): + """ + Update the conversation chain with the latest knowledge base + """ + self.conversation_chain = self._create_conversation_chain() + +# UI Functions +def handle_file_upload(files): + """ + Process uploaded files and add them to the knowledge base + """ + if not files: + return "No files uploaded." + + documents = [] + for file in files: + try: + docs = load_document(file.name) + if docs: + # Add upload source metadata + for doc in docs: + doc.metadata['source'] = 'upload' + doc.metadata['filename'] = os.path.basename(file.name) + documents.extend(docs) + except Exception as e: + print(f"Error processing file {file.name}: {e}") + + if documents: + success = kb.add_documents(documents) + if success: + # Update the chat system with new knowledge + chat_system.update_knowledge_base() + return f"Successfully processed {len(documents)} documents." + + return "No documents could be processed. Please check file formats." + +def create_ui(): + """ + Create the Gradio UI + """ + with gr.Blocks(theme=gr.themes.Soft()) as app: + gr.Markdown(""" + # Knowledge Worker + Upload documents or ask questions about your knowledge base. + """) + + with gr.Tabs(): + with gr.TabItem("Chat"): + chatbot = gr.ChatInterface( + chat_system.chat, + chatbot=gr.Chatbot(height=500, type="messages"), + textbox=gr.Textbox(placeholder="Ask a question about your documents...", container=False), + title="Knowledge Worker Chat", + type="messages" + ) + reset_btn = gr.Button("Reset Conversation") + reset_btn.click(chat_system.reset_conversation, inputs=None, outputs=gr.Textbox()) + + with gr.TabItem("Upload Documents"): + with gr.Column(): + file_output = gr.Textbox(label="Upload Status") + upload_button = gr.UploadButton( + "Click to Upload Files", + file_types=[".pdf", ".docx", ".txt", ".md", ".xlsx"], + file_count="multiple" + ) + upload_button.upload(handle_file_upload, upload_button, file_output) + + with gr.TabItem("Visualize Knowledge"): + visualize_btn = gr.Button("Generate Vector Visualization") + plot_output = gr.Plot(label="Vector Space Visualization") + visualize_btn.click(kb.visualize_vectors, inputs=None, outputs=plot_output) + + return app + +def main(): + """ + Main function to initialize and run the knowledge worker + """ + global kb, chat_system + + print("=" * 60) + print("Initializing Knowledge Worker...") + print("=" * 60) + + try: + # Initialize the knowledge base + print("Setting up vector database...") + kb = KnowledgeBase(DB_NAME) + print("Vector database initialized successfully") + + # Google Drive integration removed + + # Initialize the chat system + print("\nSetting up chat system...") + chat_system = ChatSystem(kb) + print("Chat system initialized successfully") + + # Launch the Gradio app + print("\nLaunching Gradio interface...") + print("=" * 60) + print("The web interface will open in your browser") + print("You can also access it at the URL shown below") + print("=" * 60) + + app = create_ui() + app.launch(inbrowser=True) + + except Exception as e: + print(f"Error initializing Knowledge Worker: {e}") + print("Please check your configuration and try again.") + return + +if __name__ == "__main__": + main()