Files
LLM_Engineering_OLD/week5/community-contributions/RAG-based-academic-assistant-v3.ipynb
2025-08-29 18:20:26 -04:00

410 lines
19 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "97a93fee-6bbd-477b-aba8-577d318a9f9d",
"metadata": {},
"source": [
"# AI-Powered Academic Knowledge Assistant\n",
"AI-powered RAG (Retrieval-Augmented Generation) system that transforms document collections into queryable knowledge bases using OpenAI embeddings and vector search. Features configurable chunking, file size limits, and retrieval parameters with a Gradio interface for processing PDFs and generating contextually-aware responses via LangChain and ChromaDB."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3589eee0-ce34-42f4-b538-b43f3b0d9f6f",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import glob\n",
"from dotenv import load_dotenv\n",
"import gradio as gr\n",
"import shutil\n",
"import tiktoken\n",
"import time\n",
"import uuid\n",
"from typing import List, Tuple, Optional\n",
"\n",
"# imports for langchain and Chroma\n",
"from langchain.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.schema import Document\n",
"from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n",
"from langchain_chroma import Chroma\n",
"from langchain.memory import ConversationBufferMemory\n",
"from langchain.chains import ConversationalRetrievalChain\n",
"from langchain.embeddings import HuggingFaceEmbeddings\n",
"\n",
"from langchain_community.document_loaders import PyPDFLoader, TextLoader\n",
"from langchain.docstore.document import Document\n",
"\n",
"# Load environment variables\n",
"load_dotenv(override=True)\n",
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
"\n",
"# Global variables to store the current setup\n",
"current_vectorstore = None\n",
"current_conversation_chain = None\n",
"processing_status = \"\"\n",
"\n",
"def count_tokens(text: str, model: str = \"gpt-4o-mini\") -> int:\n",
" \"\"\"Count tokens in text using tiktoken\"\"\"\n",
" try:\n",
" encoding = tiktoken.encoding_for_model(model)\n",
" return len(encoding.encode(text))\n",
" except:\n",
" # Fallback estimation: roughly 4 characters per token\n",
" return len(text) // 4\n",
"\n",
"def filter_chunks_by_tokens(chunks: List[Document], max_total_tokens: int = 250000) -> List[Document]:\n",
" \"\"\"Filter chunks to stay within token limits\"\"\"\n",
" filtered_chunks = []\n",
" total_tokens = 0\n",
" \n",
" for chunk in chunks:\n",
" chunk_tokens = count_tokens(chunk.page_content)\n",
" \n",
" # Skip individual chunks that are too large (shouldn't happen with proper splitting)\n",
" if chunk_tokens > 8000: # Individual chunk limit\n",
" continue\n",
" \n",
" if total_tokens + chunk_tokens <= max_total_tokens:\n",
" filtered_chunks.append(chunk)\n",
" total_tokens += chunk_tokens\n",
" else:\n",
" break\n",
" \n",
" return filtered_chunks\n",
"\n",
"def add_metadata(doc, doc_type, file_path):\n",
" \"\"\"Add metadata including document type and file information\"\"\"\n",
" doc.metadata[\"doc_type\"] = doc_type\n",
" doc.metadata[\"file_path\"] = file_path\n",
" doc.metadata[\"file_name\"] = os.path.basename(file_path)\n",
" return doc\n",
"\n",
"def check_file_size(file_path, max_size_bytes):\n",
" \"\"\"Check if file size is within the limit\"\"\"\n",
" try:\n",
" file_size = os.path.getsize(file_path)\n",
" return file_size <= max_size_bytes, file_size\n",
" except OSError:\n",
" return False, 0\n",
"\n",
"def load_pdfs_with_size_limit(folder_path, doc_type, max_size_bytes):\n",
" \"\"\"Load PDF files from a folder with size restrictions\"\"\"\n",
" pdf_files = glob.glob(os.path.join(folder_path, \"**/*.pdf\"), recursive=True)\n",
" loaded_docs = []\n",
" skipped_files = []\n",
" \n",
" for pdf_file in pdf_files:\n",
" is_valid_size, file_size = check_file_size(pdf_file, max_size_bytes)\n",
" \n",
" if is_valid_size:\n",
" try:\n",
" loader = PyPDFLoader(pdf_file)\n",
" docs = loader.load()\n",
" docs_with_metadata = [add_metadata(doc, doc_type, pdf_file) for doc in docs]\n",
" loaded_docs.extend(docs_with_metadata)\n",
" except Exception as e:\n",
" skipped_files.append((pdf_file, f\"Loading error: {str(e)}\"))\n",
" else:\n",
" file_size_mb = file_size / 1024 / 1024\n",
" skipped_files.append((pdf_file, f\"File too large: {file_size_mb:.2f} MB\"))\n",
" \n",
" return loaded_docs, skipped_files\n",
"\n",
"def process_documents(knowledge_base_dir: str, max_file_size_mb: float, chunk_size: int, chunk_overlap: int) -> Tuple[str, str]:\n",
" \"\"\"Process documents and create vector store\"\"\"\n",
" global current_vectorstore, current_conversation_chain\n",
" \n",
" try:\n",
" # Validate directory\n",
" if not knowledge_base_dir or not knowledge_base_dir.strip():\n",
" return \"❌ Error: Please enter a directory path!\", \"\"\n",
" \n",
" directory_path = knowledge_base_dir.strip()\n",
" \n",
" if not os.path.exists(directory_path):\n",
" return \"❌ Error: Directory does not exist! Please check the path.\", \"\"\n",
" \n",
" # Configuration\n",
" MAX_FILE_SIZE_BYTES = int(max_file_size_mb * 1024 * 1024)\n",
" \n",
" # Find folders\n",
" if directory_path.endswith('*'):\n",
" folders = glob.glob(directory_path)\n",
" else:\n",
" folders = glob.glob(os.path.join(directory_path, \"*\"))\n",
" \n",
" if not folders:\n",
" return \"❌ Error: No folders found in the specified directory!\", \"\"\n",
" \n",
" # Process documents\n",
" documents = []\n",
" all_skipped_files = []\n",
" status_lines = []\n",
" \n",
" status_lines.append(f\"🔍 Processing folders with {max_file_size_mb} MB file size limit...\")\n",
" status_lines.append(\"-\" * 60)\n",
" \n",
" for folder in folders:\n",
" if os.path.isdir(folder):\n",
" doc_type = os.path.basename(folder)\n",
" status_lines.append(f\"📁 Processing folder: {doc_type}\")\n",
" \n",
" folder_docs, skipped_files = load_pdfs_with_size_limit(folder, doc_type, MAX_FILE_SIZE_BYTES)\n",
" documents.extend(folder_docs)\n",
" all_skipped_files.extend(skipped_files)\n",
" \n",
" if folder_docs:\n",
" status_lines.append(f\" ✅ Loaded {len(folder_docs)} document pages\")\n",
" if skipped_files:\n",
" status_lines.append(f\" ⚠️ Skipped {len(skipped_files)} files\")\n",
" \n",
" if not documents:\n",
" error_msg = \"❌ No PDF documents were loaded successfully.\"\n",
" if all_skipped_files:\n",
" error_msg += f\"\\n\\nAll {len(all_skipped_files)} files were skipped:\"\n",
" for file_path, reason in all_skipped_files[:10]: # Show first 10\n",
" error_msg += f\"\\n • {os.path.basename(file_path)}: {reason}\"\n",
" if len(all_skipped_files) > 10:\n",
" error_msg += f\"\\n ... and {len(all_skipped_files) - 10} more\"\n",
" return error_msg, \"\"\n",
" \n",
" # Text splitting\n",
" status_lines.append(\"\\n\" + \"=\"*40)\n",
" status_lines.append(\"✂️ TEXT SPLITTING\")\n",
" status_lines.append(\"=\"*40)\n",
" \n",
" text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)\n",
" chunks = text_splitter.split_documents(documents)\n",
" \n",
" # Filter chunks by token count to prevent API errors\n",
" status_lines.append(\"🔢 Checking token limits...\")\n",
" original_chunk_count = len(chunks)\n",
" chunks = filter_chunks_by_tokens(chunks, max_total_tokens=250000)\n",
" \n",
" if len(chunks) < original_chunk_count:\n",
" status_lines.append(f\"⚠️ Filtered from {original_chunk_count} to {len(chunks)} chunks to stay within token limits\")\n",
" \n",
" # Create vectorstore\n",
" status_lines.append(\"🧮 Creating vector embeddings...\")\n",
" embeddings = OpenAIEmbeddings()\n",
" \n",
" # Use a temporary database name\n",
" db_name = \"temp_vector_db\"\n",
" \n",
" # Delete if already exists\n",
" if os.path.exists(db_name):\n",
" shutil.rmtree(db_name)\n",
" \n",
" # Create vectorstore\n",
" vectorstore = Chroma.from_documents(\n",
" documents=chunks, \n",
" embedding=embeddings, \n",
" persist_directory=db_name\n",
" )\n",
" \n",
" # Update global variables\n",
" current_vectorstore = vectorstore\n",
" \n",
" # Create conversation chain\n",
" llm = ChatOpenAI(temperature=0.7, model_name=\"gpt-4o-mini\")\n",
" memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
" retriever = vectorstore.as_retriever(search_kwargs={\"k\": 25})\n",
" current_conversation_chain = ConversationalRetrievalChain.from_llm(\n",
" llm=llm, \n",
" retriever=retriever, \n",
" memory=memory\n",
" )\n",
" \n",
" # Summary statistics\n",
" status_lines.append(\"\\n\" + \"=\"*40)\n",
" status_lines.append(\"📊 SUMMARY\")\n",
" status_lines.append(\"=\"*40)\n",
" status_lines.append(f\"✅ Total PDFs processed: {len(set(doc.metadata['file_path'] for doc in documents))}\")\n",
" status_lines.append(f\"📄 Total document pages: {len(documents)}\")\n",
" status_lines.append(f\"🧩 Total text chunks: {len(chunks)}\")\n",
" status_lines.append(f\"📁 Document types: {', '.join(set(doc.metadata['doc_type'] for doc in documents))}\")\n",
" status_lines.append(f\"🗃️ Vector store size: {vectorstore._collection.count()} embeddings\")\n",
" \n",
" if all_skipped_files:\n",
" status_lines.append(f\"\\n⚠ Skipped files: {len(all_skipped_files)}\")\n",
" for file_path, reason in all_skipped_files[:5]: # Show first 5\n",
" status_lines.append(f\" • {os.path.basename(file_path)}: {reason}\")\n",
" if len(all_skipped_files) > 5:\n",
" status_lines.append(f\" ... and {len(all_skipped_files) - 5} more\")\n",
" \n",
" success_msg = \"✅ Knowledge base successfully created and ready for questions!\"\n",
" detailed_status = \"\\n\".join(status_lines)\n",
" \n",
" return success_msg, detailed_status\n",
" \n",
" except Exception as e:\n",
" error_msg = f\"❌ Error processing documents: {str(e)}\"\n",
" return error_msg, \"\"\n",
"\n",
"def chat_with_documents(message, history, num_chunks):\n",
" \"\"\"Chat with the processed documents\"\"\"\n",
" global current_conversation_chain, current_vectorstore\n",
" \n",
" if current_conversation_chain is None:\n",
" return \"❌ Please process documents first before asking questions!\"\n",
" \n",
" try:\n",
" # Update retriever with new chunk count\n",
" if current_vectorstore is not None:\n",
" retriever = current_vectorstore.as_retriever(search_kwargs={\"k\": num_chunks})\n",
" current_conversation_chain.retriever = retriever\n",
" \n",
" result = current_conversation_chain.invoke({\"question\": message})\n",
" return result[\"answer\"]\n",
" \n",
" except Exception as e:\n",
" return f\"❌ Error generating response: {str(e)}\"\n",
"\n",
"def reset_conversation():\n",
" \"\"\"Reset the conversation memory\"\"\"\n",
" global current_conversation_chain\n",
" if current_conversation_chain is not None:\n",
" current_conversation_chain.memory.clear()\n",
" return \"✅ Conversation history cleared!\"\n",
" return \"No active conversation to reset.\"\n",
"\n",
"# Create Gradio Interface\n",
"with gr.Blocks(title=\"AI-Powered Academic Knowledge Assistant\", theme=gr.themes.Soft()) as app:\n",
" gr.Markdown(\"# 🎓 AI-Powered Academic Knowledge Assistant\")\n",
" gr.Markdown(\"Transform your entire document library into an intelligent, searchable AI tutor that answers questions instantly.\")\n",
" \n",
" with gr.Tabs():\n",
" # Configuration Tab\n",
" with gr.Tab(\"⚙️ Configuration\"):\n",
" gr.Markdown(\"### 📁 Document Processing Settings\")\n",
" \n",
" gr.Markdown(\"💡 **Tip:** Copy and paste your folder path here. On mobile, you can use file manager apps to copy folder paths.\")\n",
" \n",
" with gr.Row():\n",
" with gr.Column():\n",
" knowledge_dir = gr.Textbox(\n",
" label=\"Knowledge Base Directory\",\n",
" value=r\"C:\\Users\\Documents\\Syllabi\\Georgia Tech\\Spring 22\\Microwave Design\",\n",
" placeholder=\"Enter or paste your document directory path\",\n",
" lines=1\n",
" )\n",
" \n",
" max_file_size = gr.Slider(\n",
" label=\"Max File Size (MB)\",\n",
" minimum=0.5,\n",
" maximum=50,\n",
" value=4,\n",
" step=0.5\n",
" )\n",
" \n",
" with gr.Column():\n",
" chunk_size = gr.Slider(\n",
" label=\"Chunk Size (characters)\",\n",
" minimum=200,\n",
" maximum=1500,\n",
" value=800,\n",
" step=100,\n",
" info=\"Smaller chunks = better token management\"\n",
" )\n",
" \n",
" chunk_overlap = gr.Slider(\n",
" label=\"Chunk Overlap (characters)\",\n",
" minimum=0,\n",
" maximum=300,\n",
" value=150,\n",
" step=25,\n",
" info=\"Overlap preserves context between chunks\"\n",
" )\n",
" \n",
" process_btn = gr.Button(\"🚀 Process Documents\", variant=\"primary\", size=\"lg\")\n",
" \n",
" with gr.Row():\n",
" status_output = gr.Textbox(\n",
" label=\"Status\",\n",
" lines=2,\n",
" max_lines=2\n",
" )\n",
" \n",
" detailed_output = gr.Textbox(\n",
" label=\"Detailed Processing Log\",\n",
" lines=15,\n",
" max_lines=20\n",
" )\n",
" \n",
" # Chat Tab\n",
" with gr.Tab(\"💬 Chat\"):\n",
" gr.Markdown(\"### 🤖 Ask Questions About Your Documents\")\n",
" \n",
" with gr.Row():\n",
" with gr.Column(scale=1):\n",
" num_chunks = gr.Slider(\n",
" label=\"Number of chunks to retrieve\",\n",
" minimum=1,\n",
" maximum=50,\n",
" value=25,\n",
" step=1\n",
" )\n",
" \n",
" reset_btn = gr.Button(\"🗑️ Clear Chat History\", variant=\"secondary\")\n",
" reset_output = gr.Textbox(label=\"Reset Status\", lines=1)\n",
" \n",
" with gr.Column(scale=3):\n",
" chatbot = gr.ChatInterface(\n",
" fn=lambda msg, history: chat_with_documents(msg, history, num_chunks.value),\n",
" type=\"messages\",\n",
" title=\"Academic Assistant Chat\",\n",
" description=\"Ask questions about your processed documents\"\n",
" )\n",
" \n",
" # Event handlers\n",
" process_btn.click(\n",
" fn=process_documents,\n",
" inputs=[knowledge_dir, max_file_size, chunk_size, chunk_overlap],\n",
" outputs=[status_output, detailed_output]\n",
" )\n",
" \n",
" reset_btn.click(\n",
" fn=reset_conversation,\n",
" outputs=reset_output\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9eb807e0-194b-48dd-a1e9-b1b9b8a99620",
"metadata": {},
"outputs": [],
"source": [
"app.launch(share=True, inbrowser=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}