From 0286066d18a8a453d824de36918d1b7eb9774f36 Mon Sep 17 00:00:00 2001 From: abdoulrasheed Date: Wed, 29 Oct 2025 21:06:59 +0000 Subject: [PATCH] W5 --- .gitignore | 1 + .../abdoul/week_five_exercise.ipynb | 922 ++++++++++++++++++ 2 files changed, 923 insertions(+) create mode 100644 community-contributions/abdoul/week_five_exercise.ipynb diff --git a/.gitignore b/.gitignore index 6f98227..c070fba 100644 --- a/.gitignore +++ b/.gitignore @@ -202,3 +202,4 @@ week4/main.rs local/ +*.apps.googleusercontent.com.json \ No newline at end of file diff --git a/community-contributions/abdoul/week_five_exercise.ipynb b/community-contributions/abdoul/week_five_exercise.ipynb new file mode 100644 index 0000000..f2e6c43 --- /dev/null +++ b/community-contributions/abdoul/week_five_exercise.ipynb @@ -0,0 +1,922 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e2c01a17", + "metadata": {}, + "source": [ + "### Search your Google Drive knowledge base with fully local processing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df7609cf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.3\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], + "source": [ + "%pip install -qU \"langchain==0.3.27\" \"langchain-core<1.0.0,>=0.3.78\" \"langchain-text-splitters<1.0.0,>=0.3.9\" langchain_ollama langchain_chroma langchain_community google-auth google-auth-oauthlib google-auth-httplib2 google-api-python-client python-docx" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "144bdf7c", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import logging\n", + "import os\n", + "import re\n", + "import sys\n", + "import hashlib\n", + "from pathlib import Path\n", + "from enum import StrEnum\n", + "from typing import Iterable, Optional\n", + "\n", + "import gradio as gr\n", + "from langchain_core.documents import Document\n", + "from langchain_text_splitters import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter\n", + "from langchain_ollama import OllamaEmbeddings, ChatOllama\n", + "from langchain.storage import InMemoryStore\n", + "from langchain_chroma import Chroma\n", + "from langchain_community.document_loaders import TextLoader\n", + "from google.oauth2.credentials import Credentials\n", + "from google_auth_oauthlib.flow import InstalledAppFlow\n", + "from google.auth.transport.requests import Request\n", + "from googleapiclient.discovery import build\n", + "from googleapiclient.http import MediaIoBaseDownload\n", + "from googleapiclient.errors import HttpError\n", + "from docx import Document as DocxDocument" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfdb143d", + "metadata": {}, + "outputs": [], + "source": [ + "logger = logging.getLogger('drive_sage')\n", + "logger.setLevel(logging.DEBUG)\n", + "\n", + "if not logger.handlers:\n", + " handler = logging.StreamHandler(sys.stdout)\n", + " formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n", + " handler.setFormatter(formatter)\n", + " logger.addHandler(handler)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41df43aa", + "metadata": {}, + "outputs": [], + "source": [ + "SCOPES = ['https://www.googleapis.com/auth/drive.readonly']\n", + "APP_ROOT = Path.cwd()\n", + "DATA_DIR = APP_ROOT / '.drive_sage'\n", + "DOWNLOAD_DIR = DATA_DIR / 'downloads'\n", + "VECTORSTORE_DIR = DATA_DIR / 'chroma'\n", + "TOKEN_PATH = DATA_DIR / 'token.json'\n", + "MANIFEST_PATH = DATA_DIR / 'manifest.json'\n", + "CLIENT_SECRET_FILE = APP_ROOT / 'client_secret_202216035337-4qson0c08g71u8uuihv6v46arv64nhvg.apps.googleusercontent.com.json'\n", + "\n", + "for path in (DATA_DIR, DOWNLOAD_DIR, VECTORSTORE_DIR):\n", + " path.mkdir(parents=True, exist_ok=True)\n", + "\n", + "FILE_TYPE_OPTIONS = {\n", + " 'txt': {\n", + " 'label': '.txt - Plain text',\n", + " 'extensions': ['.txt'],\n", + " 'mime_types': ['text/plain'],\n", + " },\n", + " 'md': {\n", + " 'label': '.md - Markdown',\n", + " 'extensions': ['.md'],\n", + " 'mime_types': ['text/markdown', 'text/plain'],\n", + " },\n", + " 'docx': {\n", + " 'label': '.docx - Word (OpenXML)',\n", + " 'extensions': ['.docx'],\n", + " 'mime_types': ['application/vnd.openxmlformats-officedocument.wordprocessingml.document'],\n", + " },\n", + " 'doc': {\n", + " 'label': '.doc - Word 97-2003',\n", + " 'extensions': ['.doc'],\n", + " 'mime_types': ['application/msword', 'application/vnd.ms-word.document.macroenabled.12'],\n", + " },\n", + " 'gdoc': {\n", + " 'label': 'Google Docs (exported)',\n", + " 'extensions': ['.docx'],\n", + " 'mime_types': ['application/vnd.google-apps.document'],\n", + " },\n", + "}\n", + "\n", + "FILE_TYPE_LABEL_TO_KEY = {config['label']: key for key, config in FILE_TYPE_OPTIONS.items()}\n", + "DEFAULT_FILE_TYPE_KEYS = ['txt', 'md', 'docx', 'doc', 'gdoc']\n", + "DEFAULT_FILE_TYPE_LABELS = [FILE_TYPE_OPTIONS[key]['label'] for key in DEFAULT_FILE_TYPE_KEYS]\n", + "\n", + "MIME_TYPE_TO_EXTENSION = {}\n", + "for key, config in FILE_TYPE_OPTIONS.items():\n", + " extension = config['extensions'][0]\n", + " for mime in config['mime_types']:\n", + " MIME_TYPE_TO_EXTENSION[mime] = extension\n", + "\n", + "GOOGLE_EXPORT_FORMATS = {\n", + " 'application/vnd.google-apps.document': (\n", + " 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',\n", + " '.docx'\n", + " ),\n", + "}\n", + "\n", + "SIMILARITY_DISTANCE_MAX = float(os.getenv('DRIVE_SAGE_DISTANCE_MAX', '1.2'))\n", + "MAX_CONTEXT_SNIPPET_CHARS = 1200\n", + "HASH_BLOCK_SIZE = 65536\n", + "EMBED_MODEL = os.getenv('DRIVE_SAGE_EMBED_MODEL', 'nomic-embed-text')\n", + "CHAT_MODEL = os.getenv('DRIVE_SAGE_CHAT_MODEL', 'llama3.1:latest')\n", + "\n", + "CUSTOM_CSS = \"\"\"\n", + "#chat-column {\n", + " height: 80vh;\n", + "}\n", + "#chat-column > div {\n", + " height: 100%;\n", + "}\n", + "#chat-column .gradio-chatbot,\n", + "#chat-column .gradio-chat-interface,\n", + "#chat-column .gradio-chatinterface {\n", + " height: 100%;\n", + "}\n", + "#chat-output {\n", + " height: 100%;\n", + "}\n", + "#chat-output .overflow-y-auto {\n", + " max-height: 100% !important;\n", + "}\n", + "#chat-output .h-full {\n", + " height: 100% !important;\n", + "}\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "225a921a", + "metadata": {}, + "outputs": [], + "source": [ + "def build_drive_service():\n", + " creds = None\n", + " if TOKEN_PATH.exists():\n", + " try:\n", + " creds = Credentials.from_authorized_user_file(str(TOKEN_PATH), SCOPES)\n", + " except Exception as exc:\n", + " logger.warning('Failed to load cached credentials: %s', exc)\n", + " TOKEN_PATH.unlink(missing_ok=True)\n", + " creds = None\n", + "\n", + " if not creds or not creds.valid:\n", + " if creds and creds.expired and creds.refresh_token:\n", + " try:\n", + " creds.refresh(Request())\n", + " except Exception as exc:\n", + " logger.warning('Refreshing credentials failed: %s', exc)\n", + " creds = None\n", + "\n", + " if not creds or not creds.valid:\n", + " if not CLIENT_SECRET_FILE.exists():\n", + " raise FileNotFoundError(\n", + " 'client_secret.json not found. Download it from Google Cloud Console and place it next to this notebook.'\n", + " )\n", + " flow = InstalledAppFlow.from_client_secrets_file(str(CLIENT_SECRET_FILE), SCOPES)\n", + " creds = flow.run_local_server(port=0)\n", + "\n", + " with TOKEN_PATH.open('w', encoding='utf-8') as token_file:\n", + " token_file.write(creds.to_json())\n", + " \n", + " return build('drive', 'v3', credentials=creds)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0acb8ec", + "metadata": {}, + "outputs": [], + "source": [ + "def load_manifest() -> dict:\n", + " if MANIFEST_PATH.exists():\n", + " try:\n", + " with MANIFEST_PATH.open('r', encoding='utf-8') as fp:\n", + " raw = json.load(fp)\n", + " if isinstance(raw, dict):\n", + " normalized: dict[str, dict] = {}\n", + " for file_id, entry in raw.items():\n", + " if isinstance(entry, dict):\n", + " normalized[file_id] = entry\n", + " else:\n", + " normalized[file_id] = {'modified': str(entry)}\n", + " return normalized\n", + " except json.JSONDecodeError:\n", + " logger.warning('Manifest file is corrupted; resetting cache.')\n", + " return {}\n", + "\n", + "def save_manifest(manifest: dict) -> None:\n", + " with MANIFEST_PATH.open('w', encoding='utf-8') as fp:\n", + " json.dump(manifest, fp, indent=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43098d19", + "metadata": {}, + "outputs": [], + "source": [ + "class Metadata(StrEnum):\n", + " ID = 'id'\n", + " SOURCE = 'source'\n", + " PARENT_ID = 'parent_id'\n", + " FILE_TYPE = 'file_type'\n", + " TITLE = 'title'\n", + " MODIFIED = 'modified'\n", + "\n", + "def metadata_key(key: Metadata) -> str:\n", + " return key.value\n", + "\n", + "embeddings = OllamaEmbeddings(model=EMBED_MODEL)\n", + "\n", + "try:\n", + " vectorstore = Chroma(\n", + " collection_name='drive_sage',\n", + " embedding_function=embeddings,\n", + " )\n", + "except Exception as exc:\n", + " logger.exception('Failed to initialize in-memory Chroma vector store')\n", + " raise RuntimeError('Unable to initialize Chroma vector store without persistence.') from exc\n", + "\n", + "docstore = InMemoryStore()\n", + "model = ChatOllama(model=CHAT_MODEL)\n", + "\n", + "DEFAULT_TEXT_SPLITTER = RecursiveCharacterTextSplitter(\n", + " chunk_size=1000,\n", + " chunk_overlap=150,\n", + " separators=['\\n\\n', '\\n', ' ', '']\n", + ")\n", + "MARKDOWN_HEADERS = [('#', 'Header 1'), ('##', 'Header 2'), ('###', 'Header 3')]\n", + "MARKDOWN_SPLITTER = MarkdownHeaderTextSplitter(headers_to_split_on=MARKDOWN_HEADERS, strip_headers=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "116be5f6", + "metadata": {}, + "outputs": [], + "source": [ + "def safe_filename(name: str, max_length: int = 120) -> str:\n", + " sanitized = re.sub(r'[^A-Za-z0-9._-]', '_', name)\n", + " sanitized = sanitized.strip('._') or 'untitled'\n", + " return sanitized[:max_length]\n", + "\n", + "def determine_extension(metadata: dict) -> str:\n", + " mime_type = metadata.get('mimeType', '')\n", + " name = metadata.get('name')\n", + " if name and Path(name).suffix:\n", + " return Path(name).suffix.lower()\n", + " if mime_type in GOOGLE_EXPORT_FORMATS:\n", + " return GOOGLE_EXPORT_FORMATS[mime_type][1]\n", + " return MIME_TYPE_TO_EXTENSION.get(mime_type, '.txt')\n", + "\n", + "def cached_file_path(metadata: dict) -> Path:\n", + " file_id = metadata.get('id', 'unknown')\n", + " extension = determine_extension(metadata)\n", + " safe_name = safe_filename(Path(metadata.get('name', file_id)).stem)\n", + " return DOWNLOAD_DIR / f'{safe_name}_{file_id}{extension}'\n", + "\n", + "def hash_file(path: Path) -> str:\n", + " digest = hashlib.sha1()\n", + " with path.open('rb') as fh:\n", + " while True:\n", + " block = fh.read(HASH_BLOCK_SIZE)\n", + " if not block:\n", + " break\n", + " digest.update(block)\n", + " return digest.hexdigest()\n", + "\n", + "def manifest_version(entry: dict | str | None) -> Optional[str]:\n", + " if entry is None:\n", + " return None\n", + " if isinstance(entry, str):\n", + " return entry\n", + " if isinstance(entry, dict):\n", + " return entry.get('modified')\n", + " return None\n", + "\n", + "def update_manifest_entry(manifest: dict, *, file_id: str, modified: str, path: Path, mime_type: str, name: str) -> None:\n", + " manifest[file_id] = {\n", + " 'modified': modified,\n", + " 'path': str(path),\n", + " 'mimeType': mime_type,\n", + " 'name': name,\n", + " 'file_type': Path(path).suffix.lower(),\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5fe85b9", + "metadata": {}, + "outputs": [], + "source": [ + "def list_drive_text_files(service, folder_id: Optional[str], allowed_mime_types: list[str], limit: Optional[int]) -> list[dict]:\n", + " query_parts = [\"trashed = false\"]\n", + " mime_types = allowed_mime_types or list(MIME_TYPE_TO_EXTENSION.keys())\n", + " mime_clause = ' or '.join([f\"mimeType = '{mime}'\" for mime in mime_types])\n", + " query_parts.append(f'({mime_clause})')\n", + " if folder_id:\n", + " query_parts.append(f\"'{folder_id}' in parents\")\n", + " query = ' and '.join(query_parts)\n", + "\n", + " files: list[dict] = []\n", + " page_token: Optional[str] = None\n", + "\n", + " while True:\n", + " page_size = min(100, limit - len(files)) if limit else 100\n", + " if page_size <= 0:\n", + " break\n", + " try:\n", + " response = service.files().list(\n", + " q=query,\n", + " spaces='drive',\n", + " fields='nextPageToken, files(id, name, mimeType, modifiedTime)',\n", + " orderBy='modifiedTime desc',\n", + " pageToken=page_token,\n", + " pageSize=page_size,\n", + " ).execute()\n", + " except HttpError as exc:\n", + " raise RuntimeError(f'Google Drive API error: {exc}') from exc\n", + "\n", + " batch = response.get('files', [])\n", + " files.extend(batch)\n", + " if limit and len(files) >= limit:\n", + " return files[:limit]\n", + " page_token = response.get('nextPageToken')\n", + " if not page_token:\n", + " break\n", + " return files\n", + "\n", + "def download_drive_file(service, metadata: dict, manifest: dict) -> Path:\n", + " file_id = metadata['id']\n", + " mime_type = metadata.get('mimeType', '')\n", + " cache_path = cached_file_path(metadata)\n", + " export_mime = None\n", + " if mime_type in GOOGLE_EXPORT_FORMATS:\n", + " export_mime, extension = GOOGLE_EXPORT_FORMATS[mime_type]\n", + " if cache_path.suffix.lower() != extension:\n", + " cache_path = cache_path.with_suffix(extension)\n", + "\n", + "\n", + " request = (\n", + " service.files().export_media(fileId=file_id, mimeType=export_mime)\n", + " if export_mime\n", + " else service.files().get_media(fileId=file_id)\n", + " )\n", + "\n", + " logger.debug('Downloading %s (%s) -> %s', metadata.get('name', file_id), file_id, cache_path)\n", + " with cache_path.open('wb') as fh:\n", + " downloader = MediaIoBaseDownload(fh, request)\n", + " done = False\n", + " while not done:\n", + " status, done = downloader.next_chunk()\n", + " if status:\n", + " logger.debug('Download progress %.0f%%', status.progress() * 100)\n", + "\n", + " update_manifest_entry(\n", + " manifest,\n", + " file_id=file_id,\n", + " modified=metadata.get('modifiedTime', ''),\n", + " path=cache_path,\n", + " mime_type=mime_type,\n", + " name=metadata.get('name', cache_path.name),\n", + " )\n", + " return cache_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27f50b9b", + "metadata": {}, + "outputs": [], + "source": [ + "def extract_docx_text(path: Path) -> str:\n", + " doc = DocxDocument(str(path))\n", + " lines = [paragraph.text.strip() for paragraph in doc.paragraphs if paragraph.text.strip()]\n", + " return '\\n'.join(lines)\n", + "\n", + "def load_documents(\n", + " path: Path,\n", + " *,\n", + " source_id: Optional[str] = None,\n", + " file_type: Optional[str] = None,\n", + " modified: Optional[str] = None,\n", + " display_name: Optional[str] = None,\n", + " ) -> list[Document]:\n", + " suffix = (file_type or path.suffix or '.txt').lower()\n", + " try:\n", + " if suffix in {'.txt', '.md'}:\n", + " loader = TextLoader(str(path), encoding='utf-8')\n", + " documents = loader.load()\n", + " elif suffix == '.docx':\n", + " documents = [Document(page_content=extract_docx_text(path), metadata={'source': str(path)})]\n", + " else:\n", + " raise ValueError(f'Unsupported file type: {suffix}')\n", + " except UnicodeDecodeError as exc:\n", + " raise ValueError(f'Failed to read {path}: {exc}') from exc\n", + "\n", + " base_metadata = {\n", + " metadata_key(Metadata.SOURCE): str(path),\n", + " metadata_key(Metadata.FILE_TYPE): suffix,\n", + " metadata_key(Metadata.TITLE): display_name or path.name,\n", + " }\n", + " if source_id:\n", + " base_metadata[metadata_key(Metadata.ID)] = source_id\n", + " if modified:\n", + " base_metadata[metadata_key(Metadata.MODIFIED)] = modified\n", + "\n", + " cleaned: list[Document] = []\n", + " for doc in documents:\n", + " content = doc.page_content.strip()\n", + " if not content:\n", + " continue\n", + " merged_metadata = {**doc.metadata, **base_metadata}\n", + " doc.page_content = content\n", + " doc.metadata = merged_metadata\n", + " cleaned.append(doc)\n", + " return cleaned\n", + "\n", + "def preprocess(documents: Iterable[Document]) -> list[Document]:\n", + " return [doc for doc in documents if doc.page_content]\n", + "\n", + "def chunk_documents(doc: Document) -> list[Document]:\n", + " parent_id = doc.metadata.get(metadata_key(Metadata.ID))\n", + " if not parent_id:\n", + " raise ValueError('Document is missing a stable identifier for chunking.')\n", + "\n", + " if doc.metadata.get(metadata_key(Metadata.FILE_TYPE)) == '.md':\n", + " markdown_docs = MARKDOWN_SPLITTER.split_text(doc.page_content)\n", + " seed_docs = [\n", + " Document(page_content=section.page_content, metadata={**doc.metadata, **section.metadata})\n", + " for section in markdown_docs\n", + " ]\n", + " else:\n", + " seed_docs = [doc]\n", + "\n", + " chunks = DEFAULT_TEXT_SPLITTER.split_documents(seed_docs)\n", + " for idx, chunk in enumerate(chunks):\n", + " chunk.metadata[metadata_key(Metadata.PARENT_ID)] = parent_id\n", + " chunk.metadata[metadata_key(Metadata.ID)] = f'{parent_id}::chunk-{idx:04d}'\n", + " chunk.metadata.setdefault(metadata_key(Metadata.SOURCE), doc.metadata.get(metadata_key(Metadata.SOURCE)))\n", + " chunk.metadata.setdefault(metadata_key(Metadata.TITLE), doc.metadata.get(metadata_key(Metadata.TITLE)))\n", + " return chunks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f135f35", + "metadata": {}, + "outputs": [], + "source": [ + "def sync_drive_and_index(folder_id=None, selected_types=None, file_limit=None, _state: bool = False, progress=gr.Progress(track_tqdm=False)):\n", + " folder = (folder_id or '').strip() or None\n", + "\n", + " selections = selected_types if selected_types is not None else DEFAULT_FILE_TYPE_LABELS\n", + " if not isinstance(selections, (list, tuple)):\n", + " selections = [selections]\n", + " selections = list(selections)\n", + "\n", + " if len(selections) == 0:\n", + " yield 'Select at least one file type before syncing.', False\n", + " return\n", + "\n", + " chosen_keys: list[str] = []\n", + " for item in selections:\n", + " key = FILE_TYPE_LABEL_TO_KEY.get(item, item)\n", + " if key in FILE_TYPE_OPTIONS:\n", + " chosen_keys.append(key)\n", + "\n", + " if not chosen_keys:\n", + " yield 'Select at least one file type before syncing.', False\n", + " return\n", + "\n", + " allowed_mime_types = sorted({mime for key in chosen_keys for mime in FILE_TYPE_OPTIONS[key]['mime_types']})\n", + "\n", + " limit: Optional[int] = None\n", + " limit_warning: Optional[str] = None\n", + " if file_limit not in (None, '', 0):\n", + " try:\n", + " parsed_limit = int(file_limit)\n", + " if parsed_limit > 0:\n", + " limit = parsed_limit\n", + " else:\n", + " raise ValueError\n", + " except (TypeError, ValueError):\n", + " limit_warning = 'File limit must be a positive integer. Syncing all matching files instead.'\n", + "\n", + " log_lines: list[str] = []\n", + "\n", + " def push(message: str) -> str:\n", + " log_lines.append(message)\n", + " return '\\n'.join(log_lines)\n", + "\n", + " if limit_warning:\n", + " logger.warning(limit_warning)\n", + " yield push(limit_warning), False\n", + "\n", + " progress(0, 'Authorizing Google Drive access...')\n", + " yield push('Authorizing Google Drive access...'), False\n", + "\n", + " try:\n", + " service = build_drive_service()\n", + " except FileNotFoundError as exc:\n", + " error_msg = f'Error: {exc}'\n", + " logger.error(error_msg)\n", + " yield push(error_msg), False\n", + " return\n", + " except Exception as exc:\n", + " logger.exception('Drive authorization failed')\n", + " error_msg = f'Error authenticating with Google Drive: {exc}'\n", + " yield push(error_msg), False\n", + " return\n", + "\n", + " list_message = 'Listing documents' + (f' (limit {limit})' if limit else '') + '...'\n", + " progress(0, list_message)\n", + " yield push(list_message), False\n", + "\n", + " try:\n", + " files = list_drive_text_files(service, folder, allowed_mime_types, limit)\n", + " except Exception as exc:\n", + " logger.exception('Listing Drive files failed')\n", + " error_msg = f'Error listing Google Drive files: {exc}'\n", + " yield push(error_msg), False\n", + " return\n", + "\n", + " total = len(files)\n", + " if total == 0:\n", + " info = 'No documents matching the selected types were found in Google Drive.'\n", + " yield push(info), True\n", + " return\n", + "\n", + " manifest = load_manifest()\n", + " downloaded_count = 0\n", + "\n", + " for index, metadata in enumerate(files, start=1):\n", + " file_id = metadata['id']\n", + " name = metadata.get('name', file_id)\n", + " remote_version = metadata.get('modifiedTime', '')\n", + " manifest_entry = manifest.get(file_id)\n", + " cache_path = cached_file_path(metadata)\n", + " if isinstance(manifest_entry, dict) and manifest_entry.get('path'):\n", + " cache_path = Path(manifest_entry['path'])\n", + " cached_version = manifest_version(manifest_entry)\n", + "\n", + " if cached_version == remote_version and cache_path.exists():\n", + " message = f\"{index}/{total} Skipping cached file: {name} -> {cache_path}\"\n", + " progress(index / total, message)\n", + " yield push(message), False\n", + " continue\n", + "\n", + " download_message = f\"{index}/{total} Downloading {name} -> {cache_path}\"\n", + " progress(max((index - 0.5) / total, 0), download_message)\n", + " yield push(download_message), False\n", + "\n", + " try:\n", + " downloaded_path = download_drive_file(service, metadata, manifest)\n", + " index_message = f\"{index}/{total} Indexing {downloaded_path.name}\"\n", + " progress(index / total, index_message)\n", + " yield push(index_message), False\n", + " index_document(\n", + " downloaded_path,\n", + " source_id=file_id,\n", + " file_type=downloaded_path.suffix,\n", + " modified=remote_version,\n", + " display_name=name,\n", + " manifest=manifest,\n", + " )\n", + " downloaded_count += 1\n", + " except Exception as exc:\n", + " error_message = f\"{index}/{total} Failed to sync {name}: {exc}\"\n", + " logger.exception(error_message)\n", + " progress(index / total, error_message)\n", + " yield push(error_message), False\n", + "\n", + " if downloaded_count > 0:\n", + " save_manifest(manifest)\n", + " summary = f'Indexed {downloaded_count} new document(s) from Google Drive.'\n", + " else:\n", + " summary = 'Google Drive is already in sync.'\n", + "\n", + " progress(1, summary)\n", + " yield push(summary), True" + ] + }, + { + "cell_type": "markdown", + "id": "0e2f176b", + "metadata": {}, + "source": [ + "## RAG Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20ad0e80", + "metadata": {}, + "outputs": [], + "source": [ + "def persist_vectorstore(_store) -> None:\n", + " \"\"\"In-memory mode: Chroma client does not persist between sessions.\"\"\"\n", + " return\n", + "\n", + "\n", + "def index_document(\n", + " file_path: Path | str,\n", + " *,\n", + " source_id: Optional[str] = None,\n", + " file_type: Optional[str] = None,\n", + " modified: Optional[str] = None,\n", + " display_name: Optional[str] = None,\n", + " manifest: Optional[dict] = None,\n", + " ) -> tuple[str, int]:\n", + " path = Path(file_path)\n", + " path = path.expanduser().resolve()\n", + " resolved_id = source_id or f'local::{hash_file(path)}'\n", + " documents = load_documents(\n", + " path,\n", + " source_id=resolved_id,\n", + " file_type=file_type,\n", + " modified=modified,\n", + " display_name=display_name,\n", + " )\n", + " documents = preprocess(documents)\n", + " if not documents:\n", + " logger.warning('No readable content found in %s; skipping.', path)\n", + " return resolved_id, 0\n", + "\n", + " total_chunks = 0\n", + " for doc in documents:\n", + " doc_id = doc.metadata.get(metadata_key(Metadata.ID), resolved_id)\n", + " doc.metadata[metadata_key(Metadata.ID)] = doc_id\n", + " vectorstore.delete(where={metadata_key(Metadata.PARENT_ID): doc_id})\n", + " chunks = chunk_documents(doc)\n", + " if not chunks:\n", + " continue\n", + " vectorstore.add_documents(chunks)\n", + " docstore.mset([(doc_id, doc)])\n", + " total_chunks += len(chunks)\n", + "\n", + " persist_vectorstore(vectorstore)\n", + " if manifest is not None and not source_id:\n", + " update_manifest_entry(\n", + " manifest,\n", + " file_id=resolved_id,\n", + " modified=hash_file(path),\n", + " path=path,\n", + " mime_type=file_type or Path(path).suffix or '.txt',\n", + " name=display_name or path.name,\n", + " )\n", + " return resolved_id, total_chunks" + ] + }, + { + "cell_type": "markdown", + "id": "a90db6ee", + "metadata": {}, + "source": [ + "### LLM Interaction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2e15e99", + "metadata": {}, + "outputs": [], + "source": [ + "def retrieve_context(query: str, *, top_k: int = 8, distance_threshold: Optional[float] = SIMILARITY_DISTANCE_MAX):\n", + " results_with_scores = vectorstore.similarity_search_with_score(query, k=top_k)\n", + " logger.info(f'Matching records: {len(results_with_scores)}')\n", + "\n", + " filtered: list[tuple[Document, float]] = []\n", + " for doc, score in results_with_scores:\n", + " if score is None:\n", + " continue\n", + " score_value = float(score)\n", + " print(f'DEBUG: Retrieved doc source={doc.metadata.get(metadata_key(Metadata.SOURCE))} distance={score_value}')\n", + " if distance_threshold is not None and score_value > distance_threshold:\n", + " logger.debug(\n", + " 'Skipping %s with distance %.4f (above threshold %.4f)',\n", + " doc.metadata.get(metadata_key(Metadata.SOURCE)),\n", + " score_value,\n", + " distance_threshold,\n", + " )\n", + " continue\n", + " filtered.append((doc, score_value))\n", + "\n", + " if not filtered:\n", + " return []\n", + "\n", + " for doc, score_value in filtered:\n", + " parent_id = doc.metadata.get(metadata_key(Metadata.PARENT_ID))\n", + " if parent_id:\n", + " parent_doc = docstore.mget([parent_id])[0]\n", + " if parent_doc and parent_doc.page_content:\n", + " logger.debug(\n", + " 'Parent preview (%s | %.3f): %s',\n", + " doc.metadata.get(metadata_key(Metadata.SOURCE), 'unknown'),\n", + " score_value,\n", + " parent_doc.page_content[:400].replace('\\n', ' '),\n", + " )\n", + "\n", + " return filtered\n", + "\n", + "\n", + "def build_prompt_sections(relevant_docs: list[tuple[Document, float]]) -> str:\n", + " sections: list[str] = []\n", + " for idx, (doc, score) in enumerate(relevant_docs, start=1):\n", + " source = doc.metadata.get(metadata_key(Metadata.SOURCE), 'unknown')\n", + " snippet = doc.page_content.strip()[:MAX_CONTEXT_SNIPPET_CHARS]\n", + " section = (\n", + " f'[{idx}] Source: {source}\\n'\n", + " f'Distance: {score:.3f}\\n'\n", + " f'Content:\\n{snippet}'\n", + " )\n", + " sections.append(section)\n", + " return '\\n\\n'.join(sections)\n", + "\n", + "\n", + "def ask(message, history):\n", + " relevant_docs = retrieve_context(message)\n", + " if not relevant_docs:\n", + " yield \"I don't have enough information in the synced documents to answer that yet. Please sync additional files or adjust the filters.\"\n", + " return\n", + "\n", + " context = build_prompt_sections(relevant_docs)\n", + " prompt = f'''\n", + " You are a retrieval-augmented assistant. Use ONLY the facts provided in the context to answer the user.\n", + " If the context does not contain the answer, reply exactly: \"I don't have enough information in the synced documents to answer that yet. Please sync additional files.\"\n", + " \n", + " Context:\\n{context}\n", + " '''\n", + "\n", + " messages = [\n", + " ('system', prompt),\n", + " ('user', message)\n", + " ]\n", + "\n", + " stream = model.stream(messages)\n", + " response_text = ''\n", + "\n", + " for chunk in stream:\n", + " response_text += chunk.content or ''\n", + " if not response_text:\n", + " continue\n", + "\n", + " yield response_text" + ] + }, + { + "cell_type": "markdown", + "id": "c3e632dc-9e87-4510-9fcd-aa699c27e82b", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Gradio UI" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3d68a74", + "metadata": {}, + "outputs": [], + "source": [ + "def chat(message, history, sync_ready):\n", + " if message is None:\n", + " return ''\n", + "\n", + " text_input = message.get('text', '')\n", + " files_uploaded = message.get('files', [])\n", + " latest_file_path = Path(files_uploaded[-1]) if files_uploaded else None\n", + " if latest_file_path:\n", + " manifest = load_manifest()\n", + " doc_id, chunk_count = index_document(\n", + " latest_file_path,\n", + " file_type=latest_file_path.suffix,\n", + " display_name=latest_file_path.name,\n", + " manifest=manifest,\n", + " )\n", + " save_manifest(manifest)\n", + " logger.info('Indexed upload %s as %s with %s chunk(s)', latest_file_path, doc_id, chunk_count)\n", + " if not text_input:\n", + " yield f'Indexed document from upload ({chunk_count} chunk(s)).'\n", + " return\n", + "\n", + " if not text_input:\n", + " return ''\n", + "\n", + " if not sync_ready and not files_uploaded:\n", + " yield 'Sync Google Drive before chatting or upload a document first.'\n", + " return\n", + "\n", + " for chunk in ask(text_input, history):\n", + " yield chunk\n", + "\n", + "title = \"Drive Sage\"\n", + "with gr.Blocks(title=title, fill_height=True, css=CUSTOM_CSS) as ui:\n", + " gr.Markdown(f'# {title}')\n", + " gr.Markdown('## Search your Google Drive knowledge base with fully local processing.')\n", + " sync_state = gr.State(False)\n", + "\n", + " with gr.Row():\n", + " with gr.Column(scale=3, elem_id='chat-column'):\n", + " gr.ChatInterface(\n", + " fn=chat,\n", + " chatbot=gr.Chatbot(height='80vh', elem_id='chat-output'),\n", + " type='messages',\n", + " textbox=gr.MultimodalTextbox(\n", + " file_types=['text', '.txt', '.md'],\n", + " autofocus=True,\n", + " elem_id='chat-input',\n", + " ),\n", + " additional_inputs=[sync_state],\n", + " )\n", + " with gr.Column(scale=2, min_width=320):\n", + " gr.Markdown('### Google Drive Sync')\n", + " drive_folder = gr.Textbox(\n", + " label='Folder ID (optional)',\n", + " placeholder='Leave blank to scan My Drive root',\n", + " )\n", + " file_types = gr.CheckboxGroup(\n", + " label='File types to sync',\n", + " choices=[config['label'] for config in FILE_TYPE_OPTIONS.values()],\n", + " value=DEFAULT_FILE_TYPE_LABELS,\n", + " )\n", + " file_limit = gr.Number(\n", + " label='Max files to sync (leave blank for all)',\n", + " value=20,\n", + " )\n", + " sync_btn = gr.Button('Sync Google Drive')\n", + " sync_status = gr.Markdown('No sync performed yet.')\n", + "\n", + " sync_btn.click(\n", + " sync_drive_and_index,\n", + " inputs=[drive_folder, file_types, file_limit, sync_state],\n", + " outputs=[sync_status, sync_state],\n", + " )\n", + "\n", + "ui.launch(debug=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}