W5
This commit is contained in:
922
community-contributions/abdoul/week_five_exercise.ipynb
Normal file
922
community-contributions/abdoul/week_five_exercise.ipynb
Normal file
@@ -0,0 +1,922 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e2c01a17",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Search your Google Drive knowledge base with fully local processing."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "df7609cf",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.3\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install -qU \"langchain==0.3.27\" \"langchain-core<1.0.0,>=0.3.78\" \"langchain-text-splitters<1.0.0,>=0.3.9\" langchain_ollama langchain_chroma langchain_community google-auth google-auth-oauthlib google-auth-httplib2 google-api-python-client python-docx"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "144bdf7c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import logging\n",
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import sys\n",
|
||||
"import hashlib\n",
|
||||
"from pathlib import Path\n",
|
||||
"from enum import StrEnum\n",
|
||||
"from typing import Iterable, Optional\n",
|
||||
"\n",
|
||||
"import gradio as gr\n",
|
||||
"from langchain_core.documents import Document\n",
|
||||
"from langchain_text_splitters import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter\n",
|
||||
"from langchain_ollama import OllamaEmbeddings, ChatOllama\n",
|
||||
"from langchain.storage import InMemoryStore\n",
|
||||
"from langchain_chroma import Chroma\n",
|
||||
"from langchain_community.document_loaders import TextLoader\n",
|
||||
"from google.oauth2.credentials import Credentials\n",
|
||||
"from google_auth_oauthlib.flow import InstalledAppFlow\n",
|
||||
"from google.auth.transport.requests import Request\n",
|
||||
"from googleapiclient.discovery import build\n",
|
||||
"from googleapiclient.http import MediaIoBaseDownload\n",
|
||||
"from googleapiclient.errors import HttpError\n",
|
||||
"from docx import Document as DocxDocument"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dfdb143d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"logger = logging.getLogger('drive_sage')\n",
|
||||
"logger.setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"if not logger.handlers:\n",
|
||||
" handler = logging.StreamHandler(sys.stdout)\n",
|
||||
" formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n",
|
||||
" handler.setFormatter(formatter)\n",
|
||||
" logger.addHandler(handler)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "41df43aa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SCOPES = ['https://www.googleapis.com/auth/drive.readonly']\n",
|
||||
"APP_ROOT = Path.cwd()\n",
|
||||
"DATA_DIR = APP_ROOT / '.drive_sage'\n",
|
||||
"DOWNLOAD_DIR = DATA_DIR / 'downloads'\n",
|
||||
"VECTORSTORE_DIR = DATA_DIR / 'chroma'\n",
|
||||
"TOKEN_PATH = DATA_DIR / 'token.json'\n",
|
||||
"MANIFEST_PATH = DATA_DIR / 'manifest.json'\n",
|
||||
"CLIENT_SECRET_FILE = APP_ROOT / 'client_secret_202216035337-4qson0c08g71u8uuihv6v46arv64nhvg.apps.googleusercontent.com.json'\n",
|
||||
"\n",
|
||||
"for path in (DATA_DIR, DOWNLOAD_DIR, VECTORSTORE_DIR):\n",
|
||||
" path.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"FILE_TYPE_OPTIONS = {\n",
|
||||
" 'txt': {\n",
|
||||
" 'label': '.txt - Plain text',\n",
|
||||
" 'extensions': ['.txt'],\n",
|
||||
" 'mime_types': ['text/plain'],\n",
|
||||
" },\n",
|
||||
" 'md': {\n",
|
||||
" 'label': '.md - Markdown',\n",
|
||||
" 'extensions': ['.md'],\n",
|
||||
" 'mime_types': ['text/markdown', 'text/plain'],\n",
|
||||
" },\n",
|
||||
" 'docx': {\n",
|
||||
" 'label': '.docx - Word (OpenXML)',\n",
|
||||
" 'extensions': ['.docx'],\n",
|
||||
" 'mime_types': ['application/vnd.openxmlformats-officedocument.wordprocessingml.document'],\n",
|
||||
" },\n",
|
||||
" 'doc': {\n",
|
||||
" 'label': '.doc - Word 97-2003',\n",
|
||||
" 'extensions': ['.doc'],\n",
|
||||
" 'mime_types': ['application/msword', 'application/vnd.ms-word.document.macroenabled.12'],\n",
|
||||
" },\n",
|
||||
" 'gdoc': {\n",
|
||||
" 'label': 'Google Docs (exported)',\n",
|
||||
" 'extensions': ['.docx'],\n",
|
||||
" 'mime_types': ['application/vnd.google-apps.document'],\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"FILE_TYPE_LABEL_TO_KEY = {config['label']: key for key, config in FILE_TYPE_OPTIONS.items()}\n",
|
||||
"DEFAULT_FILE_TYPE_KEYS = ['txt', 'md', 'docx', 'doc', 'gdoc']\n",
|
||||
"DEFAULT_FILE_TYPE_LABELS = [FILE_TYPE_OPTIONS[key]['label'] for key in DEFAULT_FILE_TYPE_KEYS]\n",
|
||||
"\n",
|
||||
"MIME_TYPE_TO_EXTENSION = {}\n",
|
||||
"for key, config in FILE_TYPE_OPTIONS.items():\n",
|
||||
" extension = config['extensions'][0]\n",
|
||||
" for mime in config['mime_types']:\n",
|
||||
" MIME_TYPE_TO_EXTENSION[mime] = extension\n",
|
||||
"\n",
|
||||
"GOOGLE_EXPORT_FORMATS = {\n",
|
||||
" 'application/vnd.google-apps.document': (\n",
|
||||
" 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',\n",
|
||||
" '.docx'\n",
|
||||
" ),\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"SIMILARITY_DISTANCE_MAX = float(os.getenv('DRIVE_SAGE_DISTANCE_MAX', '1.2'))\n",
|
||||
"MAX_CONTEXT_SNIPPET_CHARS = 1200\n",
|
||||
"HASH_BLOCK_SIZE = 65536\n",
|
||||
"EMBED_MODEL = os.getenv('DRIVE_SAGE_EMBED_MODEL', 'nomic-embed-text')\n",
|
||||
"CHAT_MODEL = os.getenv('DRIVE_SAGE_CHAT_MODEL', 'llama3.1:latest')\n",
|
||||
"\n",
|
||||
"CUSTOM_CSS = \"\"\"\n",
|
||||
"#chat-column {\n",
|
||||
" height: 80vh;\n",
|
||||
"}\n",
|
||||
"#chat-column > div {\n",
|
||||
" height: 100%;\n",
|
||||
"}\n",
|
||||
"#chat-column .gradio-chatbot,\n",
|
||||
"#chat-column .gradio-chat-interface,\n",
|
||||
"#chat-column .gradio-chatinterface {\n",
|
||||
" height: 100%;\n",
|
||||
"}\n",
|
||||
"#chat-output {\n",
|
||||
" height: 100%;\n",
|
||||
"}\n",
|
||||
"#chat-output .overflow-y-auto {\n",
|
||||
" max-height: 100% !important;\n",
|
||||
"}\n",
|
||||
"#chat-output .h-full {\n",
|
||||
" height: 100% !important;\n",
|
||||
"}\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "225a921a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def build_drive_service():\n",
|
||||
" creds = None\n",
|
||||
" if TOKEN_PATH.exists():\n",
|
||||
" try:\n",
|
||||
" creds = Credentials.from_authorized_user_file(str(TOKEN_PATH), SCOPES)\n",
|
||||
" except Exception as exc:\n",
|
||||
" logger.warning('Failed to load cached credentials: %s', exc)\n",
|
||||
" TOKEN_PATH.unlink(missing_ok=True)\n",
|
||||
" creds = None\n",
|
||||
"\n",
|
||||
" if not creds or not creds.valid:\n",
|
||||
" if creds and creds.expired and creds.refresh_token:\n",
|
||||
" try:\n",
|
||||
" creds.refresh(Request())\n",
|
||||
" except Exception as exc:\n",
|
||||
" logger.warning('Refreshing credentials failed: %s', exc)\n",
|
||||
" creds = None\n",
|
||||
"\n",
|
||||
" if not creds or not creds.valid:\n",
|
||||
" if not CLIENT_SECRET_FILE.exists():\n",
|
||||
" raise FileNotFoundError(\n",
|
||||
" 'client_secret.json not found. Download it from Google Cloud Console and place it next to this notebook.'\n",
|
||||
" )\n",
|
||||
" flow = InstalledAppFlow.from_client_secrets_file(str(CLIENT_SECRET_FILE), SCOPES)\n",
|
||||
" creds = flow.run_local_server(port=0)\n",
|
||||
"\n",
|
||||
" with TOKEN_PATH.open('w', encoding='utf-8') as token_file:\n",
|
||||
" token_file.write(creds.to_json())\n",
|
||||
" \n",
|
||||
" return build('drive', 'v3', credentials=creds)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e0acb8ec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def load_manifest() -> dict:\n",
|
||||
" if MANIFEST_PATH.exists():\n",
|
||||
" try:\n",
|
||||
" with MANIFEST_PATH.open('r', encoding='utf-8') as fp:\n",
|
||||
" raw = json.load(fp)\n",
|
||||
" if isinstance(raw, dict):\n",
|
||||
" normalized: dict[str, dict] = {}\n",
|
||||
" for file_id, entry in raw.items():\n",
|
||||
" if isinstance(entry, dict):\n",
|
||||
" normalized[file_id] = entry\n",
|
||||
" else:\n",
|
||||
" normalized[file_id] = {'modified': str(entry)}\n",
|
||||
" return normalized\n",
|
||||
" except json.JSONDecodeError:\n",
|
||||
" logger.warning('Manifest file is corrupted; resetting cache.')\n",
|
||||
" return {}\n",
|
||||
"\n",
|
||||
"def save_manifest(manifest: dict) -> None:\n",
|
||||
" with MANIFEST_PATH.open('w', encoding='utf-8') as fp:\n",
|
||||
" json.dump(manifest, fp, indent=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "43098d19",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Metadata(StrEnum):\n",
|
||||
" ID = 'id'\n",
|
||||
" SOURCE = 'source'\n",
|
||||
" PARENT_ID = 'parent_id'\n",
|
||||
" FILE_TYPE = 'file_type'\n",
|
||||
" TITLE = 'title'\n",
|
||||
" MODIFIED = 'modified'\n",
|
||||
"\n",
|
||||
"def metadata_key(key: Metadata) -> str:\n",
|
||||
" return key.value\n",
|
||||
"\n",
|
||||
"embeddings = OllamaEmbeddings(model=EMBED_MODEL)\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" vectorstore = Chroma(\n",
|
||||
" collection_name='drive_sage',\n",
|
||||
" embedding_function=embeddings,\n",
|
||||
" )\n",
|
||||
"except Exception as exc:\n",
|
||||
" logger.exception('Failed to initialize in-memory Chroma vector store')\n",
|
||||
" raise RuntimeError('Unable to initialize Chroma vector store without persistence.') from exc\n",
|
||||
"\n",
|
||||
"docstore = InMemoryStore()\n",
|
||||
"model = ChatOllama(model=CHAT_MODEL)\n",
|
||||
"\n",
|
||||
"DEFAULT_TEXT_SPLITTER = RecursiveCharacterTextSplitter(\n",
|
||||
" chunk_size=1000,\n",
|
||||
" chunk_overlap=150,\n",
|
||||
" separators=['\\n\\n', '\\n', ' ', '']\n",
|
||||
")\n",
|
||||
"MARKDOWN_HEADERS = [('#', 'Header 1'), ('##', 'Header 2'), ('###', 'Header 3')]\n",
|
||||
"MARKDOWN_SPLITTER = MarkdownHeaderTextSplitter(headers_to_split_on=MARKDOWN_HEADERS, strip_headers=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "116be5f6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def safe_filename(name: str, max_length: int = 120) -> str:\n",
|
||||
" sanitized = re.sub(r'[^A-Za-z0-9._-]', '_', name)\n",
|
||||
" sanitized = sanitized.strip('._') or 'untitled'\n",
|
||||
" return sanitized[:max_length]\n",
|
||||
"\n",
|
||||
"def determine_extension(metadata: dict) -> str:\n",
|
||||
" mime_type = metadata.get('mimeType', '')\n",
|
||||
" name = metadata.get('name')\n",
|
||||
" if name and Path(name).suffix:\n",
|
||||
" return Path(name).suffix.lower()\n",
|
||||
" if mime_type in GOOGLE_EXPORT_FORMATS:\n",
|
||||
" return GOOGLE_EXPORT_FORMATS[mime_type][1]\n",
|
||||
" return MIME_TYPE_TO_EXTENSION.get(mime_type, '.txt')\n",
|
||||
"\n",
|
||||
"def cached_file_path(metadata: dict) -> Path:\n",
|
||||
" file_id = metadata.get('id', 'unknown')\n",
|
||||
" extension = determine_extension(metadata)\n",
|
||||
" safe_name = safe_filename(Path(metadata.get('name', file_id)).stem)\n",
|
||||
" return DOWNLOAD_DIR / f'{safe_name}_{file_id}{extension}'\n",
|
||||
"\n",
|
||||
"def hash_file(path: Path) -> str:\n",
|
||||
" digest = hashlib.sha1()\n",
|
||||
" with path.open('rb') as fh:\n",
|
||||
" while True:\n",
|
||||
" block = fh.read(HASH_BLOCK_SIZE)\n",
|
||||
" if not block:\n",
|
||||
" break\n",
|
||||
" digest.update(block)\n",
|
||||
" return digest.hexdigest()\n",
|
||||
"\n",
|
||||
"def manifest_version(entry: dict | str | None) -> Optional[str]:\n",
|
||||
" if entry is None:\n",
|
||||
" return None\n",
|
||||
" if isinstance(entry, str):\n",
|
||||
" return entry\n",
|
||||
" if isinstance(entry, dict):\n",
|
||||
" return entry.get('modified')\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
"def update_manifest_entry(manifest: dict, *, file_id: str, modified: str, path: Path, mime_type: str, name: str) -> None:\n",
|
||||
" manifest[file_id] = {\n",
|
||||
" 'modified': modified,\n",
|
||||
" 'path': str(path),\n",
|
||||
" 'mimeType': mime_type,\n",
|
||||
" 'name': name,\n",
|
||||
" 'file_type': Path(path).suffix.lower(),\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5fe85b9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def list_drive_text_files(service, folder_id: Optional[str], allowed_mime_types: list[str], limit: Optional[int]) -> list[dict]:\n",
|
||||
" query_parts = [\"trashed = false\"]\n",
|
||||
" mime_types = allowed_mime_types or list(MIME_TYPE_TO_EXTENSION.keys())\n",
|
||||
" mime_clause = ' or '.join([f\"mimeType = '{mime}'\" for mime in mime_types])\n",
|
||||
" query_parts.append(f'({mime_clause})')\n",
|
||||
" if folder_id:\n",
|
||||
" query_parts.append(f\"'{folder_id}' in parents\")\n",
|
||||
" query = ' and '.join(query_parts)\n",
|
||||
"\n",
|
||||
" files: list[dict] = []\n",
|
||||
" page_token: Optional[str] = None\n",
|
||||
"\n",
|
||||
" while True:\n",
|
||||
" page_size = min(100, limit - len(files)) if limit else 100\n",
|
||||
" if page_size <= 0:\n",
|
||||
" break\n",
|
||||
" try:\n",
|
||||
" response = service.files().list(\n",
|
||||
" q=query,\n",
|
||||
" spaces='drive',\n",
|
||||
" fields='nextPageToken, files(id, name, mimeType, modifiedTime)',\n",
|
||||
" orderBy='modifiedTime desc',\n",
|
||||
" pageToken=page_token,\n",
|
||||
" pageSize=page_size,\n",
|
||||
" ).execute()\n",
|
||||
" except HttpError as exc:\n",
|
||||
" raise RuntimeError(f'Google Drive API error: {exc}') from exc\n",
|
||||
"\n",
|
||||
" batch = response.get('files', [])\n",
|
||||
" files.extend(batch)\n",
|
||||
" if limit and len(files) >= limit:\n",
|
||||
" return files[:limit]\n",
|
||||
" page_token = response.get('nextPageToken')\n",
|
||||
" if not page_token:\n",
|
||||
" break\n",
|
||||
" return files\n",
|
||||
"\n",
|
||||
"def download_drive_file(service, metadata: dict, manifest: dict) -> Path:\n",
|
||||
" file_id = metadata['id']\n",
|
||||
" mime_type = metadata.get('mimeType', '')\n",
|
||||
" cache_path = cached_file_path(metadata)\n",
|
||||
" export_mime = None\n",
|
||||
" if mime_type in GOOGLE_EXPORT_FORMATS:\n",
|
||||
" export_mime, extension = GOOGLE_EXPORT_FORMATS[mime_type]\n",
|
||||
" if cache_path.suffix.lower() != extension:\n",
|
||||
" cache_path = cache_path.with_suffix(extension)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" request = (\n",
|
||||
" service.files().export_media(fileId=file_id, mimeType=export_mime)\n",
|
||||
" if export_mime\n",
|
||||
" else service.files().get_media(fileId=file_id)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" logger.debug('Downloading %s (%s) -> %s', metadata.get('name', file_id), file_id, cache_path)\n",
|
||||
" with cache_path.open('wb') as fh:\n",
|
||||
" downloader = MediaIoBaseDownload(fh, request)\n",
|
||||
" done = False\n",
|
||||
" while not done:\n",
|
||||
" status, done = downloader.next_chunk()\n",
|
||||
" if status:\n",
|
||||
" logger.debug('Download progress %.0f%%', status.progress() * 100)\n",
|
||||
"\n",
|
||||
" update_manifest_entry(\n",
|
||||
" manifest,\n",
|
||||
" file_id=file_id,\n",
|
||||
" modified=metadata.get('modifiedTime', ''),\n",
|
||||
" path=cache_path,\n",
|
||||
" mime_type=mime_type,\n",
|
||||
" name=metadata.get('name', cache_path.name),\n",
|
||||
" )\n",
|
||||
" return cache_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "27f50b9b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def extract_docx_text(path: Path) -> str:\n",
|
||||
" doc = DocxDocument(str(path))\n",
|
||||
" lines = [paragraph.text.strip() for paragraph in doc.paragraphs if paragraph.text.strip()]\n",
|
||||
" return '\\n'.join(lines)\n",
|
||||
"\n",
|
||||
"def load_documents(\n",
|
||||
" path: Path,\n",
|
||||
" *,\n",
|
||||
" source_id: Optional[str] = None,\n",
|
||||
" file_type: Optional[str] = None,\n",
|
||||
" modified: Optional[str] = None,\n",
|
||||
" display_name: Optional[str] = None,\n",
|
||||
" ) -> list[Document]:\n",
|
||||
" suffix = (file_type or path.suffix or '.txt').lower()\n",
|
||||
" try:\n",
|
||||
" if suffix in {'.txt', '.md'}:\n",
|
||||
" loader = TextLoader(str(path), encoding='utf-8')\n",
|
||||
" documents = loader.load()\n",
|
||||
" elif suffix == '.docx':\n",
|
||||
" documents = [Document(page_content=extract_docx_text(path), metadata={'source': str(path)})]\n",
|
||||
" else:\n",
|
||||
" raise ValueError(f'Unsupported file type: {suffix}')\n",
|
||||
" except UnicodeDecodeError as exc:\n",
|
||||
" raise ValueError(f'Failed to read {path}: {exc}') from exc\n",
|
||||
"\n",
|
||||
" base_metadata = {\n",
|
||||
" metadata_key(Metadata.SOURCE): str(path),\n",
|
||||
" metadata_key(Metadata.FILE_TYPE): suffix,\n",
|
||||
" metadata_key(Metadata.TITLE): display_name or path.name,\n",
|
||||
" }\n",
|
||||
" if source_id:\n",
|
||||
" base_metadata[metadata_key(Metadata.ID)] = source_id\n",
|
||||
" if modified:\n",
|
||||
" base_metadata[metadata_key(Metadata.MODIFIED)] = modified\n",
|
||||
"\n",
|
||||
" cleaned: list[Document] = []\n",
|
||||
" for doc in documents:\n",
|
||||
" content = doc.page_content.strip()\n",
|
||||
" if not content:\n",
|
||||
" continue\n",
|
||||
" merged_metadata = {**doc.metadata, **base_metadata}\n",
|
||||
" doc.page_content = content\n",
|
||||
" doc.metadata = merged_metadata\n",
|
||||
" cleaned.append(doc)\n",
|
||||
" return cleaned\n",
|
||||
"\n",
|
||||
"def preprocess(documents: Iterable[Document]) -> list[Document]:\n",
|
||||
" return [doc for doc in documents if doc.page_content]\n",
|
||||
"\n",
|
||||
"def chunk_documents(doc: Document) -> list[Document]:\n",
|
||||
" parent_id = doc.metadata.get(metadata_key(Metadata.ID))\n",
|
||||
" if not parent_id:\n",
|
||||
" raise ValueError('Document is missing a stable identifier for chunking.')\n",
|
||||
"\n",
|
||||
" if doc.metadata.get(metadata_key(Metadata.FILE_TYPE)) == '.md':\n",
|
||||
" markdown_docs = MARKDOWN_SPLITTER.split_text(doc.page_content)\n",
|
||||
" seed_docs = [\n",
|
||||
" Document(page_content=section.page_content, metadata={**doc.metadata, **section.metadata})\n",
|
||||
" for section in markdown_docs\n",
|
||||
" ]\n",
|
||||
" else:\n",
|
||||
" seed_docs = [doc]\n",
|
||||
"\n",
|
||||
" chunks = DEFAULT_TEXT_SPLITTER.split_documents(seed_docs)\n",
|
||||
" for idx, chunk in enumerate(chunks):\n",
|
||||
" chunk.metadata[metadata_key(Metadata.PARENT_ID)] = parent_id\n",
|
||||
" chunk.metadata[metadata_key(Metadata.ID)] = f'{parent_id}::chunk-{idx:04d}'\n",
|
||||
" chunk.metadata.setdefault(metadata_key(Metadata.SOURCE), doc.metadata.get(metadata_key(Metadata.SOURCE)))\n",
|
||||
" chunk.metadata.setdefault(metadata_key(Metadata.TITLE), doc.metadata.get(metadata_key(Metadata.TITLE)))\n",
|
||||
" return chunks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0f135f35",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def sync_drive_and_index(folder_id=None, selected_types=None, file_limit=None, _state: bool = False, progress=gr.Progress(track_tqdm=False)):\n",
|
||||
" folder = (folder_id or '').strip() or None\n",
|
||||
"\n",
|
||||
" selections = selected_types if selected_types is not None else DEFAULT_FILE_TYPE_LABELS\n",
|
||||
" if not isinstance(selections, (list, tuple)):\n",
|
||||
" selections = [selections]\n",
|
||||
" selections = list(selections)\n",
|
||||
"\n",
|
||||
" if len(selections) == 0:\n",
|
||||
" yield 'Select at least one file type before syncing.', False\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" chosen_keys: list[str] = []\n",
|
||||
" for item in selections:\n",
|
||||
" key = FILE_TYPE_LABEL_TO_KEY.get(item, item)\n",
|
||||
" if key in FILE_TYPE_OPTIONS:\n",
|
||||
" chosen_keys.append(key)\n",
|
||||
"\n",
|
||||
" if not chosen_keys:\n",
|
||||
" yield 'Select at least one file type before syncing.', False\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" allowed_mime_types = sorted({mime for key in chosen_keys for mime in FILE_TYPE_OPTIONS[key]['mime_types']})\n",
|
||||
"\n",
|
||||
" limit: Optional[int] = None\n",
|
||||
" limit_warning: Optional[str] = None\n",
|
||||
" if file_limit not in (None, '', 0):\n",
|
||||
" try:\n",
|
||||
" parsed_limit = int(file_limit)\n",
|
||||
" if parsed_limit > 0:\n",
|
||||
" limit = parsed_limit\n",
|
||||
" else:\n",
|
||||
" raise ValueError\n",
|
||||
" except (TypeError, ValueError):\n",
|
||||
" limit_warning = 'File limit must be a positive integer. Syncing all matching files instead.'\n",
|
||||
"\n",
|
||||
" log_lines: list[str] = []\n",
|
||||
"\n",
|
||||
" def push(message: str) -> str:\n",
|
||||
" log_lines.append(message)\n",
|
||||
" return '\\n'.join(log_lines)\n",
|
||||
"\n",
|
||||
" if limit_warning:\n",
|
||||
" logger.warning(limit_warning)\n",
|
||||
" yield push(limit_warning), False\n",
|
||||
"\n",
|
||||
" progress(0, 'Authorizing Google Drive access...')\n",
|
||||
" yield push('Authorizing Google Drive access...'), False\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" service = build_drive_service()\n",
|
||||
" except FileNotFoundError as exc:\n",
|
||||
" error_msg = f'Error: {exc}'\n",
|
||||
" logger.error(error_msg)\n",
|
||||
" yield push(error_msg), False\n",
|
||||
" return\n",
|
||||
" except Exception as exc:\n",
|
||||
" logger.exception('Drive authorization failed')\n",
|
||||
" error_msg = f'Error authenticating with Google Drive: {exc}'\n",
|
||||
" yield push(error_msg), False\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" list_message = 'Listing documents' + (f' (limit {limit})' if limit else '') + '...'\n",
|
||||
" progress(0, list_message)\n",
|
||||
" yield push(list_message), False\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" files = list_drive_text_files(service, folder, allowed_mime_types, limit)\n",
|
||||
" except Exception as exc:\n",
|
||||
" logger.exception('Listing Drive files failed')\n",
|
||||
" error_msg = f'Error listing Google Drive files: {exc}'\n",
|
||||
" yield push(error_msg), False\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" total = len(files)\n",
|
||||
" if total == 0:\n",
|
||||
" info = 'No documents matching the selected types were found in Google Drive.'\n",
|
||||
" yield push(info), True\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" manifest = load_manifest()\n",
|
||||
" downloaded_count = 0\n",
|
||||
"\n",
|
||||
" for index, metadata in enumerate(files, start=1):\n",
|
||||
" file_id = metadata['id']\n",
|
||||
" name = metadata.get('name', file_id)\n",
|
||||
" remote_version = metadata.get('modifiedTime', '')\n",
|
||||
" manifest_entry = manifest.get(file_id)\n",
|
||||
" cache_path = cached_file_path(metadata)\n",
|
||||
" if isinstance(manifest_entry, dict) and manifest_entry.get('path'):\n",
|
||||
" cache_path = Path(manifest_entry['path'])\n",
|
||||
" cached_version = manifest_version(manifest_entry)\n",
|
||||
"\n",
|
||||
" if cached_version == remote_version and cache_path.exists():\n",
|
||||
" message = f\"{index}/{total} Skipping cached file: {name} -> {cache_path}\"\n",
|
||||
" progress(index / total, message)\n",
|
||||
" yield push(message), False\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" download_message = f\"{index}/{total} Downloading {name} -> {cache_path}\"\n",
|
||||
" progress(max((index - 0.5) / total, 0), download_message)\n",
|
||||
" yield push(download_message), False\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" downloaded_path = download_drive_file(service, metadata, manifest)\n",
|
||||
" index_message = f\"{index}/{total} Indexing {downloaded_path.name}\"\n",
|
||||
" progress(index / total, index_message)\n",
|
||||
" yield push(index_message), False\n",
|
||||
" index_document(\n",
|
||||
" downloaded_path,\n",
|
||||
" source_id=file_id,\n",
|
||||
" file_type=downloaded_path.suffix,\n",
|
||||
" modified=remote_version,\n",
|
||||
" display_name=name,\n",
|
||||
" manifest=manifest,\n",
|
||||
" )\n",
|
||||
" downloaded_count += 1\n",
|
||||
" except Exception as exc:\n",
|
||||
" error_message = f\"{index}/{total} Failed to sync {name}: {exc}\"\n",
|
||||
" logger.exception(error_message)\n",
|
||||
" progress(index / total, error_message)\n",
|
||||
" yield push(error_message), False\n",
|
||||
"\n",
|
||||
" if downloaded_count > 0:\n",
|
||||
" save_manifest(manifest)\n",
|
||||
" summary = f'Indexed {downloaded_count} new document(s) from Google Drive.'\n",
|
||||
" else:\n",
|
||||
" summary = 'Google Drive is already in sync.'\n",
|
||||
"\n",
|
||||
" progress(1, summary)\n",
|
||||
" yield push(summary), True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0e2f176b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## RAG Pipeline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "20ad0e80",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def persist_vectorstore(_store) -> None:\n",
|
||||
" \"\"\"In-memory mode: Chroma client does not persist between sessions.\"\"\"\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def index_document(\n",
|
||||
" file_path: Path | str,\n",
|
||||
" *,\n",
|
||||
" source_id: Optional[str] = None,\n",
|
||||
" file_type: Optional[str] = None,\n",
|
||||
" modified: Optional[str] = None,\n",
|
||||
" display_name: Optional[str] = None,\n",
|
||||
" manifest: Optional[dict] = None,\n",
|
||||
" ) -> tuple[str, int]:\n",
|
||||
" path = Path(file_path)\n",
|
||||
" path = path.expanduser().resolve()\n",
|
||||
" resolved_id = source_id or f'local::{hash_file(path)}'\n",
|
||||
" documents = load_documents(\n",
|
||||
" path,\n",
|
||||
" source_id=resolved_id,\n",
|
||||
" file_type=file_type,\n",
|
||||
" modified=modified,\n",
|
||||
" display_name=display_name,\n",
|
||||
" )\n",
|
||||
" documents = preprocess(documents)\n",
|
||||
" if not documents:\n",
|
||||
" logger.warning('No readable content found in %s; skipping.', path)\n",
|
||||
" return resolved_id, 0\n",
|
||||
"\n",
|
||||
" total_chunks = 0\n",
|
||||
" for doc in documents:\n",
|
||||
" doc_id = doc.metadata.get(metadata_key(Metadata.ID), resolved_id)\n",
|
||||
" doc.metadata[metadata_key(Metadata.ID)] = doc_id\n",
|
||||
" vectorstore.delete(where={metadata_key(Metadata.PARENT_ID): doc_id})\n",
|
||||
" chunks = chunk_documents(doc)\n",
|
||||
" if not chunks:\n",
|
||||
" continue\n",
|
||||
" vectorstore.add_documents(chunks)\n",
|
||||
" docstore.mset([(doc_id, doc)])\n",
|
||||
" total_chunks += len(chunks)\n",
|
||||
"\n",
|
||||
" persist_vectorstore(vectorstore)\n",
|
||||
" if manifest is not None and not source_id:\n",
|
||||
" update_manifest_entry(\n",
|
||||
" manifest,\n",
|
||||
" file_id=resolved_id,\n",
|
||||
" modified=hash_file(path),\n",
|
||||
" path=path,\n",
|
||||
" mime_type=file_type or Path(path).suffix or '.txt',\n",
|
||||
" name=display_name or path.name,\n",
|
||||
" )\n",
|
||||
" return resolved_id, total_chunks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a90db6ee",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### LLM Interaction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e2e15e99",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def retrieve_context(query: str, *, top_k: int = 8, distance_threshold: Optional[float] = SIMILARITY_DISTANCE_MAX):\n",
|
||||
" results_with_scores = vectorstore.similarity_search_with_score(query, k=top_k)\n",
|
||||
" logger.info(f'Matching records: {len(results_with_scores)}')\n",
|
||||
"\n",
|
||||
" filtered: list[tuple[Document, float]] = []\n",
|
||||
" for doc, score in results_with_scores:\n",
|
||||
" if score is None:\n",
|
||||
" continue\n",
|
||||
" score_value = float(score)\n",
|
||||
" print(f'DEBUG: Retrieved doc source={doc.metadata.get(metadata_key(Metadata.SOURCE))} distance={score_value}')\n",
|
||||
" if distance_threshold is not None and score_value > distance_threshold:\n",
|
||||
" logger.debug(\n",
|
||||
" 'Skipping %s with distance %.4f (above threshold %.4f)',\n",
|
||||
" doc.metadata.get(metadata_key(Metadata.SOURCE)),\n",
|
||||
" score_value,\n",
|
||||
" distance_threshold,\n",
|
||||
" )\n",
|
||||
" continue\n",
|
||||
" filtered.append((doc, score_value))\n",
|
||||
"\n",
|
||||
" if not filtered:\n",
|
||||
" return []\n",
|
||||
"\n",
|
||||
" for doc, score_value in filtered:\n",
|
||||
" parent_id = doc.metadata.get(metadata_key(Metadata.PARENT_ID))\n",
|
||||
" if parent_id:\n",
|
||||
" parent_doc = docstore.mget([parent_id])[0]\n",
|
||||
" if parent_doc and parent_doc.page_content:\n",
|
||||
" logger.debug(\n",
|
||||
" 'Parent preview (%s | %.3f): %s',\n",
|
||||
" doc.metadata.get(metadata_key(Metadata.SOURCE), 'unknown'),\n",
|
||||
" score_value,\n",
|
||||
" parent_doc.page_content[:400].replace('\\n', ' '),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return filtered\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def build_prompt_sections(relevant_docs: list[tuple[Document, float]]) -> str:\n",
|
||||
" sections: list[str] = []\n",
|
||||
" for idx, (doc, score) in enumerate(relevant_docs, start=1):\n",
|
||||
" source = doc.metadata.get(metadata_key(Metadata.SOURCE), 'unknown')\n",
|
||||
" snippet = doc.page_content.strip()[:MAX_CONTEXT_SNIPPET_CHARS]\n",
|
||||
" section = (\n",
|
||||
" f'[{idx}] Source: {source}\\n'\n",
|
||||
" f'Distance: {score:.3f}\\n'\n",
|
||||
" f'Content:\\n{snippet}'\n",
|
||||
" )\n",
|
||||
" sections.append(section)\n",
|
||||
" return '\\n\\n'.join(sections)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def ask(message, history):\n",
|
||||
" relevant_docs = retrieve_context(message)\n",
|
||||
" if not relevant_docs:\n",
|
||||
" yield \"I don't have enough information in the synced documents to answer that yet. Please sync additional files or adjust the filters.\"\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" context = build_prompt_sections(relevant_docs)\n",
|
||||
" prompt = f'''\n",
|
||||
" You are a retrieval-augmented assistant. Use ONLY the facts provided in the context to answer the user.\n",
|
||||
" If the context does not contain the answer, reply exactly: \"I don't have enough information in the synced documents to answer that yet. Please sync additional files.\"\n",
|
||||
" \n",
|
||||
" Context:\\n{context}\n",
|
||||
" '''\n",
|
||||
"\n",
|
||||
" messages = [\n",
|
||||
" ('system', prompt),\n",
|
||||
" ('user', message)\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" stream = model.stream(messages)\n",
|
||||
" response_text = ''\n",
|
||||
"\n",
|
||||
" for chunk in stream:\n",
|
||||
" response_text += chunk.content or ''\n",
|
||||
" if not response_text:\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" yield response_text"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c3e632dc-9e87-4510-9fcd-aa699c27e82b",
|
||||
"metadata": {
|
||||
"jp-MarkdownHeadingCollapsed": true
|
||||
},
|
||||
"source": [
|
||||
"## Gradio UI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a3d68a74",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def chat(message, history, sync_ready):\n",
|
||||
" if message is None:\n",
|
||||
" return ''\n",
|
||||
"\n",
|
||||
" text_input = message.get('text', '')\n",
|
||||
" files_uploaded = message.get('files', [])\n",
|
||||
" latest_file_path = Path(files_uploaded[-1]) if files_uploaded else None\n",
|
||||
" if latest_file_path:\n",
|
||||
" manifest = load_manifest()\n",
|
||||
" doc_id, chunk_count = index_document(\n",
|
||||
" latest_file_path,\n",
|
||||
" file_type=latest_file_path.suffix,\n",
|
||||
" display_name=latest_file_path.name,\n",
|
||||
" manifest=manifest,\n",
|
||||
" )\n",
|
||||
" save_manifest(manifest)\n",
|
||||
" logger.info('Indexed upload %s as %s with %s chunk(s)', latest_file_path, doc_id, chunk_count)\n",
|
||||
" if not text_input:\n",
|
||||
" yield f'Indexed document from upload ({chunk_count} chunk(s)).'\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" if not text_input:\n",
|
||||
" return ''\n",
|
||||
"\n",
|
||||
" if not sync_ready and not files_uploaded:\n",
|
||||
" yield 'Sync Google Drive before chatting or upload a document first.'\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" for chunk in ask(text_input, history):\n",
|
||||
" yield chunk\n",
|
||||
"\n",
|
||||
"title = \"Drive Sage\"\n",
|
||||
"with gr.Blocks(title=title, fill_height=True, css=CUSTOM_CSS) as ui:\n",
|
||||
" gr.Markdown(f'# {title}')\n",
|
||||
" gr.Markdown('## Search your Google Drive knowledge base with fully local processing.')\n",
|
||||
" sync_state = gr.State(False)\n",
|
||||
"\n",
|
||||
" with gr.Row():\n",
|
||||
" with gr.Column(scale=3, elem_id='chat-column'):\n",
|
||||
" gr.ChatInterface(\n",
|
||||
" fn=chat,\n",
|
||||
" chatbot=gr.Chatbot(height='80vh', elem_id='chat-output'),\n",
|
||||
" type='messages',\n",
|
||||
" textbox=gr.MultimodalTextbox(\n",
|
||||
" file_types=['text', '.txt', '.md'],\n",
|
||||
" autofocus=True,\n",
|
||||
" elem_id='chat-input',\n",
|
||||
" ),\n",
|
||||
" additional_inputs=[sync_state],\n",
|
||||
" )\n",
|
||||
" with gr.Column(scale=2, min_width=320):\n",
|
||||
" gr.Markdown('### Google Drive Sync')\n",
|
||||
" drive_folder = gr.Textbox(\n",
|
||||
" label='Folder ID (optional)',\n",
|
||||
" placeholder='Leave blank to scan My Drive root',\n",
|
||||
" )\n",
|
||||
" file_types = gr.CheckboxGroup(\n",
|
||||
" label='File types to sync',\n",
|
||||
" choices=[config['label'] for config in FILE_TYPE_OPTIONS.values()],\n",
|
||||
" value=DEFAULT_FILE_TYPE_LABELS,\n",
|
||||
" )\n",
|
||||
" file_limit = gr.Number(\n",
|
||||
" label='Max files to sync (leave blank for all)',\n",
|
||||
" value=20,\n",
|
||||
" )\n",
|
||||
" sync_btn = gr.Button('Sync Google Drive')\n",
|
||||
" sync_status = gr.Markdown('No sync performed yet.')\n",
|
||||
"\n",
|
||||
" sync_btn.click(\n",
|
||||
" sync_drive_and_index,\n",
|
||||
" inputs=[drive_folder, file_types, file_limit, sync_state],\n",
|
||||
" outputs=[sync_status, sync_state],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"ui.launch(debug=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user