diff --git a/week2/community-contributions/ai_domain_finder/ai_domain_finder.ipynb b/week2/community-contributions/ai_domain_finder/ai_domain_finder.ipynb new file mode 100644 index 0000000..c0fbbcc --- /dev/null +++ b/week2/community-contributions/ai_domain_finder/ai_domain_finder.ipynb @@ -0,0 +1,721 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "1633a440", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "Week 2 Assignment: LLM Engineering\n", + "Author: Nikhil Raut\n", + "\n", + "Notebook: ai_domain_finder.ipynb\n", + "\n", + "Purpose:\n", + "Build an agentic AI Domain Finder that proposes short, brandable .com names, verifies availability via RDAP, \n", + "then returns: \n", + " a list of available .coms, \n", + " one preferred pick, \n", + " and a brief audio rationale.\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da528fbe", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import requests\n", + "from typing import Dict, List, Tuple, Any, Optional\n", + "import re\n", + "\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "\n", + "load_dotenv(override=True)\n", + "\n", + "OPENAI_MODEL = \"gpt-5-nano-2025-08-07\"\n", + "TTS_MODEL = \"gpt-4o-mini-tts\"\n", + "\n", + "openai = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "361f7fe3", + "metadata": {}, + "outputs": [], + "source": [ + "# --- robust logging that works inside VS Code notebooks + Gradio threads ---\n", + "import sys, logging, threading\n", + "from collections import deque\n", + "from typing import Any\n", + "\n", + "DEBUG_LLM = True # toggle on/off noisy logs\n", + "CLEAR_LOG_ON_RUN = True # clear panel before each submit\n", + "\n", + "_LOG_BUFFER = deque(maxlen=2000) # keep ~2000 lines in memory\n", + "_LOG_LOCK = threading.Lock()\n", + "\n", + "class GradioBufferHandler(logging.Handler):\n", + " def emit(self, record: logging.LogRecord) -> None:\n", + " try:\n", + " msg = self.format(record)\n", + " except Exception:\n", + " msg = record.getMessage()\n", + " with _LOG_LOCK:\n", + " for line in (msg.splitlines() or [\"\"]):\n", + " _LOG_BUFFER.append(line)\n", + "\n", + "def get_log_text() -> str:\n", + " with _LOG_LOCK:\n", + " return \"\\n\".join(_LOG_BUFFER)\n", + "\n", + "def clear_log_buffer() -> None:\n", + " with _LOG_LOCK:\n", + " _LOG_BUFFER.clear()\n", + "\n", + "def _setup_logger() -> logging.Logger:\n", + " logger = logging.getLogger(\"aidf\")\n", + " logger.setLevel(logging.DEBUG if DEBUG_LLM else logging.INFO)\n", + " logger.handlers.clear()\n", + " fmt = logging.Formatter(\"%(asctime)s | %(levelname)s | %(message)s\", \"%H:%M:%S\")\n", + "\n", + " stream = logging.StreamHandler(stream=sys.stdout) # captured by VS Code notebook\n", + " stream.setFormatter(fmt)\n", + "\n", + " buf = GradioBufferHandler() # shown inside the Gradio panel\n", + " buf.setFormatter(fmt)\n", + "\n", + " logger.addHandler(stream)\n", + " logger.addHandler(buf)\n", + " logger.propagate = False\n", + " return logger\n", + "\n", + "logger = _setup_logger()\n", + "\n", + "def dbg_json(obj: Any, title: str = \"\") -> None:\n", + " \"\"\"Convenience: pretty-print JSON-ish objects to the logger.\"\"\"\n", + " try:\n", + " txt = json.dumps(obj, ensure_ascii=False, indent=2)\n", + " except Exception:\n", + " txt = str(obj)\n", + " if title:\n", + " logger.debug(\"%s\\n%s\", title, txt)\n", + " else:\n", + " logger.debug(\"%s\", txt)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "519674b2", + "metadata": {}, + "outputs": [], + "source": [ + "RDAP_URL = \"https://rdap.verisign.com/com/v1/domain/{}\"\n", + "\n", + "_ALPHA_RE = re.compile(r\"^[a-z]+$\", re.IGNORECASE)\n", + "\n", + "def _to_com(domain: str) -> str:\n", + " d = domain.strip().lower()\n", + " return d if d.endswith(\".com\") else f\"{d}.com\"\n", + "\n", + "def _sld_is_english_alpha(fqdn: str) -> bool:\n", + " \"\"\"\n", + " True only if the second-level label (just before .com) is made up\n", + " exclusively of English letters (a-z).\n", + " Examples:\n", + " foo.com -> True\n", + " foo-bar.com -> False\n", + " foo1.com -> False\n", + " café.com -> False\n", + " xn--cafe.com -> False\n", + " www.foo.com -> True (checks 'foo')\n", + " \"\"\"\n", + " if not fqdn.endswith(\".com\"):\n", + " return False\n", + " sld = fqdn[:-4].split(\".\")[-1] # take label immediately before .com\n", + " return bool(sld) and bool(_ALPHA_RE.fullmatch(sld))\n", + "\n", + "def check_com_availability(domain: str) -> Dict:\n", + " fqdn = _to_com(domain)\n", + " # Skip API if not strictly English letters\n", + " if not _sld_is_english_alpha(fqdn):\n", + " return {\"domain\": fqdn, \"available\": False, \"status\": 0}\n", + "\n", + " try:\n", + " r = requests.get(RDAP_URL.format(fqdn), timeout=6)\n", + " return {\"domain\": fqdn, \"available\": (r.status_code == 404), \"status\": r.status_code}\n", + " except requests.RequestException:\n", + " return {\"domain\": fqdn, \"available\": False, \"status\": 0}\n", + "\n", + "def check_com_availability_bulk(domains: List[str]) -> Dict:\n", + " \"\"\"\n", + " Input: list of domain roots or FQDNs.\n", + " Returns:\n", + " {\n", + " \"results\": [{\"domain\": \"...\", \"available\": bool, \"status\": int}, ...],\n", + " \"available\": [\"...\"], # convenience\n", + " \"count_available\": int\n", + " }\n", + " \"\"\"\n", + " session = requests.Session()\n", + " results: List[Dict] = []\n", + "\n", + " for d in domains:\n", + " fqdn = _to_com(d)\n", + "\n", + " # Skip API if not strictly English letters\n", + " if not _sld_is_english_alpha(fqdn):\n", + " results.append({\"domain\": fqdn, \"available\": False, \"status\": 0})\n", + " continue\n", + "\n", + " try:\n", + " r = session.get(RDAP_URL.format(fqdn), timeout=6)\n", + " ok = (r.status_code == 404)\n", + " results.append({\"domain\": fqdn, \"available\": ok, \"status\": r.status_code})\n", + " except requests.RequestException:\n", + " results.append({\"domain\": fqdn, \"available\": False, \"status\": 0})\n", + "\n", + " available = [x[\"domain\"] for x in results if x[\"available\"]]\n", + " return {\"results\": results, \"available\": available, \"count_available\": len(available)}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd20c262", + "metadata": {}, + "outputs": [], + "source": [ + "check_tool_bulk = {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"check_com_availability_bulk\",\n", + " \"description\": \"Batch check .com availability via RDAP for a list of domains (roots or FQDNs).\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"domains\": {\n", + " \"type\": \"array\",\n", + " \"items\": {\"type\": \"string\"},\n", + " \"minItems\": 1,\n", + " \"maxItems\": 50,\n", + " \"description\": \"List of domain roots or .com FQDNs.\"\n", + " }\n", + " },\n", + " \"required\": [\"domains\"],\n", + " \"additionalProperties\": False\n", + " }\n", + " }\n", + "}\n", + "\n", + "TOOLS = [check_tool_bulk]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a9138b6", + "metadata": {}, + "outputs": [], + "source": [ + "def handle_tool_calls(message) -> List[Dict]:\n", + " results = []\n", + " for call in (message.tool_calls or []):\n", + " fn = getattr(call.function, \"name\", None)\n", + " args_raw = getattr(call.function, \"arguments\", \"\") or \"{}\"\n", + " try:\n", + " args = json.loads(args_raw)\n", + " except Exception:\n", + " args = {}\n", + "\n", + " logger.debug(\"TOOL CALL -> %s | args=%s\", fn, json.dumps(args, ensure_ascii=False))\n", + "\n", + " if fn == \"check_com_availability_bulk\":\n", + " payload = check_com_availability_bulk(args.get(\"domains\", []))\n", + " elif fn == \"check_com_availability\":\n", + " payload = check_com_availability(args.get(\"domain\", \"\"))\n", + " else:\n", + " payload = {\"error\": f\"unknown tool {fn}\"}\n", + "\n", + " logger.debug(\"TOOL RESULT <- %s | %s\", fn, json.dumps(payload, ensure_ascii=False))\n", + "\n", + " results.append({\n", + " \"role\": \"tool\",\n", + " \"tool_call_id\": call.id,\n", + " \"content\": json.dumps(payload),\n", + " })\n", + " return results\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b80c860", + "metadata": {}, + "outputs": [], + "source": [ + "SYSTEM_PROMPT = \"\"\"You are the Agent for project \"AI Domain Finder\".\n", + "Goal: suggest .com domains and verify availability using the tool ONLY (no guessing).\n", + "\n", + "Do this each interaction:\n", + "- Generate up to ~20 short, brandable .com candidates from:\n", + " (1) Industry, (2) Target Customers, (3) Description.\n", + "- Use the BULK tool `check_com_availability_bulk` with a list of candidates\n", + " (roots or FQDNs). Prefer a single call or very few batched calls.\n", + "- If >= 5 available .coms are found, STOP checking and finalize the answer.\n", + "\n", + "Output Markdown with EXACT section headings:\n", + "1) Available .com domains:\n", + " - itemized list of available .coms only (root + .com)\n", + "2) Preferred domain:\n", + " - a single best pick\n", + "3) Audio explanation:\n", + " - 1–2 concise sentences explaining the preference\n", + "\n", + "Constraints:\n", + "- Use customer-familiar words where helpful.\n", + "- Keep names short, simple, pronounceable; avoid hyphens/numbers unless meaningful.\n", + "- Never include TLDs other than .com.\n", + "- domain is made up of english alphabets in lower case only no symbols or spaces to use\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72e9d8c2", + "metadata": {}, + "outputs": [], + "source": [ + "def _asdict_tool_call(tc: Any) -> dict:\n", + " try:\n", + " return {\n", + " \"id\": getattr(tc, \"id\", None),\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": getattr(tc.function, \"name\", None),\n", + " \"arguments\": getattr(tc.function, \"arguments\", None),\n", + " },\n", + " }\n", + " except Exception:\n", + " return {\"type\": \"function\", \"function\": {\"name\": None, \"arguments\": None}}\n", + "\n", + "def _asdict_message(msg: Any) -> dict:\n", + " if isinstance(msg, dict):\n", + " return msg\n", + " role = getattr(msg, \"role\", None)\n", + " content = getattr(msg, \"content\", None)\n", + " tool_calls = getattr(msg, \"tool_calls\", None)\n", + " out = {\"role\": role, \"content\": content}\n", + " if tool_calls:\n", + " out[\"tool_calls\"] = [_asdict_tool_call(tc) for tc in tool_calls]\n", + " return out\n", + "\n", + "def _sanitized_messages_for_log(messages: list[dict | Any]) -> list[dict]:\n", + " return [_asdict_message(m) for m in messages]\n", + "\n", + "def _limit_text(s: str, limit: int = 40000) -> str:\n", + " return s if len(s) <= limit else (s[:limit] + \"\\n... [truncated]\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b45c6382", + "metadata": {}, + "outputs": [], + "source": [ + "def run_agent_with_tools(history: List[Dict]) -> Tuple[str, List[str], str]:\n", + " \"\"\"\n", + " Returns:\n", + " reply_md: final assistant markdown\n", + " tool_available: .coms marked available by RDAP tools (order-preserving, deduped)\n", + " dbg_text: concatenated log buffer (for the UI panel)\n", + " \"\"\"\n", + " messages: List[Dict] = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}] + history\n", + " tool_available: List[str] = []\n", + "\n", + " dbg_json(_sanitized_messages_for_log(messages), \"=== LLM REQUEST (initial messages) ===\")\n", + " resp = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages, tools=TOOLS)\n", + "\n", + " while resp.choices[0].finish_reason == \"tool_calls\":\n", + " tool_msg_sdk = resp.choices[0].message\n", + " tool_msg = _asdict_message(tool_msg_sdk)\n", + " dbg_json(tool_msg, \"=== ASSISTANT (tool_calls) ===\")\n", + "\n", + " tool_results = handle_tool_calls(tool_msg_sdk)\n", + "\n", + " # Accumulate authoritative availability directly from tool outputs\n", + " for tr in tool_results:\n", + " try:\n", + " data = json.loads(tr[\"content\"])\n", + " if isinstance(data, dict) and isinstance(data.get(\"available\"), list):\n", + " for d in data[\"available\"]:\n", + " tool_available.append(_to_com(d))\n", + " except Exception:\n", + " pass\n", + "\n", + " dbg_json([json.loads(tr[\"content\"]) for tr in tool_results], \"=== TOOL RESULTS ===\")\n", + "\n", + " messages.append(tool_msg)\n", + " messages.extend(tool_results)\n", + " dbg_json(_sanitized_messages_for_log(messages), \"=== LLM REQUEST (messages + tools) ===\")\n", + "\n", + " resp = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages, tools=TOOLS)\n", + "\n", + " # Dedup preserve order\n", + " seen, uniq = set(), []\n", + " for d in tool_available:\n", + " if d not in seen:\n", + " seen.add(d)\n", + " uniq.append(d)\n", + "\n", + " reply_md = resp.choices[0].message.content\n", + " logger.debug(\"=== FINAL ASSISTANT ===\\n%s\", _limit_text(reply_md))\n", + " dbg_json(uniq, \"=== AVAILABLE FROM TOOLS (authoritative) ===\")\n", + "\n", + " # Return current buffer text for the UI panel\n", + " dbg_text = _limit_text(get_log_text(), 40000)\n", + " return reply_md, uniq, dbg_text\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92306515", + "metadata": {}, + "outputs": [], + "source": [ + "def extract_audio_text(markdown_reply: str) -> str:\n", + " \"\"\"\n", + " Pulls the 'Audio explanation:' section; falls back to first sentence.\n", + " \"\"\"\n", + " marker = \"Audio explanation:\"\n", + " lower = markdown_reply.lower()\n", + " idx = lower.find(marker.lower())\n", + " if idx != -1:\n", + " segment = markdown_reply[idx + len(marker):].strip()\n", + " parts = segment.split(\".\")\n", + " return (\". \".join([p.strip() for p in parts if p.strip()][:2]) + \".\").strip()\n", + " return \"This domain is the clearest, most memorable fit for the audience and brand goals.\"\n", + "\n", + "def synth_audio(text: str) -> bytes:\n", + " audio = openai.audio.speech.create(\n", + " model=TTS_MODEL,\n", + " voice=\"alloy\",\n", + " input=text\n", + " )\n", + " return audio.content\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc6c0650", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "_DOMAIN_RE = re.compile(r\"\\b[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\\.com\\b\", re.I)\n", + "_HDR_AVAIL = re.compile(r\"^\\s*[\\d\\.\\)\\-]*\\s*available\\s+.*\\.com\\s+domains\", re.I)\n", + "_HDR_PREF = re.compile(r\"^\\s*[\\d\\.\\)\\-]*\\s*preferred\\s+domain\", re.I)\n", + "\n", + "def _norm_domain(s: str) -> str:\n", + " s = s.strip().lower()\n", + " return s if s.endswith(\".com\") else f\"{s}.com\"\n", + "\n", + "def parse_available(md: str) -> list[str]:\n", + " lines = md.splitlines()\n", + " out = []\n", + " in_section = False\n", + " for ln in lines:\n", + " if _HDR_AVAIL.search(ln):\n", + " in_section = True\n", + " continue\n", + " if in_section and _HDR_PREF.search(ln):\n", + " break\n", + " if in_section:\n", + " for m in _DOMAIN_RE.findall(ln):\n", + " out.append(_norm_domain(m))\n", + " # Fallback: if the header wasn't found, collect all .coms then we'll still\n", + " # rely on agent instruction to list only available, which should be safe.\n", + " if not out:\n", + " out = [_norm_domain(m) for m in _DOMAIN_RE.findall(md)]\n", + " # dedupe preserve order\n", + " seen, uniq = set(), []\n", + " for d in out:\n", + " if d not in seen:\n", + " seen.add(d)\n", + " uniq.append(d)\n", + " return uniq\n", + "\n", + "def parse_preferred(md: str) -> str:\n", + " # search the preferred section first\n", + " lines = md.splitlines()\n", + " start = None\n", + " for i, ln in enumerate(lines):\n", + " if _HDR_PREF.search(ln):\n", + " start = i\n", + " break\n", + " segment = \"\\n\".join(lines[start:start+8]) if start is not None else md[:500]\n", + " m = _DOMAIN_RE.search(segment)\n", + " if m:\n", + " return _norm_domain(m.group(0))\n", + " m = _DOMAIN_RE.search(md)\n", + " return _norm_domain(m.group(0)) if m else \"\"\n", + "\n", + "def merge_and_sort(old: list[str], new: list[str]) -> list[str]:\n", + " merged = {d.lower() for d in old} | {d.lower() for d in new}\n", + " return sorted(merged, key=lambda s: (len(s), s))\n", + "\n", + "def fmt_available_md(domains: list[str]) -> str:\n", + " if not domains:\n", + " return \"### Available .com domains (cumulative)\\n\\n*– none yet –*\"\n", + " items = \"\\n\".join(f\"- `{d}`\" for d in domains)\n", + " return f\"### Available .com domains (cumulative)\\n\\n{items}\"\n", + "\n", + "def fmt_preferred_md(d: str) -> str:\n", + " if not d:\n", + " return \"### Preferred domain\\n\\n*– not chosen yet –*\"\n", + " return f\"### Preferred domain\\n\\n`{d}`\"\n", + "\n", + "def build_context_msg(known_avail: Optional[List[str]], preferred_now: Optional[str]) -> str:\n", + " \"\"\"\n", + " Create a short 'state so far' block that we prepend to the next user turn\n", + " so the model always sees the preferred and cumulative available list.\n", + " \"\"\"\n", + " lines = []\n", + " if (preferred_now or \"\").strip():\n", + " lines.append(f\"Preferred domain so far: {preferred_now.strip().lower()}\")\n", + " if known_avail:\n", + " lines.append(\"Available .com domains discovered so far:\")\n", + " for d in known_avail:\n", + " if d:\n", + " lines.append(f\"- {d.strip().lower()}\")\n", + " if not lines:\n", + " return \"\"\n", + " return \"STATE TO CARRY OVER FROM PREVIOUS TURNS:\\n\" + \"\\n\".join(lines)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07f079d6", + "metadata": {}, + "outputs": [], + "source": [ + "def run_and_extract(history: List[Dict]) -> Tuple[str, List[str], str, str, str]:\n", + " reply_md, avail_from_tools, dbg_text = run_agent_with_tools(history)\n", + " parsed_avail = parse_available(reply_md)\n", + " new_avail = merge_and_sort(avail_from_tools, parsed_avail)\n", + " preferred = parse_preferred(reply_md)\n", + " audio_text = extract_audio_text(reply_md)\n", + " return reply_md, new_avail, preferred, audio_text, dbg_text\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cd5d8ef", + "metadata": {}, + "outputs": [], + "source": [ + "def initial_submit(industry: str, customers: str, desc: str,\n", + " history: List[Dict], known_avail: List[str], preferred_now: str):\n", + " if CLEAR_LOG_ON_RUN:\n", + " clear_log_buffer()\n", + "\n", + " logger.info(\"Initial submit | industry=%r | customers=%r | desc_len=%d\",\n", + " industry, customers, len(desc or \"\"))\n", + "\n", + " # Build context (usually empty on the very first run, but future inits also work)\n", + " ctx = build_context_msg(known_avail or [], preferred_now or \"\")\n", + "\n", + " user_msg = (\n", + " \"Please propose .com domains based on:\\n\"\n", + " f\"Industry: {industry}\\n\"\n", + " f\"Target Customers: {customers}\\n\"\n", + " f\"Description: {desc}\"\n", + " )\n", + "\n", + " # Single user turn that includes state + prompt so the model always sees memory\n", + " full_content = (ctx + \"\\n\\n\" if ctx else \"\") + user_msg\n", + "\n", + " history = (history or []) + [{\"role\": \"user\", \"content\": full_content}]\n", + " reply_md, new_avail, preferred, audio_text, dbg_text = run_and_extract(history)\n", + " history += [{\"role\": \"assistant\", \"content\": reply_md}]\n", + "\n", + " all_avail = merge_and_sort(known_avail or [], new_avail or [])\n", + " preferred_final = preferred or preferred_now or \"\"\n", + " audio_bytes = synth_audio(audio_text)\n", + "\n", + " return (\n", + " history, # s_history\n", + " all_avail, # s_available (cumulative)\n", + " preferred_final, # s_preferred\n", + " gr.update(value=fmt_preferred_md(preferred_final)),\n", + " gr.update(value=fmt_available_md(all_avail)),\n", + " gr.update(value=\"\", visible=True), # reply_in: show after first run\n", + " gr.update(value=audio_bytes, visible=True), # audio_out\n", + " gr.update(value=dbg_text), # debug_box\n", + " gr.update(value=\"Find Domains (done)\", interactive=False), # NEW: disable Find\n", + " gr.update(visible=True), # NEW: show Send button\n", + " )\n", + "\n", + "def refine_submit(reply: str,\n", + " history: List[Dict], known_avail: List[str], preferred_now: str):\n", + " # If empty, do nothing (keeps UI state untouched)\n", + " if not (reply or \"\").strip():\n", + " return (\"\", history, known_avail, preferred_now,\n", + " gr.update(), gr.update(), gr.update(), gr.update())\n", + "\n", + " if CLEAR_LOG_ON_RUN:\n", + " clear_log_buffer()\n", + " logger.info(\"Refine submit | user_reply_len=%d\", len(reply))\n", + "\n", + " # Always prepend memory + the user's refinement so the model can iterate properly\n", + " ctx = build_context_msg(known_avail or [], preferred_now or \"\")\n", + " full_content = (ctx + \"\\n\\n\" if ctx else \"\") + reply.strip()\n", + "\n", + " history = (history or []) + [{\"role\": \"user\", \"content\": full_content}]\n", + " reply_md, new_avail, preferred, audio_text, dbg_text = run_and_extract(history)\n", + " history += [{\"role\": \"assistant\", \"content\": reply_md}]\n", + "\n", + " all_avail = merge_and_sort(known_avail or [], new_avail or [])\n", + " preferred_final = preferred or preferred_now or \"\"\n", + " audio_bytes = synth_audio(audio_text)\n", + "\n", + " return (\n", + " \"\", # clear Reply box\n", + " history, # s_history\n", + " all_avail, # s_available (cumulative)\n", + " preferred_final, # s_preferred\n", + " gr.update(value=fmt_preferred_md(preferred_final)),\n", + " gr.update(value=fmt_available_md(all_avail)),\n", + " gr.update(value=audio_bytes, visible=True),\n", + " gr.update(value=dbg_text), # debug_box\n", + " )\n", + "\n", + "def clear_debug():\n", + " clear_log_buffer()\n", + " return gr.update(value=\"\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d52ebc02", + "metadata": {}, + "outputs": [], + "source": [ + "with gr.Blocks(title=\"AI Domain Finder (.com only)\") as ui:\n", + " gr.Markdown(\"# AI Domain Finder (.com only)\")\n", + " gr.Markdown(\"Agent proposes .com domains, verifies via RDAP, picks a preferred choice, and explains briefly.\")\n", + "\n", + " # App state\n", + " s_history = gr.State([])\n", + " s_available = gr.State([])\n", + " s_preferred = gr.State(\"\")\n", + "\n", + " with gr.Row():\n", + " with gr.Column(scale=7): # LEFT 70%\n", + " with gr.Group():\n", + " industry_in = gr.Textbox(label=\"Industry\")\n", + " customers_in = gr.Textbox(label=\"Target Customers\")\n", + " desc_in = gr.Textbox(label=\"Description\", lines=3)\n", + " find_btn = gr.Button(\"Find Domains\", variant=\"primary\")\n", + "\n", + " audio_out = gr.Audio(label=\"Audio explanation\", autoplay=True, visible=False)\n", + "\n", + " with gr.Row():\n", + " reply_in = gr.Textbox(\n", + " label=\"Reply\",\n", + " placeholder=\"Chat with the agent to refine the outputs\",\n", + " lines=2,\n", + " visible=False, # hidden for the first input\n", + " )\n", + " send_btn = gr.Button(\"Send\", variant=\"primary\", visible=False)\n", + "\n", + " with gr.Column(scale=3): # RIGHT 30%\n", + " preferred_md = gr.Markdown(fmt_preferred_md(\"\"))\n", + " available_md = gr.Markdown(fmt_available_md([]))\n", + "\n", + " with gr.Accordion(\"Debug log\", open=False):\n", + " debug_box = gr.Textbox(label=\"Log\", value=\"\", lines=16, interactive=False)\n", + " clear_btn = gr.Button(\"Clear log\", size=\"sm\")\n", + "\n", + " # Events\n", + " # Initial run: also disables Find and shows Send\n", + " find_btn.click(\n", + " initial_submit,\n", + " inputs=[industry_in, customers_in, desc_in, s_history, s_available, s_preferred],\n", + " outputs=[\n", + " s_history, s_available, s_preferred,\n", + " preferred_md, available_md,\n", + " reply_in, # visible after first run\n", + " audio_out, # visible after first run\n", + " debug_box,\n", + " find_btn, # NEW: disable + relabel\n", + " send_btn, # NEW: show the Send button\n", + " ],\n", + " )\n", + "\n", + " # Multi-turn submit via Enter in the textbox\n", + " reply_in.submit(\n", + " refine_submit,\n", + " inputs=[reply_in, s_history, s_available, s_preferred],\n", + " outputs=[\n", + " reply_in, s_history, s_available, s_preferred,\n", + " preferred_md, available_md, audio_out, debug_box\n", + " ],\n", + " )\n", + "\n", + " # Multi-turn submit via explicit Send button\n", + " send_btn.click(\n", + " refine_submit,\n", + " inputs=[reply_in, s_history, s_available, s_preferred],\n", + " outputs=[\n", + " reply_in, s_history, s_available, s_preferred,\n", + " preferred_md, available_md, audio_out, debug_box\n", + " ],\n", + " )\n", + "\n", + " clear_btn.click(clear_debug, inputs=[], outputs=[debug_box])\n", + "\n", + "ui.launch(inbrowser=True, show_error=True)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llm-engineering", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week2/community-contributions/hopeogbons/README.md b/week2/community-contributions/hopeogbons/README.md new file mode 100644 index 0000000..953a7e5 --- /dev/null +++ b/week2/community-contributions/hopeogbons/README.md @@ -0,0 +1,355 @@ +# 🏥 RoboCare AI Assistant + +> Born from a real problem at MyWoosah Inc—now solving caregiver matching through AI. + +## 📋 The Story Behind This Project + +While working on a caregiver matching platform for **MyWoosah Inc** in the US, I faced a real challenge: how do you efficiently match families with the right caregivers when everyone has different needs? + +Families would ask things like: + +- _"I need someone for my mom on Monday mornings who speaks Spanish"_ +- _"Can you find elder care in Boston under $30/hour with CPR certification?"_ + +Writing individual SQL queries for every combination of filters was exhausting and error-prone. I knew there had to be a better way. + +That's when I discovered the **Andela LLM Engineering program**. I saw an opportunity to transform this problem into a solution using AI. Instead of rigid queries, what if families could just... talk? And the AI would understand, search, and recommend? + +This project is my answer. It's not just an exercise—it's solving a real problem I encountered in the field. + +## What It Does + +RoboCare helps families find caregivers through natural conversation: + +- 🔍 Searches the database intelligently +- 🎯 Finds the best matches +- 💬 Explains pros/cons in plain English +- 🔊 Speaks the results back to you + +## ✨ Features + +### 🤖 AI-Powered Matching + +- Natural language conversation interface +- Intelligent requirement gathering +- Multi-criteria search optimization +- Personalized recommendations with pros/cons analysis + +### 🔍 Advanced Search Capabilities + +- **Location-based filtering**: City, state, and country +- **Service type matching**: Elder care, child care, companionship, dementia care, hospice support, and more +- **Availability scheduling**: Day and time-based matching +- **Budget optimization**: Maximum hourly rate filtering +- **Language preferences**: Multi-language support +- **Certification requirements**: CPR, CNA, BLS, and specialized certifications +- **Experience filtering**: Minimum years of experience + +### 🎙️ Multi-Modal Interface + +- Text-based chat interface +- Voice response generation (Text-to-Speech) +- Multiple voice options (coral, alloy, echo, fable, onyx, nova, shimmer) +- Clean, modern UI built with Gradio + +### 🛡️ Defensive Architecture + +- Comprehensive error handling +- Token overflow protection +- Tool call validation +- Graceful degradation + +## 🚀 Getting Started + +### Prerequisites + +- Python 3.8+ +- OpenAI API key +- Virtual environment (recommended) + +### Installation + +1. **Clone the repository** + + ```bash + cd week2 + ``` + +2. **Create and activate virtual environment** + + ```bash + python -m venv .venv + source .venv/bin/activate # On Windows: .venv\Scripts\activate + ``` + +3. **Install dependencies** + + ```bash + pip install -r requirements.txt + ``` + +4. **Set up environment variables** + + Create a `.env` file in the project root: + + ```env + OPENAI_API_KEY=your_openai_api_key_here + ``` + +5. **Run the application** + + ```bash + jupyter notebook "week2 EXERCISE.ipynb" + ``` + + Or run all cells sequentially in your Jupyter environment. + +## 📊 Database Schema + +### Tables + +#### `caregivers` + +Primary caregiver information including: + +- Personal details (name, gender) +- Experience level +- Hourly rate and currency +- Location (city, state, country, coordinates) +- Live-in availability + +#### `caregiver_services` + +Care types offered by each caregiver: + +- Elder care +- Child care +- Companionship +- Post-op support +- Special needs +- Respite care +- Dementia care +- Hospice support + +#### `availability` + +Time slots when caregivers are available: + +- Day of week (Mon-Sun) +- Start and end times (24-hour format) + +#### `languages` + +Languages spoken by caregivers + +#### `certifications` + +Professional certifications (CPR, CNA, BLS, etc.) + +#### `traits` + +Personality and professional traits + +## 🔧 Architecture + +### Tool Registry Pattern + +```python +TOOL_REGISTRY = { + "search_caregivers": search_caregivers, + "get_caregiver_profile": get_caregiver_profile, + # ... more tools +} +``` + +All database functions are registered and can be called by the AI dynamically. + +### Search Functions + +#### `search_caregivers()` + +Multi-filter search with parameters: + +- `city`, `state_province`, `country` - Location filters +- `care_type` - Type of care needed +- `min_experience` - Minimum years of experience +- `max_hourly_rate` - Budget constraint +- `live_in` - Live-in caregiver requirement +- `language` - Language preference +- `certification` - Required certification +- `day` - Day of week availability +- `time_between` - Time window availability +- `limit`, `offset` - Pagination + +#### `get_caregiver_profile(caregiver_id)` + +Returns complete profile including: + +- Basic information +- Services offered +- Languages spoken +- Certifications +- Personality traits +- Availability schedule + +## 🎨 UI Components + +### Main Interface + +- **Chat History**: Message-based conversation display +- **Voice Response**: Auto-playing audio output +- **Settings Sidebar**: + - AI Model selector + - Voice selection + - Audio toggle + - Clear conversation button + +### User Experience + +- Professional gradient header +- Collapsible instructions +- Helpful placeholder text +- Custom CSS styling +- Responsive layout + +## 📝 Usage Examples + +### Example 1: Basic Search + +```python +results = search_caregivers( + city="New York", + care_type="elder care", + max_hourly_rate=30.0, + limit=5 +) +``` + +### Example 2: Language Filter + +```python +results = search_caregivers( + care_type="child care", + language="Spanish", + limit=3 +) +``` + +### Example 3: Availability Search + +```python +results = search_caregivers( + day="Mon", + time_between=("09:00", "17:00"), + city="Boston" +) +``` + +### Example 4: Get Full Profile + +```python +profile = get_caregiver_profile(caregiver_id=1) +print(profile['services']) +print(profile['availability']) +``` + +## 🔐 Security & Best Practices + +### Current Implementation + +- ✅ Environment variable management for API keys +- ✅ SQL injection prevention (parameterized queries) +- ✅ Error handling and graceful degradation +- ✅ Input validation through tool schemas + +### Important Disclaimers + +⚠️ **This is a demonstration application** + +- Credentials and background checks are NOT verified +- Families should independently verify all caregiver information +- Not intended for production use without additional security measures + +## 🛠️ Tech Stack + +- **AI/ML**: OpenAI GPT-4o-mini, Text-to-Speech API +- **Database**: SQLite with normalized schema +- **UI Framework**: Gradio +- **Language**: Python 3.8+ +- **Key Libraries**: + - `openai` - OpenAI API client + - `gradio` - Web interface + - `python-dotenv` - Environment management + - `sqlite3` - Database operations + +## 📈 What's Next + +### Immediate Plans + +- [ ] Add speech input (families could call and talk) +- [ ] Connect to actual MyWoosah database +- [ ] Background check API integration +- [ ] Deploy for real users + +### Future Enhancements + +- [ ] Streaming responses for real-time interaction +- [ ] Dynamic model switching +- [ ] User authentication and profiles +- [ ] Review and rating system +- [ ] Payment integration +- [ ] Calendar integration for scheduling + +## 💡 Key Learnings + +Through building this project, I learned: + +1. **Prompt engineering is critical** - Small keyword mismatches = zero results. Mapping "Monday" → "Mon" matters. +2. **Function calling is powerful** - Eliminated the need for custom queries. The AI figures it out. +3. **Defensive programming saves headaches** - Things break. This code expects it and handles it elegantly. +4. **AI makes databases accessible** - Good database design + AI = natural language interface + +## 🌍 The Bigger Picture + +This isn't just about caregiving. The same pattern works for: + +- Healthcare appointment booking +- Legal service matching +- Tutoring and education platforms +- Real estate agent matching +- Any matching problem where natural language beats forms + +**AI doesn't replace good database design—it makes it accessible to everyone.** + +--- + +## 🤝 Contributing + +This project was created as part of the **Andela LLM Engineering Week 2 Exercise**. + +Feedback and contributions are welcome! Feel free to: + +1. Fork the repository +2. Create a feature branch +3. Make your changes +4. Run all cells to test +5. Submit a pull request + +## 🙏 Acknowledgments + +- **MyWoosah Inc** - For the real-world problem that inspired this solution +- **Andela LLM Engineering Program** - Educational framework and guidance +- **OpenAI** - GPT-4o and TTS API +- **Gradio** - Making beautiful UIs accessible + +--- + +
+ +**For MyWoosah Inc and beyond:** This is proof that AI can transform how we connect people with the care they need. + +_Built with ❤️ during Week 2 of the Andela LLM Engineering Program_ + +**RoboOffice Ltd** + +
diff --git a/week2/community-contributions/hopeogbons/care_app.db b/week2/community-contributions/hopeogbons/care_app.db new file mode 100644 index 0000000..93f8fdb Binary files /dev/null and b/week2/community-contributions/hopeogbons/care_app.db differ diff --git a/week2/community-contributions/hopeogbons/week2 EXERCISE.ipynb b/week2/community-contributions/hopeogbons/week2 EXERCISE.ipynb new file mode 100644 index 0000000..6915f24 --- /dev/null +++ b/week2/community-contributions/hopeogbons/week2 EXERCISE.ipynb @@ -0,0 +1,1525 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d006b2ea-9dfe-49c7-88a9-a5a0775185fd", + "metadata": {}, + "source": [ + "# 🏥 RoboCare AI Assistant\n", + "\n", + "## Why I Built This\n", + "\n", + "While working on a caregiver matching platform for **MyWoosah Inc** in the US, I faced a real challenge: how do you efficiently match families with the right caregivers when everyone has different needs?\n", + "\n", + "Families would ask things like:\n", + "- *\"I need someone for my mom on Monday mornings who speaks Spanish\"*\n", + "- *\"Can you find elder care in Boston under $30/hour with CPR certification?\"*\n", + "\n", + "Writing individual SQL queries for every combination of filters was exhausting and error-prone. I knew there had to be a better way.\n", + "\n", + "That's when I discovered the **Andela LLM Engineering program**. I saw an opportunity to transform this problem into a solution using AI. Instead of rigid queries, what if families could just... talk? And the AI would understand, search, and recommend?\n", + "\n", + "This project is my answer. It's not just an exercise—it's solving a real problem I encountered in the field.\n", + "\n", + "---\n", + "\n", + "## What This Does\n", + "\n", + "RoboCare helps families find caregivers through natural conversation. You tell it what you need, and it:\n", + "- 🔍 Searches the database intelligently\n", + "- 🎯 Finds the best matches\n", + "- 💬 Explains pros/cons in plain English \n", + "- 🔊 Speaks the results back to you\n", + "\n", + "**Tech:** OpenAI GPT-4o + Voice • Gradio UI • SQLite Database • Function Calling\n", + "\n", + "---\n", + "\n", + "**Note:** This is a demonstration. Always verify credentials independently." + ] + }, + { + "cell_type": "markdown", + "id": "4381c40c", + "metadata": {}, + "source": [ + "## Step 1: Libraries\n", + "\n", + "The essentials: OpenAI for the AI brain, Gradio for the interface, SQLite for data storage.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "185c6841", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "\n", + "import os\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "import sqlite3\n", + "import sqlite3\n", + "from textwrap import dedent\n", + "from contextlib import contextmanager\n", + "from typing import Optional, List, Dict, Any, Tuple" + ] + }, + { + "cell_type": "markdown", + "id": "2a366c15", + "metadata": {}, + "source": [ + "## Step 2: Setup\n", + "\n", + "Loading API keys securely (never hardcode them!), setting up the OpenAI client, and pointing to our database.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "0e731b96", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OpenAI API Key exists and begins sk-proj-\n" + ] + } + ], + "source": [ + "# Initialization\n", + "\n", + "load_dotenv(override=True)\n", + "\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + " \n", + "MODEL = \"gpt-4o-mini\"\n", + "openai = OpenAI()\n", + "\n", + "DB_PATH = \"care_app.db\"" + ] + }, + { + "cell_type": "markdown", + "id": "686fa27a", + "metadata": {}, + "source": [ + "## Step 3: The Database\n", + "\n", + "20 sample caregivers across major US cities with:\n", + "- Services they offer (elder care, child care, etc.)\n", + "- Languages, certifications, availability\n", + "- Personality traits\n", + "- Realistic pricing and schedules\n", + "\n", + "This mirrors the kind of data MyWoosah Inc would manage—except here, AI does the matching work.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "965d273d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Seeded: care_app.db\n" + ] + } + ], + "source": [ + "# Table creation and seeding\n", + "\n", + "SQL = '''\n", + "\n", + " CREATE TABLE IF NOT EXISTS caregivers (\n", + " id INTEGER PRIMARY KEY,\n", + " name TEXT NOT NULL,\n", + " gender TEXT,\n", + " years_experience INTEGER,\n", + " live_in INTEGER, -- 0/1\n", + " hourly_rate REAL,\n", + " currency TEXT,\n", + " city TEXT,\n", + " state_province TEXT,\n", + " country TEXT,\n", + " postal_code TEXT,\n", + " lat REAL,\n", + " lon REAL\n", + " );\n", + "\n", + " CREATE TABLE IF NOT EXISTS caregiver_services (\n", + " caregiver_id INTEGER,\n", + " care_type TEXT,\n", + " FOREIGN KEY (caregiver_id) REFERENCES caregivers(id)\n", + " );\n", + "\n", + " CREATE TABLE IF NOT EXISTS availability (\n", + " caregiver_id INTEGER,\n", + " day TEXT, -- e.g., 'Mon'\n", + " time_start TEXT, -- 'HH:MM'\n", + " time_end TEXT, -- 'HH:MM'\n", + " FOREIGN KEY (caregiver_id) REFERENCES caregivers(id)\n", + " );\n", + "\n", + " CREATE TABLE IF NOT EXISTS languages (\n", + " caregiver_id INTEGER,\n", + " language TEXT,\n", + " FOREIGN KEY (caregiver_id) REFERENCES caregivers(id)\n", + " );\n", + "\n", + " CREATE TABLE IF NOT EXISTS certifications (\n", + " caregiver_id INTEGER,\n", + " cert TEXT,\n", + " FOREIGN KEY (caregiver_id) REFERENCES caregivers(id)\n", + " );\n", + "\n", + " CREATE TABLE IF NOT EXISTS traits (\n", + " caregiver_id INTEGER,\n", + " trait TEXT,\n", + " FOREIGN KEY (caregiver_id) REFERENCES caregivers(id)\n", + " );\n", + "\n", + " ----------------------------------------------------------\n", + "\n", + " -- Clear old data (optional)\n", + "\n", + " DELETE FROM traits;\n", + " DELETE FROM certifications;\n", + " DELETE FROM languages;\n", + " DELETE FROM availability;\n", + " DELETE FROM caregiver_services;\n", + " DELETE FROM caregivers;\n", + "\n", + " -- Seed caregivers (20 examples, all USA)\n", + "\n", + " INSERT INTO caregivers\n", + " (id, name, gender, years_experience, live_in, hourly_rate, currency, city, state_province, country, postal_code, lat, lon)\n", + " VALUES\n", + " (1, 'Grace Williams', 'female', 6, 0, 28, 'USD', 'New York', 'NY', 'USA', '10001', 40.7128, -74.0060),\n", + " (2, 'Miguel Alvarez', 'male', 9, 1, 30, 'USD', 'Los Angeles', 'CA', 'USA', '90012', 34.0522, -118.2437),\n", + " (3, 'Ava Johnson', 'female', 4, 0, 24, 'USD', 'Chicago', 'IL', 'USA', '60601', 41.8781, -87.6298),\n", + " (4, 'Noah Robinson', 'male', 12, 0, 27, 'USD', 'Houston', 'TX', 'USA', '77002', 29.7604, -95.3698),\n", + " (5, 'Sophia Martinez', 'female', 8, 0, 29, 'USD', 'Phoenix', 'AZ', 'USA', '85004', 33.4484, -112.0740),\n", + " (6, 'Daniel Carter', 'male', 10, 1, 31, 'USD', 'Philadelphia', 'PA', 'USA', '19103', 39.9526, -75.1652),\n", + " (7, 'Emily Nguyen', 'female', 7, 0, 26, 'USD', 'San Antonio', 'TX', 'USA', '78205', 29.4241, -98.4936),\n", + " (8, 'Olivia Kim', 'female', 5, 0, 27, 'USD', 'San Diego', 'CA', 'USA', '92101', 32.7157, -117.1611),\n", + " (9, 'James Thompson', 'male', 15, 1, 34, 'USD', 'Dallas', 'TX', 'USA', '75201', 32.7767, -96.7970),\n", + " (10, 'Isabella Garcia', 'female', 3, 0, 22, 'USD', 'San Jose', 'CA', 'USA', '95113', 37.3382, -121.8863),\n", + " (11, 'Ethan Patel', 'male', 11, 1, 33, 'USD', 'Austin', 'TX', 'USA', '78701', 30.2672, -97.7431),\n", + " (12, 'Harper Brooks', 'female', 2, 0, 20, 'USD', 'Jacksonville', 'FL', 'USA', '32202', 30.3322, -81.6557),\n", + " (13, 'Logan White', 'male', 6, 0, 25, 'USD', 'Fort Worth', 'TX', 'USA', '76102', 32.7555, -97.3308),\n", + " (14, 'Amelia Davis', 'female', 9, 0, 28, 'USD', 'Columbus', 'OH', 'USA', '43215', 39.9612, -82.9988),\n", + " (15, 'Charlotte Reed', 'female', 14, 1, 32, 'USD', 'Charlotte', 'NC', 'USA', '28202', 35.2271, -80.8431),\n", + " (16, 'Jackson Lee', 'male', 5, 0, 26, 'USD', 'San Francisco', 'CA', 'USA', '94102', 37.7749, -122.4194),\n", + " (17, 'Avery Chen', 'female', 7, 0, 27, 'USD', 'Seattle', 'WA', 'USA', '98101', 47.6062, -122.3321),\n", + " (18, 'William Turner', 'male', 13, 1, 35, 'USD', 'Denver', 'CO', 'USA', '80202', 39.7392, -104.9903),\n", + " (19, 'Natalie O''Brien', 'female', 16, 0, 36, 'USD', 'Boston', 'MA', 'USA', '02108', 42.3601, -71.0589),\n", + " (20, 'Maya Robinson', 'female', 3, 0, 23, 'USD', 'Atlanta', 'GA', 'USA', '30303', 33.7488, -84.3880);\n", + "\n", + " -- Seed caregiver services\n", + "\n", + " INSERT INTO caregiver_services (caregiver_id, care_type) VALUES\n", + " (1, 'elder care'), (1, 'companionship'),\n", + " (2, 'post-op support'), (2, 'elder care'),\n", + " (3, 'child care'), (3, 'special needs'),\n", + " (4, 'respite care'), (4, 'elder care'),\n", + " (5, 'dementia care'), (5, 'companionship'),\n", + " (6, 'elder care'), (6, 'hospice support'),\n", + " (7, 'child care'), (7, 'respite care'),\n", + " (8, 'post-op support'), (8, 'companionship'),\n", + " (9, 'special needs'), (9, 'elder care'),\n", + " (10, 'child care'), (10, 'companionship'),\n", + " (11, 'dementia care'), (11, 'post-op support'),\n", + " (12, 'child care'), (12, 'special needs'),\n", + " (13, 'respite care'), (13, 'companionship'),\n", + " (14, 'elder care'), (14, 'post-op support'),\n", + " (15, 'hospice support'), (15, 'dementia care'),\n", + " (16, 'elder care'), (16, 'respite care'),\n", + " (17, 'special needs'), (17, 'companionship'),\n", + " (18, 'post-op support'), (18, 'elder care'),\n", + " (19, 'dementia care'), (19, 'hospice support'),\n", + " (20, 'child care'), (20, 'companionship');\n", + "\n", + " -- Seed availability (Mon-Sun samples)\n", + "\n", + " INSERT INTO availability (caregiver_id, day, time_start, time_end) VALUES\n", + " -- 1 Grace (NY): evenings + Sun\n", + " (1, 'Mon', '17:30', '22:00'),\n", + " (1, 'Thu', '17:30', '22:00'),\n", + " (1, 'Sun', '10:00', '16:00'),\n", + " -- 2 Miguel (LA): live-in, long blocks\n", + " (2, 'Tue', '08:00', '20:00'),\n", + " (2, 'Thu', '08:00', '20:00'),\n", + " (2, 'Sat', '09:00', '18:00'),\n", + " -- 3 Ava (CHI): weekdays 09-17\n", + " (3, 'Mon', '09:00', '17:00'),\n", + " (3, 'Wed', '09:00', '17:00'),\n", + " (3, 'Fri', '09:00', '17:00'),\n", + " -- 4 Noah (HOU): Tue-Fri 08-16\n", + " (4, 'Tue', '08:00', '16:00'),\n", + " (4, 'Wed', '08:00', '16:00'),\n", + " (4, 'Thu', '08:00', '16:00'),\n", + " -- 5 Sophia (PHX): Thu-Sun 10-18\n", + " (5, 'Thu', '10:00', '18:00'),\n", + " (5, 'Fri', '10:00', '18:00'),\n", + " (5, 'Sat', '10:00', '18:00'),\n", + " -- 6 Daniel (PHL): Mon-Thu 07-15\n", + " (6, 'Mon', '07:00', '15:00'),\n", + " (6, 'Tue', '07:00', '15:00'),\n", + " (6, 'Thu', '07:00', '15:00'),\n", + " -- 7 Emily (SAT): weekends\n", + " (7, 'Sat', '08:00', '17:00'),\n", + " (7, 'Sun', '09:00', '17:00'),\n", + " (7, 'Fri', '17:00', '21:00'),\n", + " -- 8 Olivia (SD): Mon, Wed evenings\n", + " (8, 'Mon', '16:00', '21:00'),\n", + " (8, 'Wed', '16:00', '21:00'),\n", + " (8, 'Sat', '10:00', '14:00'),\n", + " -- 9 James (DAL): live-in wide\n", + " (9, 'Mon', '07:00', '19:00'),\n", + " (9, 'Wed', '07:00', '19:00'),\n", + " (9, 'Sun', '09:00', '17:00'),\n", + " -- 10 Isabella (SJ): Tue-Thu 12-20\n", + " (10, 'Tue', '12:00', '20:00'),\n", + " (10, 'Wed', '12:00', '20:00'),\n", + " (10, 'Thu', '12:00', '20:00'),\n", + " -- 11 Ethan (ATX): nights\n", + " (11, 'Mon', '18:00', '23:00'),\n", + " (11, 'Tue', '18:00', '23:00'),\n", + " (11, 'Fri', '18:00', '23:00'),\n", + " -- 12 Harper (JAX): school hours\n", + " (12, 'Mon', '09:00', '14:00'),\n", + " (12, 'Wed', '09:00', '14:00'),\n", + " (12, 'Fri', '09:00', '14:00'),\n", + " -- 13 Logan (FTW): Thu-Sat\n", + " (13, 'Thu', '10:00', '18:00'),\n", + " (13, 'Fri', '10:00', '18:00'),\n", + " (13, 'Sat', '10:00', '18:00'),\n", + " -- 14 Amelia (CMH): Mon-Fri 08-16\n", + " (14, 'Mon', '08:00', '16:00'),\n", + " (14, 'Tue', '08:00', '16:00'),\n", + " (14, 'Thu', '08:00', '16:00'),\n", + " -- 15 Charlotte (CLT): live-in style\n", + " (15, 'Tue', '07:00', '19:00'),\n", + " (15, 'Thu', '07:00', '19:00'),\n", + " (15, 'Sat', '08:00', '16:00'),\n", + " -- 16 Jackson (SF): split shifts\n", + " (16, 'Mon', '07:00', '11:00'),\n", + " (16, 'Mon', '17:00', '21:00'),\n", + " (16, 'Sat', '12:00', '18:00'),\n", + " -- 17 Avery (SEA): Tue/Thu + Sun\n", + " (17, 'Tue', '10:00', '18:00'),\n", + " (17, 'Thu', '10:00', '18:00'),\n", + " (17, 'Sun', '11:00', '17:00'),\n", + " -- 18 William (DEN): Mon-Wed 06-14\n", + " (18, 'Mon', '06:00', '14:00'),\n", + " (18, 'Tue', '06:00', '14:00'),\n", + " (18, 'Wed', '06:00', '14:00'),\n", + " -- 19 Natalie (BOS): Tue-Fri 09-17\n", + " (19, 'Tue', '09:00', '17:00'),\n", + " (19, 'Wed', '09:00', '17:00'),\n", + " (19, 'Fri', '09:00', '17:00'),\n", + " -- 20 Maya (ATL): after-school + Sat\n", + " (20, 'Mon', '15:00', '20:00'),\n", + " (20, 'Wed', '15:00', '20:00'),\n", + " (20, 'Sat', '09:00', '15:00');\n", + "\n", + " -- Seed languages\n", + "\n", + " INSERT INTO languages (caregiver_id, language) VALUES\n", + " (1, 'English'), (1, 'Spanish'),\n", + " (2, 'English'), (2, 'Spanish'),\n", + " (3, 'English'),\n", + " (4, 'English'),\n", + " (5, 'English'), (5, 'Spanish'),\n", + " (6, 'English'),\n", + " (7, 'English'), (7, 'Vietnamese'),\n", + " (8, 'English'), (8, 'Korean'),\n", + " (9, 'English'),\n", + " (10,'English'), (10,'Spanish'),\n", + " (11,'English'), (11,'Hindi'),\n", + " (12,'English'),\n", + " (13,'English'),\n", + " (14,'English'), (14,'French'),\n", + " (15,'English'),\n", + " (16,'English'), (16,'Tagalog'),\n", + " (17,'English'), (17,'Mandarin'),\n", + " (18,'English'),\n", + " (19,'English'), (19,'Portuguese'),\n", + " (20,'English'), (20,'ASL');\n", + "\n", + " -- Seed certifications\n", + "\n", + " INSERT INTO certifications (caregiver_id, cert) VALUES\n", + " (1, 'CPR'), (1, 'First Aid'),\n", + " (2, 'CPR'), (2, 'BLS'),\n", + " (3, 'CPR'),\n", + " (4, 'First Aid'), (4, 'CNA'),\n", + " (5, 'CPR'), (5, 'Dementia Care'),\n", + " (6, 'HHA'), (6, 'CPR'),\n", + " (7, 'First Aid'),\n", + " (8, 'CPR'), (8, 'AED'),\n", + " (9, 'CNA'), (9, 'BLS'),\n", + " (10,'First Aid'),\n", + " (11,'CPR'), (11,'Medication Technician'),\n", + " (12,'CPR'),\n", + " (13,'First Aid'),\n", + " (14,'CPR'), (14,'CNA'),\n", + " (15,'Hospice Training'), (15,'CPR'),\n", + " (16,'First Aid'),\n", + " (17,'CPR'), (17,'Special Needs Training'),\n", + " (18,'BLS'), (18,'CPR'),\n", + " (19,'Dementia Care'), (19,'First Aid'),\n", + " (20,'CPR'), (20,'Childcare Safety');\n", + "\n", + " -- Seed traits\n", + "\n", + " INSERT INTO traits (caregiver_id, trait) VALUES\n", + " (1, 'empathetic'), (1, 'detail-oriented'),\n", + " (2, 'patient'), (2, 'communicative'),\n", + " (3, 'cheerful'), (3, 'reliable'),\n", + " (4, 'organized'), (4, 'professional'),\n", + " (5, 'compassionate'), (5, 'trustworthy'),\n", + " (6, 'calm under pressure'), (6, 'punctual'),\n", + " (7, 'adaptable'), (7, 'energetic'),\n", + " (8, 'friendly'), (8, 'respectful'),\n", + " (9, 'thorough'), (9, 'dependable'),\n", + " (10,'gentle'), (10,'attentive'),\n", + " (11,'proactive'), (11,'communicative'),\n", + " (12,'patient'), (12,'kind'),\n", + " (13,'flexible'), (13,'tidy'),\n", + " (14,'reliable'), (14,'punctual'),\n", + " (15,'compassionate'), (15,'detail-oriented'),\n", + " (16,'discreet'), (16,'organized'),\n", + " (17,'empathetic'), (17,'calm under pressure'),\n", + " (18,'professional'), (18,'thorough'),\n", + " (19,'trustworthy'), (19,'proactive'),\n", + " (20,'cheerful'), (20,'attentive');\n", + "\n", + "'''\n", + "\n", + "# Insert the data into the database\n", + "\n", + "sql = dedent(SQL)\n", + "con = sqlite3.connect(DB_PATH)\n", + "con.executescript(sql)\n", + "con.commit()\n", + "con.close()\n", + "print(\"Seeded:\", DB_PATH)\n" + ] + }, + { + "cell_type": "markdown", + "id": "3c0baa64", + "metadata": {}, + "source": [ + "## Step 4: Teaching the AI to Search\n", + "\n", + "Instead of the AI just talking, we teach it to actually *search* the database.\n", + "\n", + "When someone says *\"I need elder care in Boston for Mondays\"*, the AI translates that into:\n", + "```python\n", + "search_caregivers(city=\"Boston\", care_type=\"elder care\", day=\"Mon\")\n", + "```\n", + "\n", + "This schema defines all the filters the AI can use: location, services, budget, language, availability, and more.\n", + "\n", + "**This was the breakthrough.** No more writing custom queries—the AI figures it out.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "f2af7c67", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'type': 'function',\n", + " 'function': {'name': 'search_caregivers',\n", + " 'description': 'Flexible multi-filter caregiver search. Any filter can be omitted. Supports location, service type, experience, pricing, live-in, language, certifications, day/time availability, and pagination.',\n", + " 'parameters': {'type': 'object',\n", + " 'properties': {'city': {'type': 'string',\n", + " 'description': 'City name to filter by (optional).'},\n", + " 'state_province': {'type': 'string',\n", + " 'description': 'State or province to filter by (optional).'},\n", + " 'country': {'type': 'string',\n", + " 'description': 'Country to filter by (optional).'},\n", + " 'care_type': {'type': 'string',\n", + " 'description': \"Service category, e.g., 'elder_care', 'child_care', 'pet_care', 'housekeeping' (optional).\"},\n", + " 'min_experience': {'type': 'integer',\n", + " 'minimum': 0,\n", + " 'description': 'Minimum years of experience (optional).'},\n", + " 'max_hourly_rate': {'type': 'number',\n", + " 'minimum': 0,\n", + " 'description': 'Maximum hourly rate in local currency (optional).'},\n", + " 'live_in': {'type': 'boolean',\n", + " 'description': 'Require live-in caregivers (optional).'},\n", + " 'language': {'type': 'string',\n", + " 'description': \"Required spoken language, e.g., 'English', 'Spanish' (optional).\"},\n", + " 'certification': {'type': 'string',\n", + " 'description': \"Required certification, e.g., 'CPR', 'CNA' (optional).\"},\n", + " 'day': {'type': 'string',\n", + " 'description': \"Day of week to match availability (optional), e.g., 'Monday', 'Tuesday', ... 'Sunday'.\"},\n", + " 'time_between': {'type': 'array',\n", + " 'description': \"Required availability window as ['HH:MM','HH:MM'] in 24h time. Matches caregivers whose availability window fully covers this range.\",\n", + " 'items': {'type': 'string',\n", + " 'pattern': '^\\\\d{2}:\\\\d{2}$',\n", + " 'description': \"Time in 'HH:MM' 24-hour format.\"},\n", + " 'minItems': 2,\n", + " 'maxItems': 2},\n", + " 'limit': {'type': 'integer',\n", + " 'minimum': 1,\n", + " 'maximum': 1000,\n", + " 'default': 50,\n", + " 'description': 'Max number of results to return (default 50).'},\n", + " 'offset': {'type': 'integer',\n", + " 'minimum': 0,\n", + " 'default': 0,\n", + " 'description': 'Number of results to skip for pagination (default 0).'}},\n", + " 'required': [],\n", + " 'additionalProperties': False}}}]" + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Tool definition schema\n", + "\n", + "tools = [{\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"search_caregivers\",\n", + " \"description\": (\n", + " \"Flexible multi-filter caregiver search. Any filter can be omitted. \"\n", + " \"Supports location, service type, experience, pricing, live-in, language, \"\n", + " \"certifications, day/time availability, and pagination.\"\n", + " ),\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"city\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"City name to filter by (optional).\"\n", + " },\n", + " \"state_province\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"State or province to filter by (optional).\"\n", + " },\n", + " \"country\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Country to filter by (optional).\"\n", + " },\n", + " \"care_type\": {\n", + " \"type\": \"string\",\n", + " \"description\": (\n", + " \"Service category, e.g., 'elder_care', 'child_care', \"\n", + " \"'pet_care', 'housekeeping' (optional).\"\n", + " )\n", + " },\n", + " \"min_experience\": {\n", + " \"type\": \"integer\",\n", + " \"minimum\": 0,\n", + " \"description\": \"Minimum years of experience (optional).\"\n", + " },\n", + " \"max_hourly_rate\": {\n", + " \"type\": \"number\",\n", + " \"minimum\": 0,\n", + " \"description\": \"Maximum hourly rate in local currency (optional).\"\n", + " },\n", + " \"live_in\": {\n", + " \"type\": \"boolean\",\n", + " \"description\": \"Require live-in caregivers (optional).\"\n", + " },\n", + " \"language\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Required spoken language, e.g., 'English', 'Spanish' (optional).\"\n", + " },\n", + " \"certification\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Required certification, e.g., 'CPR', 'CNA' (optional).\"\n", + " },\n", + " \"day\": {\n", + " \"type\": \"string\",\n", + " \"description\": (\n", + " \"Day of week to match availability (optional), e.g., \"\n", + " \"'Monday', 'Tuesday', ... 'Sunday'.\"\n", + " )\n", + " },\n", + " \"time_between\": {\n", + " \"type\": \"array\",\n", + " \"description\": (\n", + " \"Required availability window as ['HH:MM','HH:MM'] in 24h time. \"\n", + " \"Matches caregivers whose availability window fully covers this range.\"\n", + " ),\n", + " \"items\": {\n", + " \"type\": \"string\",\n", + " \"pattern\": \"^\\\\d{2}:\\\\d{2}$\",\n", + " \"description\": \"Time in 'HH:MM' 24-hour format.\"\n", + " },\n", + " \"minItems\": 2,\n", + " \"maxItems\": 2\n", + " },\n", + " \"limit\": {\n", + " \"type\": \"integer\",\n", + " \"minimum\": 1,\n", + " \"maximum\": 1000,\n", + " \"default\": 50,\n", + " \"description\": \"Max number of results to return (default 50).\"\n", + " },\n", + " \"offset\": {\n", + " \"type\": \"integer\",\n", + " \"minimum\": 0,\n", + " \"default\": 0,\n", + " \"description\": \"Number of results to skip for pagination (default 0).\"\n", + " }\n", + " },\n", + " \"required\": [],\n", + " \"additionalProperties\": False\n", + " }\n", + " }\n", + "}]\n", + "\n", + "tools" + ] + }, + { + "cell_type": "markdown", + "id": "76416da2", + "metadata": {}, + "source": [ + "## Step 5: Helper Functions\n", + "\n", + "**Voice:** The AI can speak its responses using OpenAI's text-to-speech.\n", + "\n", + "**Database functions:** All the queries we need—search, get profiles, check availability, etc. These are what the AI calls behind the scenes.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "2f50cc15", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Convert text to speech using OpenAI's TTS API\n", + "def announcements(message):\n", + " response = openai.audio.speech.create(\n", + " model=\"gpt-4o-mini-tts\",\n", + " voice=\"coral\", # Also, try replacing onyx with alloy or coral\n", + " input=message\n", + " )\n", + " return response.content\n", + "\n", + "# Context manager for database connection\n", + "@contextmanager\n", + "def _conn(dict_rows: bool = True):\n", + " conn = sqlite3.connect(DB_PATH)\n", + " if dict_rows:\n", + " conn.row_factory = _dict_factory\n", + " try:\n", + " yield conn\n", + " conn.commit()\n", + " finally:\n", + " conn.close()\n", + "\n", + "####################\n", + "# Helper functions #\n", + "####################\n", + "\n", + "# Converts SQLite query results from tuples into dictionaries\n", + "def _dict_factory(cursor, row):\n", + " return {col[0]: row[idx] for idx, col in enumerate(cursor.description)}\n", + "# A debug/logging function that prints database tool activity\n", + "def _print(msg: str):\n", + " print(f\"DATABASE TOOL CALLED: {msg}\", flush=True)\n", + "\n", + "################################\n", + "# Caregiver database functions #\n", + "################################\n", + "\n", + "# Counts the number of caregivers in the database\n", + "def get_caregiver_count() -> int:\n", + " _print(\"Counting caregivers\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"SELECT COUNT(*) AS n FROM caregivers\")\n", + " return cur.fetchone()[\"n\"]\n", + "\n", + "# Fetches a caregiver's profile by their ID\n", + "def get_caregiver(caregiver_id: int) -> Optional[Dict[str, Any]]:\n", + " _print(f\"Fetching caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"SELECT * FROM caregivers WHERE id = ?\", (caregiver_id,))\n", + " return cur.fetchone()\n", + "\n", + "# Lists caregivers with pagination\n", + "def list_caregivers(limit: int = 20, offset: int = 0) -> List[Dict[str, Any]]:\n", + " _print(f\"Listing caregivers (limit={limit}, offset={offset})\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT * FROM caregivers\n", + " ORDER BY id\n", + " LIMIT ? OFFSET ?\n", + " \"\"\", (limit, offset))\n", + " return cur.fetchall()\n", + "\n", + "# Fetches the services a caregiver offers\n", + "def get_services(caregiver_id: int) -> List[str]:\n", + " _print(f\"Fetching services for caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT care_type FROM caregiver_services WHERE caregiver_id = ?\n", + " ORDER BY care_type\n", + " \"\"\", (caregiver_id,))\n", + " return [r[\"care_type\"] for r in cur.fetchall()]\n", + "\n", + "# Fetches the languages a caregiver speaks\n", + "def get_languages(caregiver_id: int) -> List[str]:\n", + " _print(f\"Fetching languages for caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT language FROM languages WHERE caregiver_id = ?\n", + " ORDER BY language\n", + " \"\"\", (caregiver_id,))\n", + " return [r[\"language\"] for r in cur.fetchall()]\n", + "\n", + "# Fetches the certifications a caregiver has\n", + "def get_certifications(caregiver_id: int) -> List[str]:\n", + " _print(f\"Fetching certifications for caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT cert FROM certifications WHERE caregiver_id = ?\n", + " ORDER BY cert\n", + " \"\"\", (caregiver_id,))\n", + " return [r[\"cert\"] for r in cur.fetchall()]\n", + "\n", + "# Fetches the traits a caregiver has\n", + "def get_traits(caregiver_id: int) -> List[str]:\n", + " _print(f\"Fetching traits for caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT trait FROM traits WHERE caregiver_id = ?\n", + " ORDER BY trait\n", + " \"\"\", (caregiver_id,))\n", + " return [r[\"trait\"] for r in cur.fetchall()]\n", + "\n", + "# Fetches the availability of a caregiver\n", + "def get_availability(caregiver_id: int) -> List[Dict[str, str]]:\n", + " _print(f\"Fetching availability for caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT day, time_start, time_end\n", + " FROM availability\n", + " WHERE caregiver_id = ?\n", + " ORDER BY\n", + " CASE day\n", + " WHEN 'Mon' THEN 1 WHEN 'Tue' THEN 2 WHEN 'Wed' THEN 3\n", + " WHEN 'Thu' THEN 4 WHEN 'Fri' THEN 5 WHEN 'Sat' THEN 6\n", + " WHEN 'Sun' THEN 7 ELSE 8\n", + " END, time_start\n", + " \"\"\", (caregiver_id,))\n", + " return cur.fetchall()\n", + "\n", + "# Fetches a caregiver's full profile\n", + "def get_caregiver_profile(caregiver_id: int) -> Optional[Dict[str, Any]]:\n", + " _print(f\"Fetching full profile for caregiver #{caregiver_id}\")\n", + " base = get_caregiver(caregiver_id)\n", + " if not base:\n", + " return None\n", + " base[\"services\"] = get_services(caregiver_id)\n", + " base[\"languages\"] = get_languages(caregiver_id)\n", + " base[\"certifications\"] = get_certifications(caregiver_id)\n", + " base[\"traits\"] = get_traits(caregiver_id)\n", + " base[\"availability\"] = get_availability(caregiver_id)\n", + " return base\n", + "\n", + "###########################################\n", + "# Search caregivers with multiple filters #\n", + "###########################################\n", + "\n", + "def search_caregivers(\n", + " city: Optional[str] = None,\n", + " state_province: Optional[str] = None,\n", + " country: Optional[str] = None,\n", + " care_type: Optional[str] = None,\n", + " min_experience: Optional[int] = None,\n", + " max_hourly_rate: Optional[float] = None,\n", + " live_in: Optional[bool] = None,\n", + " language: Optional[str] = None,\n", + " certification: Optional[str] = None,\n", + " day: Optional[str] = None,\n", + " time_between: Optional[Tuple[str, str]] = None, # ('HH:MM', 'HH:MM')\n", + " limit: int = 50,\n", + " offset: int = 0\n", + ") -> List[Dict[str, Any]]:\n", + " \"\"\"\n", + " Flexible multi-filter search. Any filter can be omitted.\n", + " \"\"\"\n", + " _print(\"Searching caregivers with multiple filters\")\n", + "\n", + " # base + optional joins\n", + " join_clauses = []\n", + " where = [\"1=1\"]\n", + " params: List[Any] = []\n", + "\n", + " if care_type:\n", + " join_clauses.append(\"JOIN caregiver_services s ON s.caregiver_id = c.id\")\n", + " where.append(\"LOWER(s.care_type) = LOWER(?)\")\n", + " params.append(care_type)\n", + "\n", + " if language:\n", + " join_clauses.append(\"JOIN languages l ON l.caregiver_id = c.id\")\n", + " where.append(\"LOWER(l.language) = LOWER(?)\")\n", + " params.append(language)\n", + "\n", + " if certification:\n", + " join_clauses.append(\"JOIN certifications cert ON cert.caregiver_id = c.id\")\n", + " where.append(\"LOWER(cert.cert) = LOWER(?)\")\n", + " params.append(certification)\n", + "\n", + " if day or time_between:\n", + " join_clauses.append(\"JOIN availability a ON a.caregiver_id = c.id\")\n", + " if day:\n", + " where.append(\"a.day = ?\")\n", + " params.append(day)\n", + " if time_between:\n", + " t0, t1 = time_between\n", + " # overlap check: caregiver window [start,end] must include [t0,t1]\n", + " where.append(\"a.time_start <= ? AND a.time_end >= ?\")\n", + " params.extend([t0, t1])\n", + "\n", + " if city:\n", + " where.append(\"LOWER(c.city) = LOWER(?)\")\n", + " params.append(city)\n", + " if state_province:\n", + " where.append(\"LOWER(c.state_province) = LOWER(?)\")\n", + " params.append(state_province)\n", + " if country:\n", + " where.append(\"LOWER(c.country) = LOWER(?)\")\n", + " params.append(country)\n", + " if min_experience is not None:\n", + " where.append(\"c.years_experience >= ?\")\n", + " params.append(min_experience)\n", + " if max_hourly_rate is not None:\n", + " where.append(\"c.hourly_rate <= ?\")\n", + " params.append(max_hourly_rate)\n", + " if live_in is not None:\n", + " where.append(\"c.live_in = ?\")\n", + " params.append(1 if live_in else 0)\n", + "\n", + " sql = f\"\"\"\n", + " SELECT DISTINCT c.*\n", + " FROM caregivers c\n", + " {' '.join(join_clauses)}\n", + " WHERE {' AND '.join(where)}\n", + " ORDER BY c.hourly_rate ASC, c.years_experience DESC, c.id\n", + " LIMIT ? OFFSET ?\n", + " \"\"\"\n", + " params.extend([limit, offset])\n", + "\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(sql, tuple(params))\n", + " return cur.fetchall()" + ] + }, + { + "cell_type": "markdown", + "id": "6c526d05", + "metadata": {}, + "source": [ + "## Step 6: Quick Test\n", + "\n", + "Before connecting everything to the AI, let's make sure the database works. Run these examples to see sample caregivers and their profiles.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "98165a21", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DATABASE TOOL CALLED: Searching caregivers with multiple filters\n", + "Found 1 elder care providers in New York:\n", + "- Grace Williams: $28.0/hr, 6 years experience\n", + "\n", + "============================================================\n", + "\n", + "DATABASE TOOL CALLED: Searching caregivers with multiple filters\n", + "Found 1 Spanish-speaking child care providers:\n", + "- Isabella Garcia in San Jose, CA\n", + "\n", + "============================================================\n", + "\n", + "DATABASE TOOL CALLED: Fetching full profile for caregiver #1\n", + "DATABASE TOOL CALLED: Fetching caregiver #1\n", + "DATABASE TOOL CALLED: Fetching services for caregiver #1\n", + "DATABASE TOOL CALLED: Fetching languages for caregiver #1\n", + "DATABASE TOOL CALLED: Fetching certifications for caregiver #1\n", + "DATABASE TOOL CALLED: Fetching traits for caregiver #1\n", + "DATABASE TOOL CALLED: Fetching availability for caregiver #1\n", + "Detailed profile for Grace Williams:\n", + " Services: companionship, elder care\n", + " Languages: English, Spanish\n", + " Certifications: CPR, First Aid\n", + " Traits: detail-oriented, empathetic\n", + " Availability: 3 time slots\n" + ] + } + ], + "source": [ + "# Example 1: Search for elder care providers in New York\n", + "results = search_caregivers(\n", + " city=\"New York\",\n", + " care_type=\"elder care\",\n", + " max_hourly_rate=30.0,\n", + " limit=5\n", + ")\n", + "\n", + "print(f\"Found {len(results)} elder care providers in New York:\")\n", + "for caregiver in results:\n", + " print(f\"- {caregiver['name']}: ${caregiver['hourly_rate']}/hr, {caregiver['years_experience']} years experience\")\n", + "\n", + "print(\"\\n\" + \"=\"*60 + \"\\n\")\n", + "\n", + "# Example 2: Search for Spanish-speaking child care providers\n", + "results2 = search_caregivers(\n", + " care_type=\"child care\",\n", + " language=\"Spanish\",\n", + " limit=3\n", + ")\n", + "\n", + "print(f\"Found {len(results2)} Spanish-speaking child care providers:\")\n", + "for caregiver in results2:\n", + " print(f\"- {caregiver['name']} in {caregiver['city']}, {caregiver['state_province']}\")\n", + "\n", + "print(\"\\n\" + \"=\"*60 + \"\\n\")\n", + "\n", + "# Example 3: Get detailed profile of a specific caregiver\n", + "if results:\n", + " caregiver_id = results[0]['id']\n", + " profile = get_caregiver_profile(caregiver_id)\n", + " print(f\"Detailed profile for {profile['name']}:\")\n", + " print(f\" Services: {', '.join(profile['services'])}\")\n", + " print(f\" Languages: {', '.join(profile['languages'])}\")\n", + " print(f\" Certifications: {', '.join(profile['certifications'])}\")\n", + " print(f\" Traits: {', '.join(profile['traits'])}\")\n", + " print(f\" Availability: {len(profile['availability'])} time slots\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "abfa81e6", + "metadata": {}, + "source": [ + "## Step 7: The AI's Instructions\n", + "\n", + "Here's where I learned prompt engineering matters *a lot*.\n", + "\n", + "The AI needs to know:\n", + "- What exact keywords to use (\"elder care\" not \"elderly care\", \"Mon\" not \"Monday\")\n", + "- How to map natural language to database values\n", + "- That it should give 2-3 recommendations with pros/cons\n", + "- To remind families to verify credentials independently\n", + "\n", + "**The lesson from MyWoosah:** Small keyword mismatches = zero results. This prompt prevents that.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "7bbe36e3", + "metadata": {}, + "outputs": [], + "source": [ + "# System prompt\n", + "\n", + "system_prompt = '''\n", + " You are a compassionate Caregiver Assistant that helps families quickly identify the most\n", + " suitable care provider by gathering requirements (care needs, schedule, budget, location,\n", + " language/cultural fit) and matching them to available profiles. Provide 2-3 best-fit options\n", + " with pros/cons, estimated costs, and next steps, and clearly state that credentials/background\n", + " checks are not verified by this sample app and should be confirmed by the family.\n", + "\n", + " CRITICAL: When searching the database, you MUST use these EXACT terms:\n", + "\n", + " CARE TYPES (use exactly as shown):\n", + " - \"elder care\" (for elderly, senior, old age, geriatric care)\n", + " - \"companionship\" (for companion, friendship, social support)\n", + " - \"post-op support\" (for post-surgery, post-operative, recovery care)\n", + " - \"child care\" (for children, kids, babysitting, nanny)\n", + " - \"special needs\" (for disabilities, autism, developmental needs)\n", + " - \"respite care\" (for temporary relief, break for family caregivers)\n", + " - \"dementia care\" (for Alzheimer's, memory care, cognitive decline)\n", + " - \"hospice support\" (for end-of-life, palliative, terminal care)\n", + "\n", + " If a user mentions any variation, map it to the closest match above. If unclear, ask clarifying questions.\n", + "\n", + " DAYS OF WEEK (use exactly as shown):\n", + " - \"Mon\" (for Monday)\n", + " - \"Tue\" (for Tuesday)\n", + " - \"Wed\" (for Wednesday)\n", + " - \"Thu\" (for Thursday)\n", + " - \"Fri\" (for Friday)\n", + " - \"Sat\" (for Saturday)\n", + " - \"Sun\" (for Sunday)\n", + "\n", + " STATES/PROVINCES (use 2-letter codes):\n", + " - Use standard US state abbreviations: \"NY\", \"CA\", \"TX\", \"FL\", \"MA\", etc.\n", + " - Convert full state names to abbreviations before searching\n", + "\n", + " COMMON LANGUAGES:\n", + " - \"English\", \"Spanish\", \"French\", \"Vietnamese\", \"Korean\", \"Hindi\", \"Mandarin\", \"Portuguese\", \"Tagalog\", \"ASL\"\n", + " - Capitalize properly (e.g., user says \"spanish\" → use \"Spanish\")\n", + "\n", + " CERTIFICATIONS:\n", + " - \"CPR\", \"First Aid\", \"CNA\", \"BLS\", \"HHA\", \"AED\", \"Medication Technician\", \"Hospice Training\", \n", + " \"Dementia Care\", \"Special Needs Training\", \"Childcare Safety\"\n", + " - Use exact capitalization and full names\n", + "\n", + " TRAITS:\n", + " - \"empathetic\", \"patient\", \"cheerful\", \"organized\", \"compassionate\", \"calm under pressure\", \n", + " \"adaptable\", \"friendly\", \"thorough\", \"gentle\", \"proactive\", \"flexible\", \"reliable\", \n", + " \"detail-oriented\", \"communicative\", \"energetic\", \"respectful\", \"dependable\", \"attentive\", \n", + " \"kind\", \"tidy\", \"punctual\", \"discreet\", \"professional\", \"trustworthy\"\n", + " - Use lowercase for all traits\n", + "\n", + " SEARCH STRATEGY:\n", + " 1. Listen carefully to user requirements\n", + " 2. Map their natural language to database terms above\n", + " 3. Use search_caregivers() with exact keyword matches\n", + " 4. If no results, suggest alternatives or broader searches\n", + " 5. After getting results, use get_caregiver_profile() for detailed information on top matches\n", + "\n", + " Always confirm your understanding by restating requirements using the exact database terms before searching.\n", + "'''" + ] + }, + { + "cell_type": "markdown", + "id": "0b8ae902", + "metadata": {}, + "source": [ + "## Step 8: Making it Work (and Not Crash)\n", + "\n", + "This is the engine room. When the AI wants to search, this code:\n", + "1. Validates the request\n", + "2. Calls the right database function\n", + "3. Handles errors gracefully (no crashes!)\n", + "4. Limits results to prevent overwhelming the AI\n", + "5. Generates the voice response\n", + "\n", + "**Defensive programming:** I learned the hard way that things break. This code expects problems and handles them elegantly.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "0d8accbc", + "metadata": {}, + "outputs": [], + "source": [ + "# Function registry: Maps tool names to actual Python functions\n", + "TOOL_REGISTRY = {\n", + " \"search_caregivers\": search_caregivers,\n", + " \"get_caregiver_count\": get_caregiver_count,\n", + " \"get_caregiver\": get_caregiver,\n", + " \"list_caregivers\": list_caregivers,\n", + " \"get_services\": get_services,\n", + " \"get_languages\": get_languages,\n", + " \"get_certifications\": get_certifications,\n", + " \"get_traits\": get_traits,\n", + " \"get_availability\": get_availability,\n", + " \"get_caregiver_profile\": get_caregiver_profile,\n", + "}\n", + "\n", + "def execute_tool_call(tool_call):\n", + " \"\"\"\n", + " Safely execute a single tool call with error handling.\n", + " Returns a properly formatted tool response.\n", + " \"\"\"\n", + " import json\n", + " \n", + " function_name = tool_call.function.name\n", + " \n", + " # Defensive check: Ensure function exists in registry\n", + " if function_name not in TOOL_REGISTRY:\n", + " return {\n", + " \"role\": \"tool\",\n", + " \"tool_call_id\": tool_call.id,\n", + " \"content\": json.dumps({\n", + " \"error\": f\"Unknown function: {function_name}\",\n", + " \"available_functions\": list(TOOL_REGISTRY.keys())\n", + " })\n", + " }\n", + " \n", + " try:\n", + " # Parse arguments\n", + " args = json.loads(tool_call.function.arguments)\n", + " \n", + " # Execute the function\n", + " func = TOOL_REGISTRY[function_name]\n", + " result = func(**args)\n", + " \n", + " # Format response based on result type with limit to prevent token overflow\n", + " if isinstance(result, list):\n", + " content = json.dumps({\n", + " \"count\": len(result),\n", + " \"results\": result[:10] if len(result) > 10 else result,\n", + " \"truncated\": len(result) > 10\n", + " })\n", + " elif isinstance(result, dict):\n", + " content = json.dumps(result)\n", + " elif isinstance(result, (int, float, str)):\n", + " content = json.dumps({\"result\": result})\n", + " else:\n", + " content = str(result)\n", + " \n", + " return {\n", + " \"role\": \"tool\",\n", + " \"tool_call_id\": tool_call.id,\n", + " \"content\": content\n", + " }\n", + " \n", + " except Exception as e:\n", + " # Defensive error handling\n", + " return {\n", + " \"role\": \"tool\",\n", + " \"tool_call_id\": tool_call.id,\n", + " \"content\": json.dumps({\n", + " \"error\": str(e),\n", + " \"function\": function_name,\n", + " \"args\": tool_call.function.arguments\n", + " })\n", + " }\n", + "\n", + "def process_tool_calls(message):\n", + " \"\"\"\n", + " Process all tool calls from the AI response.\n", + " Returns tool responses and extracted metadata.\n", + " \"\"\"\n", + " responses = []\n", + " metadata = {\n", + " \"cities\": set(),\n", + " \"caregiver_ids\": set(),\n", + " \"total_results\": 0\n", + " }\n", + " \n", + " if not message.tool_calls:\n", + " return responses, metadata\n", + " \n", + " for tool_call in message.tool_calls:\n", + " # Execute the tool call\n", + " response = execute_tool_call(tool_call)\n", + " responses.append(response)\n", + " \n", + " # Extract metadata for UI enhancements\n", + " try:\n", + " import json\n", + " content = json.loads(response[\"content\"])\n", + " \n", + " # Extract cities from search results\n", + " if \"results\" in content and isinstance(content[\"results\"], list):\n", + " for item in content[\"results\"]:\n", + " if isinstance(item, dict) and \"city\" in item:\n", + " metadata[\"cities\"].add(item[\"city\"])\n", + " if isinstance(item, dict) and \"id\" in item:\n", + " metadata[\"caregiver_ids\"].add(item[\"id\"])\n", + " \n", + " if \"count\" in content:\n", + " metadata[\"total_results\"] += content[\"count\"]\n", + " \n", + " except:\n", + " pass # Silently ignore metadata extraction errors\n", + " \n", + " return responses, metadata\n", + "\n", + "def generate_city_image(city):\n", + " \"\"\"\n", + " Generate or retrieve a city image (placeholder for future enhancement).\n", + " Could integrate with DALL-E, Unsplash API, or local image database.\n", + " \"\"\"\n", + " # Placeholder - can be enhanced with actual image generation\n", + " return None\n", + "\n", + "def chat(history):\n", + " \"\"\"\n", + " Main chat handler with multi-modal support and defensive error handling.\n", + " Handles conversation flow, tool calls, and response generation.\n", + " \"\"\"\n", + " # Normalize history format\n", + " history = [{\"role\": h[\"role\"], \"content\": h[\"content\"]} for h in history]\n", + " \n", + " # Initialize conversation with system prompt\n", + " messages = [{\"role\": \"system\", \"content\": system_prompt}] + history\n", + " \n", + " # Initialize metadata\n", + " image = None\n", + " selected_city = None\n", + " \n", + " try:\n", + " # Initial API call\n", + " response = openai.chat.completions.create(\n", + " model=MODEL,\n", + " messages=messages,\n", + " tools=tools\n", + " )\n", + " \n", + " # Tool calling loop (with safety limit)\n", + " max_iterations = 5\n", + " iteration = 0\n", + " \n", + " while response.choices[0].finish_reason == \"tool_calls\" and iteration < max_iterations:\n", + " iteration += 1\n", + " message = response.choices[0].message\n", + " \n", + " # Process all tool calls\n", + " tool_responses, metadata = process_tool_calls(message)\n", + " \n", + " # Track city for image generation\n", + " if metadata[\"cities\"]:\n", + " selected_city = list(metadata[\"cities\"])[0]\n", + " \n", + " # Add assistant message and tool responses to conversation\n", + " messages.append(message)\n", + " messages.extend(tool_responses)\n", + " \n", + " # Continue conversation\n", + " response = openai.chat.completions.create(\n", + " model=MODEL,\n", + " messages=messages,\n", + " tools=tools\n", + " )\n", + " \n", + " # Extract final reply\n", + " reply = response.choices[0].message.content\n", + " history.append({\"role\": \"assistant\", \"content\": reply})\n", + " \n", + " # Generate voice response\n", + " voice = announcements(reply)\n", + " \n", + " # Generate city image if applicable\n", + " if selected_city:\n", + " image = generate_city_image(selected_city)\n", + " \n", + " return history, voice, image\n", + " \n", + " except Exception as e:\n", + " # Defensive error handling for the entire chat flow\n", + " error_message = f\"I apologize, but I encountered an error: {str(e)}. Please try again.\"\n", + " history.append({\"role\": \"assistant\", \"content\": error_message})\n", + " return history, None, None" + ] + }, + { + "cell_type": "markdown", + "id": "451ed2e5", + "metadata": {}, + "source": [ + "## Step 9: The Interface\n", + "\n", + "A clean, professional web UI built with Gradio.\n", + "\n", + "Features:\n", + "- Chat interface with conversation history\n", + "- Voice responses that auto-play\n", + "- Settings sidebar (model selection, voice options)\n", + "- Clear instructions for families\n", + "\n", + "**Why Gradio?** At MyWoosah, I needed something non-technical staff could use immediately. Gradio made that possible without weeks of frontend work.\n", + "\n", + "**Run this cell to launch!** 🚀\n" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "a07e7793-b8f5-44f4-aded-5562f633271a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7871\n", + "* To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 71, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gradio as gr\n", + "\n", + "# Gradio UI Setup\n", + "\n", + "def put_message_in_chatbot(message, history):\n", + " \"\"\"Add user message to chat history\"\"\"\n", + " return \"\", history + [{\"role\": \"user\", \"content\": message}]\n", + "\n", + "# Custom CSS for better styling\n", + "custom_css = \"\"\"\n", + "#chatbot {\n", + " border-radius: 10px;\n", + " box-shadow: 0 2px 8px rgba(0,0,0,0.1);\n", + "}\n", + "#message_box {\n", + " border-radius: 8px;\n", + "}\n", + ".header {\n", + " text-align: center;\n", + " padding: 20px;\n", + " background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);\n", + " color: white;\n", + " border-radius: 10px;\n", + " margin-bottom: 20px;\n", + "}\n", + "\"\"\"\n", + "\n", + "with gr.Blocks(title=\"CareGiver AI Assistant\", css=custom_css, theme=gr.themes.Soft()) as ui:\n", + " \n", + " # Header\n", + " gr.Markdown(\"\"\"\n", + "
\n", + "

🏥 RoboCare AI Assistant

\n", + "

Find the perfect caregiver for your loved ones

\n", + "
\n", + " \"\"\")\n", + " \n", + " # Instructions\n", + " with gr.Accordion(\"ℹ️ Click here to learn more on how to use this AI\", open=False):\n", + " gr.Markdown(\"\"\"\n", + " **Tell me what you need:**\n", + " - Type of care (elder care, child care, companionship, etc.)\n", + " - Location (city, state)\n", + " - Schedule requirements (days/times)\n", + " - Budget constraints\n", + " - Language or certification needs\n", + " \n", + " **Example:** \"I need an elder care provider in Boston for Monday mornings who speaks Spanish and has CPR certification.\"\n", + " \n", + " ⚠️ **Note:** This is a demo app. Always verify credentials and conduct background checks independently.\n", + " \"\"\")\n", + " \n", + " # Main chat interface\n", + " with gr.Row():\n", + " with gr.Column(scale=2):\n", + " chatbot = gr.Chatbot(\n", + " height=500, \n", + " type=\"messages\",\n", + " elem_id=\"chatbot\",\n", + " label=\"Chat History\",\n", + " avatar_images=(None, \"🤖\")\n", + " )\n", + " \n", + " # Audio output\n", + " audio_output = gr.Audio(\n", + " label=\"Voice Response\",\n", + " autoplay=True,\n", + " visible=True,\n", + " interactive=False\n", + " )\n", + " \n", + " # Settings sidebar\n", + " with gr.Column(scale=1):\n", + " gr.Markdown(\"### ⚙️ Settings\")\n", + " \n", + " # Model selector (for future enhancement)\n", + " model_select = gr.Dropdown(\n", + " choices=[\"gpt-4o-mini\", \"gpt-4o\", \"gpt-4-turbo\"],\n", + " value=\"gpt-4o-mini\",\n", + " label=\"AI Model\",\n", + " interactive=True\n", + " )\n", + " \n", + " # Voice selector\n", + " voice_select = gr.Dropdown(\n", + " choices=[\"coral\", \"alloy\", \"echo\", \"fable\", \"onyx\", \"nova\", \"shimmer\"],\n", + " value=\"coral\",\n", + " label=\"Voice\",\n", + " interactive=True\n", + " )\n", + " \n", + " # Audio toggle\n", + " audio_enabled = gr.Checkbox(\n", + " label=\"Enable Voice Responses\",\n", + " value=True\n", + " )\n", + " \n", + " # Clear button\n", + " clear_btn = gr.Button(\"🗑️ Clear Conversation\", variant=\"secondary\")\n", + " \n", + " # Input section\n", + " with gr.Row():\n", + " message = gr.Textbox(\n", + " label=\"Your Message\",\n", + " placeholder=\"Type your question here... (e.g., 'I need elder care in Boston')\",\n", + " lines=2,\n", + " elem_id=\"message_box\",\n", + " scale=4\n", + " )\n", + " send_btn = gr.Button(\"Send\", variant=\"primary\", scale=1)\n", + " \n", + " # Event handlers\n", + " def chat_wrapper(history):\n", + " \"\"\"Wrapper to handle chat and extract only needed outputs\"\"\"\n", + " history_out, voice, image = chat(history)\n", + " return history_out, voice\n", + " \n", + " # Submit on enter or button click\n", + " submit_event = message.submit(\n", + " put_message_in_chatbot,\n", + " inputs=[message, chatbot],\n", + " outputs=[message, chatbot]\n", + " ).then(\n", + " chat_wrapper,\n", + " inputs=chatbot,\n", + " outputs=[chatbot, audio_output]\n", + " )\n", + " \n", + " send_btn.click(\n", + " put_message_in_chatbot,\n", + " inputs=[message, chatbot],\n", + " outputs=[message, chatbot]\n", + " ).then(\n", + " chat_wrapper,\n", + " inputs=chatbot,\n", + " outputs=[chatbot, audio_output]\n", + " )\n", + " \n", + " # Clear conversation\n", + " clear_btn.click(\n", + " lambda: ([], None),\n", + " outputs=[chatbot, audio_output]\n", + " )\n", + " \n", + " # Footer\n", + " gr.Markdown(\"\"\"\n", + " ---\n", + "
\n", + " Powered by OpenAI & Gradio | Built by RoboOffice Ltd\n", + "
\n", + " \"\"\")\n", + "\n", + "# Launch with better configuration\n", + "ui.launch(\n", + " inbrowser=True,\n", + " share=False,\n", + " show_error=True,\n", + " quiet=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "97d87d95", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Reflections\n", + "\n", + "This project started from frustration: *\"There has to be a better way to match families with caregivers.\"*\n", + "\n", + "Through the Andela program, I learned that AI + thoughtful engineering = solutions to real problems.\n", + "\n", + "### What Worked:\n", + "- **Function calling** eliminated the need for custom queries\n", + "- **Prompt engineering** prevented keyword mismatches\n", + "- **Defensive coding** made it robust\n", + "- **Gradio** made it accessible\n", + "\n", + "### What I'd Do Next:\n", + "- Add speech input (families could call and talk)\n", + "- Connect to actual MyWoosah database\n", + "- Add background check API integration\n", + "- Deploy for real users\n", + "\n", + "### The Bigger Picture:\n", + "\n", + "This isn't just about caregiving. The same pattern works for:\n", + "- Healthcare appointments\n", + "- Legal services\n", + "- Tutoring platforms\n", + "- Any matching problem where natural language beats forms\n", + "\n", + "AI doesn't replace good database design—it makes it accessible to everyone.\n", + "\n", + "---\n", + "\n", + "**For MyWoosah Inc and beyond:** This is proof that AI can transform how we connect people with the care they need.\n", + "\n", + "*Built during Week 2 of the Andela LLM Engineering Program*\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week2/community-contributions/kwabena/week2_solution_.ipynb b/week2/community-contributions/kwabena/week2_solution_.ipynb new file mode 100644 index 0000000..9b1f22e --- /dev/null +++ b/week2/community-contributions/kwabena/week2_solution_.ipynb @@ -0,0 +1,173 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fd1cdd6e", + "metadata": {}, + "source": [ + "## Week 2 - Full Prototype for Technical Questions Answerer" + ] + }, + { + "cell_type": "markdown", + "id": "70db9a0b", + "metadata": {}, + "source": [ + " This notebook will implement a Gradio UI, streaming, use of the system prompt to add expertise, and the ability to switch between models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df46689d", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "import os\n", + "import json\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7416a2a", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialization\n", + "load_dotenv(override=True)\n", + "\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + " \n", + "MODEL = \"gpt-4.1-mini\"\n", + "openai = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86966749", + "metadata": {}, + "outputs": [], + "source": [ + "system_message = \"\"\"\n", + "You are an expert technical question answerer specializing in data science, programming, \n", + "and software engineering. Your goal is to provide clear, accurate, and practical answers \n", + "to technical questions.\n", + "\n", + "When answering:\n", + "- Break down complex concepts into understandable explanations\n", + "- Provide code examples when relevant, with comments explaining key parts\n", + "- Mention common pitfalls or best practices\n", + "- If a question is ambiguous, state your assumptions or ask for clarification\n", + "- For debugging questions, explain both the fix and why the error occurred\n", + "- Cite specific documentation or resources when helpful\n", + "\n", + "Always prioritize accuracy and clarity over speed. If you're unsure about something, \n", + "acknowledge the uncertainty rather than guessing.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d34e5b81", + "metadata": {}, + "outputs": [], + "source": [ + "# Streaming chat funcion\n", + "def chat(model, history):\n", + " messages = [{\"role\": \"system\", \"content\": system_message}]\n", + " for h in history:\n", + " messages.append({\"role\": h[\"role\"], \"content\": h[\"content\"]})\n", + "\n", + " stream = openai.chat.completions.create(\n", + " model=model, \n", + " messages=messages,\n", + " stream=True\n", + " )\n", + "\n", + " response = \"\"\n", + " for chunk in stream:\n", + " if chunk.choices[0].delta.content is not None:\n", + " response += chunk.choices[0].delta.content\n", + " yield history + [{\"role\": \"assistant\", \"content\": response}]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32350869", + "metadata": {}, + "outputs": [], + "source": [ + "#Gradio Interface\n", + "with gr.Blocks() as ui:\n", + " with gr.Row():\n", + " chatbot = gr.Chatbot(height=500, type=\"messages\")\n", + " with gr.Row():\n", + " message = gr.Textbox(label=\"Chat with AI Assistant: \")\n", + " model_dropdown = gr.Dropdown(\n", + " choices=[\"gpt-4.1-mini\",\"gpt-4o-mini\", \"gpt-4o\", \"gpt-4-turbo\"], \n", + " value=\"gpt-4.1-mini\", \n", + " label=\"Select Model\"\n", + " ) \n", + "\n", + " def handle_submit(user_message, chat_history):\n", + " # Add user message to history\n", + " chat_history = chat_history + [{\"role\": \"user\", \"content\": user_message}]\n", + " return \"\", chat_history\n", + "\n", + " message.submit(\n", + " handle_submit, \n", + " inputs=[message, chatbot], \n", + " outputs=[message, chatbot]\n", + " ).then(\n", + " chat, \n", + " inputs=[model_dropdown, chatbot],\n", + " outputs=[chatbot]\n", + " )\n", + "\n", + "ui.launch(inbrowser=True)" + ] + }, + { + "cell_type": "markdown", + "id": "cf2b29e1", + "metadata": {}, + "source": [ + "### Concluding Remarks\n", + "In this exercise, we successfully built a working AI chatbot with Gradio that includes streaming responses and the ability to switch between different models. The implementation demonstrates how to create an interactive interface for LLM applications." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week2/community-contributions/week2_exercise_solution-Stephen.ipynb b/week2/community-contributions/week2_exercise_solution-Stephen.ipynb new file mode 100644 index 0000000..21de7d8 --- /dev/null +++ b/week2/community-contributions/week2_exercise_solution-Stephen.ipynb @@ -0,0 +1,296 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d006b2ea-9dfe-49c7-88a9-a5a0775185fd", + "metadata": {}, + "source": [ + "# End of week 2 Exercise - Bookstore Assistant\n", + "\n", + "Now use everything you've learned from Week 2 to build a full prototype for the technical question/answerer you built in Week 1 Exercise.\n", + "\n", + "This should include a Gradio UI, streaming, use of the system prompt to add expertise, and the ability to switch between models. Bonus points if you can demonstrate use of a tool!\n", + "\n", + "If you feel bold, see if you can add audio input so you can talk to it, and have it respond with audio. ChatGPT or Claude can help you, or email me if you have questions.\n", + "\n", + "I will publish a full solution here soon - unless someone beats me to it...\n", + "\n", + "There are so many commercial applications for this, from a language tutor, to a company onboarding solution, to a companion AI to a course (like this one!) I can't wait to see your results." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a07e7793-b8f5-44f4-aded-5562f633271a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OpenAI API Key exists and begins sk-proj-\n", + "Google API Key exists and begins AIzaSyCL\n" + ] + } + ], + "source": [ + "import os\n", + "import json\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "\n", + "load_dotenv(override=True)\n", + "\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "google_api_key = os.getenv('GOOGLE_API_KEY')\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + "\n", + "if google_api_key:\n", + " print(f\"Google API Key exists and begins {google_api_key[:8]}\")\n", + "else:\n", + " print(\"Google API Key not set\")\n", + " \n", + "MODEL_GPT = \"gpt-4.1-mini\"\n", + "MODEL_GEMINI = \"gemini-2.5-pro\"\n", + "\n", + "\n", + "openai = OpenAI()\n", + "\n", + "gemini_url = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n", + "gemini = OpenAI(api_key=google_api_key, base_url=gemini_url)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a3aa8bf", + "metadata": {}, + "outputs": [], + "source": [ + "# Gradio UI, streaming, use of the system prompt to add expertise, and the ability to switch between models\n", + "\n", + "system_message= \"\"\"\n", + " You are an assistant in a software engineering bookstore that analyzes the content of technical books and generates concise, informative summaries for readers.\n", + " Your goal is to help customers quickly understand what each book covers, its practical value, and who would benefit most from reading it.\n", + " Respond in markdown without code blocks.\n", + " Each summary should include:\n", + " Overview: The book’s main topic, scope, and focus area (e.g., software architecture, DevOps, system design).\n", + " Key Insights: The most important lessons, principles, or methodologies discussed.\n", + " Recommended For: The type of reader who would benefit most (e.g., junior developers, engineering managers, backend specialists).\n", + " Related Reads: Suggest one or two similar or complementary titles available in the store.\n", + " Maintain a professional and knowledgeable tone that reflects expertise in software engineering literature. \n", + "\"\"\"\n", + "\n", + "def stream_gpt(prompt):\n", + " messages = [\n", + " {\"role\": \"system\", \"content\": system_message},\n", + " {\"role\": \"user\", \"content\": prompt}\n", + " ]\n", + " stream = openai.chat.completions.create(\n", + " model=MODEL_GPT,\n", + " messages=messages,\n", + " stream=True\n", + " )\n", + " result = \"\"\n", + " for chunk in stream:\n", + " result += chunk.choices[0].delta.content or \"\"\n", + " yield result\n", + "\n", + "def stream_gemini(prompt):\n", + " messages = [\n", + " {\"role\": \"system\", \"content\": system_message},\n", + " {\"role\": \"user\", \"content\": prompt}\n", + " ]\n", + " stream = openai.chat.completions.create(\n", + " model=MODEL_GEMINI,\n", + " messages=messages,\n", + " stream=True\n", + " )\n", + " result = \"\"\n", + " for chunk in stream:\n", + " result += chunk.choices[0].delta.content or \"\"\n", + " yield result\n", + "\n", + "def stream_model(prompt, model):\n", + " if model==\"GPT\":\n", + " result = stream_gpt(prompt)\n", + " elif model==\"Gemini\":\n", + " result = stream_gemini(prompt)\n", + " else:\n", + " raise ValueError(\"Unknown model\")\n", + " yield from result\n", + "\n", + "\n", + "message_input = gr.Textbox(label=\"Your message:\", info=\"Enter a software engineering book title for the LLM\", lines=4)\n", + "model_selector = gr.Dropdown([\"GPT\", \"Gemini\"], label=\"Select model\", value=\"GPT\")\n", + "message_output = gr.Markdown(label=\"Response:\")\n", + "\n", + "view = gr.Interface(\n", + " fn=stream_model,\n", + " title=\"Bookstore Assistant\", \n", + " inputs=[message_input, model_selector], \n", + " outputs=[message_output], \n", + " examples=[\n", + " [\"Explain Clean Code by Robert C. Martin\", \"GPT\"],\n", + " [\"Explain Clean Code by Robert C. Martin\", \"Gemini\"]\n", + " ], \n", + " flagging_mode=\"never\"\n", + " )\n", + "view.launch()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4d7980c", + "metadata": {}, + "outputs": [], + "source": [ + "import sqlite3\n", + "\n", + "DB = \"books.db\"\n", + "\n", + "with sqlite3.connect(DB) as conn:\n", + " cursor = conn.cursor()\n", + " cursor.execute('CREATE TABLE IF NOT EXISTS prices (title TEXT PRIMARY KEY, price REAL)')\n", + " conn.commit()\n", + "\n", + "def get_book_price(title):\n", + " print(f\"DATABASE TOOL CALLED: Getting price for {title}\", flush=True)\n", + " with sqlite3.connect(DB) as conn:\n", + " cursor = conn.cursor()\n", + " cursor.execute('SELECT price FROM prices WHERE title = ?', (title.lower(),))\n", + " result = cursor.fetchone()\n", + " return f\"Book -> {title} price is ${result[0]}\" if result else \"No price data available for this title\"\n", + "\n", + "def set_book_price(title, price):\n", + " with sqlite3.connect(DB) as conn:\n", + " cursor = conn.cursor()\n", + " cursor.execute('INSERT INTO prices (title, price) VALUES (?, ?) ON CONFLICT(title) DO UPDATE SET price = ?', (title.lower(), price, price))\n", + " conn.commit()\n", + "\n", + "book_prices = {\"Clean code\":20, \"Clean architecture\": 30, \"System design\": 40, \"Design patterns\": 50}\n", + "for title, price in book_prices.items():\n", + " set_book_price(title, price)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86741761", + "metadata": {}, + "outputs": [], + "source": [ + "# use of a tool\n", + "MODEL = \"gpt-4.1-mini\"\n", + "\n", + "system_message = \"\"\"\n", + "You are a helpful assistant in a software engineering bookstore BookEye. \n", + "Give short, courteous answers, no more than 1 sentence.\n", + "Always be accurate. If you don't know the answer, say so.\n", + "\"\"\"\n", + "\n", + "price_function = {\n", + " \"name\": \"get_book_price\",\n", + " \"description\": \"Get the price of a book.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"book_title\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The title of the book that the customer wants to buy\",\n", + " },\n", + " },\n", + " \"required\": [\"book_title\"],\n", + " \"additionalProperties\": False\n", + " }\n", + "}\n", + "tools = [{\"type\": \"function\", \"function\": price_function}]\n", + "\n", + "\n", + "def talker(message):\n", + " response = openai.audio.speech.create(\n", + " model=\"gpt-4o-mini-tts\",\n", + " voice=\"coral\",\n", + " input=message\n", + " )\n", + " return response.content\n", + "\n", + "def handle_tool_calls(message):\n", + " responses = []\n", + " for tool_call in message.tool_calls:\n", + " if tool_call.function.name == \"get_book_price\":\n", + " arguments = json.loads(tool_call.function.arguments)\n", + " title = arguments.get('book_title')\n", + " price_details = get_book_price(title)\n", + " responses.append({\n", + " \"role\": \"tool\",\n", + " \"content\": price_details,\n", + " \"tool_call_id\": tool_call.id\n", + " })\n", + " return responses\n", + "\n", + "def chat(history):\n", + " history = [{\"role\":h[\"role\"], \"content\":h[\"content\"]} for h in history]\n", + " messages = [{\"role\": \"system\", \"content\": system_message}] + history\n", + " response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n", + "\n", + " while response.choices[0].finish_reason==\"tool_calls\":\n", + " message = response.choices[0].message\n", + " responses = handle_tool_calls(message)\n", + " messages.append(message)\n", + " messages.extend(responses)\n", + " response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n", + "\n", + " reply = response.choices[0].message.content\n", + " history += [{\"role\":\"assistant\", \"content\":reply}]\n", + "\n", + " voice = talker(reply)\n", + " \n", + " return history, voice\n", + "\n", + "def put_message_in_chatbot(message, history):\n", + " return \"\", history + [{\"role\":\"user\", \"content\":message}]\n", + "with gr.Blocks() as ui:\n", + " with gr.Row():\n", + " chatbot = gr.Chatbot(height=300, type=\"messages\")\n", + " audio_output = gr.Audio(autoplay=True)\n", + " \n", + " with gr.Row():\n", + " message = gr.Textbox(label=\"Chat with our AI Assistant:\")\n", + "\n", + " message.submit(put_message_in_chatbot, inputs=[message, chatbot], outputs=[message, chatbot]).then(\n", + " chat, inputs=chatbot, outputs=[chatbot, audio_output]\n", + " )\n", + "\n", + "ui.launch(inbrowser=True, auth=(\"ted\", \"mowsb\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb b/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb new file mode 100644 index 0000000..a4474af --- /dev/null +++ b/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb @@ -0,0 +1,906 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a8dbb4e8", + "metadata": {}, + "source": [ + "# 🧪 Survey Synthetic Dataset Generator — Week 3 Task" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d86f629", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import os, re, json, time, uuid, math, random\n", + "from datetime import datetime, timedelta\n", + "from typing import List, Dict, Any\n", + "import numpy as np, pandas as pd\n", + "import pandera.pandas as pa\n", + "random.seed(7); np.random.seed(7)\n", + "print(\"✅ Base libraries ready. Pandera available:\", pa is not None)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f196ae73", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def extract_strict_json(text: str):\n", + " \"\"\"Improved JSON extraction with multiple fallback strategies\"\"\"\n", + " if text is None:\n", + " raise ValueError(\"Empty model output.\")\n", + " \n", + " t = text.strip()\n", + " \n", + " # Strategy 1: Direct JSON parsing\n", + " try:\n", + " obj = json.loads(t)\n", + " if isinstance(obj, list):\n", + " return obj\n", + " elif isinstance(obj, dict):\n", + " for key in (\"rows\",\"data\",\"items\",\"records\",\"results\"):\n", + " if key in obj and isinstance(obj[key], list):\n", + " return obj[key]\n", + " if all(isinstance(k, str) and k.isdigit() for k in obj.keys()):\n", + " return [obj[k] for k in sorted(obj.keys(), key=int)]\n", + " except json.JSONDecodeError:\n", + " pass\n", + " \n", + " # Strategy 2: Extract JSON from code blocks\n", + " if t.startswith(\"```\"):\n", + " t = re.sub(r\"^```(?:json)?\\s*|\\s*```$\", \"\", t, flags=re.IGNORECASE|re.MULTILINE).strip()\n", + " \n", + " # Strategy 3: Find JSON array in text\n", + " start, end = t.find('['), t.rfind(']')\n", + " if start == -1 or end == -1 or end <= start:\n", + " raise ValueError(\"No JSON array found in model output.\")\n", + " \n", + " t = t[start:end+1]\n", + " \n", + " # Strategy 4: Fix common JSON issues\n", + " t = re.sub(r\",\\s*([\\]}])\", r\"\\1\", t) # Remove trailing commas\n", + " t = re.sub(r\"\\bNaN\\b|\\bInfinity\\b|\\b-Infinity\\b\", \"null\", t) # Replace NaN/Infinity\n", + " t = t.replace(\"\\u00a0\", \" \").replace(\"\\u200b\", \"\") # Remove invisible characters\n", + " \n", + " try:\n", + " return json.loads(t)\n", + " except json.JSONDecodeError as e:\n", + " raise ValueError(f\"Could not parse JSON: {str(e)}. Text: {t[:200]}...\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "3670fa0d", + "metadata": {}, + "source": [ + "## 1) Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d16bd03a", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "CFG = {\n", + " \"rows\": 800,\n", + " \"datetime_range\": {\"start\": \"2024-01-01\", \"end\": \"2025-10-01\", \"fmt\": \"%Y-%m-%d %H:%M:%S\"},\n", + " \"fields\": [\n", + " {\"name\": \"response_id\", \"type\": \"uuid4\"},\n", + " {\"name\": \"respondent_id\", \"type\": \"int\", \"min\": 10000, \"max\": 99999},\n", + " {\"name\": \"submitted_at\", \"type\": \"datetime\"},\n", + " {\"name\": \"country\", \"type\": \"enum\", \"values\": [\"KE\",\"UG\",\"TZ\",\"RW\",\"NG\",\"ZA\"], \"probs\": [0.50,0.10,0.12,0.05,0.15,0.08]},\n", + " {\"name\": \"language\", \"type\": \"enum\", \"values\": [\"en\",\"sw\"], \"probs\": [0.85,0.15]},\n", + " {\"name\": \"device\", \"type\": \"enum\", \"values\": [\"android\",\"ios\",\"web\"], \"probs\": [0.60,0.25,0.15]},\n", + " {\"name\": \"age\", \"type\": \"int\", \"min\": 18, \"max\": 70},\n", + " {\"name\": \"gender\", \"type\": \"enum\", \"values\": [\"female\",\"male\",\"nonbinary\",\"prefer_not_to_say\"], \"probs\": [0.49,0.49,0.01,0.01]},\n", + " {\"name\": \"education\", \"type\": \"enum\", \"values\": [\"primary\",\"secondary\",\"diploma\",\"bachelor\",\"postgraduate\"], \"probs\": [0.08,0.32,0.18,0.30,0.12]},\n", + " {\"name\": \"income_band\", \"type\": \"enum\", \"values\": [\"low\",\"lower_mid\",\"upper_mid\",\"high\"], \"probs\": [0.28,0.42,0.23,0.07]},\n", + " {\"name\": \"completion_seconds\", \"type\": \"float\", \"min\": 60, \"max\": 1800, \"distribution\": \"lognormal\"},\n", + " {\"name\": \"attention_passed\", \"type\": \"bool\"},\n", + " {\"name\": \"q_quality\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"q_value\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"q_ease\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"q_support\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"nps\", \"type\": \"int\", \"min\": 0, \"max\": 10},\n", + " {\"name\": \"is_detractor\", \"type\": \"bool\"}\n", + " ]\n", + "}\n", + "print(\"Loaded config for\", CFG[\"rows\"], \"rows and\", len(CFG[\"fields\"]), \"fields.\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "7da1f429", + "metadata": {}, + "source": [ + "## 2) Helpers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2f5fdff", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def sample_enum(values, probs=None, size=None):\n", + " values = list(values)\n", + " if probs is None:\n", + " probs = [1.0 / len(values)] * len(values)\n", + " return np.random.choice(values, p=probs, size=size)\n", + "\n", + "def sample_numeric(field_cfg, size=1):\n", + " t = field_cfg[\"type\"]\n", + " if t == \"int\":\n", + " lo, hi = int(field_cfg[\"min\"]), int(field_cfg[\"max\"])\n", + " dist = field_cfg.get(\"distribution\", \"uniform\")\n", + " if dist == \"uniform\":\n", + " return np.random.randint(lo, hi + 1, size=size)\n", + " elif dist == \"normal\":\n", + " mu = (lo + hi) / 2.0\n", + " sigma = (hi - lo) / 6.0\n", + " out = np.random.normal(mu, sigma, size=size)\n", + " return np.clip(out, lo, hi).astype(int)\n", + " else:\n", + " return np.random.randint(lo, hi + 1, size=size)\n", + " elif t == \"float\":\n", + " lo, hi = float(field_cfg[\"min\"]), float(field_cfg[\"max\"])\n", + " dist = field_cfg.get(\"distribution\", \"uniform\")\n", + " if dist == \"uniform\":\n", + " return np.random.uniform(lo, hi, size=size)\n", + " elif dist == \"normal\":\n", + " mu = (lo + hi) / 2.0\n", + " sigma = (hi - lo) / 6.0\n", + " return np.clip(np.random.normal(mu, sigma, size=size), lo, hi)\n", + " elif dist == \"lognormal\":\n", + " mu = math.log(max(1e-3, (lo + hi) / 2.0))\n", + " sigma = 0.75\n", + " out = np.random.lognormal(mu, sigma, size=size)\n", + " return np.clip(out, lo, hi)\n", + " else:\n", + " return np.random.uniform(lo, hi, size=size)\n", + " else:\n", + " raise ValueError(\"Unsupported numeric type\")\n", + "\n", + "def sample_datetime(start: str, end: str, size=1, fmt=\"%Y-%m-%d %H:%M:%S\"):\n", + " s = datetime.fromisoformat(start)\n", + " e = datetime.fromisoformat(end)\n", + " total = int((e - s).total_seconds())\n", + " r = np.random.randint(0, total, size=size)\n", + " return [(s + timedelta(seconds=int(x))).strftime(fmt) for x in r]\n" + ] + }, + { + "cell_type": "markdown", + "id": "5f24111a", + "metadata": {}, + "source": [ + "## 3) Rule-based Generator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd61330d", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def generate_rule_based(CFG: Dict[str, Any]) -> pd.DataFrame:\n", + " n = CFG[\"rows\"]\n", + " dt_cfg = CFG.get(\"datetime_range\", {\"start\":\"2024-01-01\",\"end\":\"2025-10-01\",\"fmt\":\"%Y-%m-%d %H:%M:%S\"})\n", + " data = {}\n", + " for f in CFG[\"fields\"]:\n", + " name, t = f[\"name\"], f[\"type\"]\n", + " if t == \"uuid4\":\n", + " data[name] = [str(uuid.uuid4()) for _ in range(n)]\n", + " elif t in (\"int\",\"float\"):\n", + " data[name] = sample_numeric(f, size=n)\n", + " elif t == \"enum\":\n", + " data[name] = sample_enum(f[\"values\"], f.get(\"probs\"), size=n)\n", + " elif t == \"datetime\":\n", + " data[name] = sample_datetime(dt_cfg[\"start\"], dt_cfg[\"end\"], size=n, fmt=dt_cfg[\"fmt\"])\n", + " elif t == \"bool\":\n", + " data[name] = np.random.rand(n) < 0.9 # 90% True\n", + " else:\n", + " data[name] = [None]*n\n", + " df = pd.DataFrame(data)\n", + "\n", + " # Derive NPS roughly from likert questions\n", + " if set([\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]).issubset(df.columns):\n", + " likert_avg = df[[\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]].mean(axis=1)\n", + " df[\"nps\"] = np.clip(np.round((likert_avg - 1.0) * (10.0/4.0) + np.random.normal(0, 1.2, size=n)), 0, 10).astype(int)\n", + "\n", + " # Heuristic target: is_detractor more likely when completion high & attention failed\n", + " if \"is_detractor\" in df.columns:\n", + " base = 0.25\n", + " comp = df.get(\"completion_seconds\", pd.Series(np.zeros(n)))\n", + " attn = pd.Series(df.get(\"attention_passed\", np.ones(n))).astype(bool)\n", + " boost = (comp > 900).astype(int) + (~attn).astype(int)\n", + " p = np.clip(base + 0.15*boost, 0.01, 0.95)\n", + " df[\"is_detractor\"] = np.random.rand(n) < p\n", + "\n", + " return df\n", + "\n", + "df_rule = generate_rule_based(CFG)\n", + "df_rule.head()\n" + ] + }, + { + "cell_type": "markdown", + "id": "dd9eff20", + "metadata": {}, + "source": [ + "## 4) Validation (Pandera optional)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a4ef86a", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def build_pandera_schema(CFG):\n", + " if pa is None:\n", + " return None\n", + " cols = {}\n", + " for f in CFG[\"fields\"]:\n", + " t, name = f[\"type\"], f[\"name\"]\n", + " if t == \"int\": cols[name] = pa.Column(int)\n", + " elif t == \"float\": cols[name] = pa.Column(float)\n", + " elif t == \"enum\": cols[name] = pa.Column(object)\n", + " elif t == \"datetime\": cols[name] = pa.Column(object)\n", + " elif t == \"uuid4\": cols[name] = pa.Column(object)\n", + " elif t == \"bool\": cols[name] = pa.Column(bool)\n", + " else: cols[name] = pa.Column(object)\n", + " return pa.DataFrameSchema(cols) if pa is not None else None\n", + "\n", + "def validate_df(df, CFG):\n", + " schema = build_pandera_schema(CFG)\n", + " if schema is None:\n", + " return df, {\"engine\":\"basic\",\"valid_rows\": len(df), \"invalid_rows\": 0}\n", + " try:\n", + " v = schema.validate(df, lazy=True)\n", + " return v, {\"engine\":\"pandera\",\"valid_rows\": len(v), \"invalid_rows\": 0}\n", + " except Exception as e:\n", + " print(\"Validation error:\", e)\n", + " return df, {\"engine\":\"pandera\",\"valid_rows\": len(df), \"invalid_rows\": 0, \"notes\": \"Non-strict mode.\"}\n", + "\n", + "validated_rule, report_rule = validate_df(df_rule, CFG)\n", + "print(report_rule)\n", + "validated_rule.head()\n" + ] + }, + { + "cell_type": "markdown", + "id": "d5f1d93a", + "metadata": {}, + "source": [ + "## 5) Save" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73626b4c", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from pathlib import Path\n", + "out = Path(\"data\"); out.mkdir(exist_ok=True)\n", + "ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n", + "csv_path = out / f\"survey_rule_{ts}.csv\"\n", + "validated_rule.to_csv(csv_path, index=False)\n", + "print(\"Saved:\", csv_path.as_posix())\n" + ] + }, + { + "cell_type": "markdown", + "id": "87c89b51", + "metadata": {}, + "source": [ + "## 6) Optional: LLM Generator (JSON mode, retry & strict parsing)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24e94771", + "metadata": {}, + "outputs": [], + "source": [ + "# Fixed LLM Generation Functions\n", + "def create_survey_prompt(CFG, n_rows=50):\n", + " \"\"\"Create a clear, structured prompt for survey data generation\"\"\"\n", + " fields_desc = []\n", + " for field in CFG['fields']:\n", + " name = field['name']\n", + " field_type = field['type']\n", + " \n", + " if field_type == 'int':\n", + " min_val = field.get('min', 0)\n", + " max_val = field.get('max', 100)\n", + " fields_desc.append(f\" - {name}: integer between {min_val} and {max_val}\")\n", + " elif field_type == 'float':\n", + " min_val = field.get('min', 0.0)\n", + " max_val = field.get('max', 100.0)\n", + " fields_desc.append(f\" - {name}: float between {min_val} and {max_val}\")\n", + " elif field_type == 'enum':\n", + " values = field.get('values', [])\n", + " fields_desc.append(f\" - {name}: one of {values}\")\n", + " elif field_type == 'bool':\n", + " fields_desc.append(f\" - {name}: boolean (true/false)\")\n", + " elif field_type == 'uuid4':\n", + " fields_desc.append(f\" - {name}: UUID string\")\n", + " elif field_type == 'datetime':\n", + " fmt = field.get('fmt', '%Y-%m-%d %H:%M:%S')\n", + " fields_desc.append(f\" - {name}: datetime string in format {fmt}\")\n", + " else:\n", + " fields_desc.append(f\" - {name}: {field_type}\")\n", + " \n", + " prompt = f\"\"\"Generate {n_rows} rows of realistic survey response data.\n", + "\n", + "Schema:\n", + "{chr(10).join(fields_desc)}\n", + "\n", + "CRITICAL REQUIREMENTS:\n", + "- Return a JSON object with a \"responses\" key containing an array\n", + "- Each object in the array must have all required fields\n", + "- Use realistic, diverse values for survey responses\n", + "- No trailing commas\n", + "- No comments or explanations\n", + "\n", + "Output format: JSON object with \"responses\" array containing exactly {n_rows} objects.\n", + "\n", + "Example structure:\n", + "{{\n", + " \"responses\": [\n", + " {{\n", + " \"response_id\": \"uuid-string\",\n", + " \"respondent_id\": 12345,\n", + " \"submitted_at\": \"2024-01-01 12:00:00\",\n", + " \"country\": \"KE\",\n", + " \"language\": \"en\",\n", + " \"device\": \"android\",\n", + " \"age\": 25,\n", + " \"gender\": \"female\",\n", + " \"education\": \"bachelor\",\n", + " \"income_band\": \"upper_mid\",\n", + " \"completion_seconds\": 300.5,\n", + " \"attention_passed\": true,\n", + " \"q_quality\": 4,\n", + " \"q_value\": 3,\n", + " \"q_ease\": 5,\n", + " \"q_support\": 4,\n", + " \"nps\": 8,\n", + " \"is_detractor\": false\n", + " }},\n", + " ...\n", + " ]\n", + "}}\n", + "\n", + "IMPORTANT: Return ONLY the JSON object with \"responses\" key, nothing else.\"\"\"\n", + " \n", + " return prompt\n", + "\n", + "def repair_truncated_json(content):\n", + " \"\"\"Attempt to repair truncated JSON responses\"\"\"\n", + " content = content.strip()\n", + " \n", + " # If it starts with { but doesn't end with }, try to close it\n", + " if content.startswith('{') and not content.endswith('}'):\n", + " # Find the last complete object in the responses array\n", + " responses_start = content.find('\"responses\": [')\n", + " if responses_start != -1:\n", + " # Find the last complete object\n", + " brace_count = 0\n", + " last_complete_pos = -1\n", + " in_string = False\n", + " escape_next = False\n", + " \n", + " for i, char in enumerate(content[responses_start:], responses_start):\n", + " if escape_next:\n", + " escape_next = False\n", + " continue\n", + " \n", + " if char == '\\\\':\n", + " escape_next = True\n", + " continue\n", + " \n", + " if char == '\"' and not escape_next:\n", + " in_string = not in_string\n", + " continue\n", + " \n", + " if not in_string:\n", + " if char == '{':\n", + " brace_count += 1\n", + " elif char == '}':\n", + " brace_count -= 1\n", + " if brace_count == 0:\n", + " last_complete_pos = i\n", + " break\n", + " \n", + " if last_complete_pos != -1:\n", + " # Truncate at the last complete object and close the JSON\n", + " repaired = content[:last_complete_pos + 1] + '\\n ]\\n}'\n", + " print(f\"🔧 Repaired JSON: truncated at position {last_complete_pos}\")\n", + " return repaired\n", + " \n", + " return content\n", + "\n", + "def fixed_llm_generate_batch(CFG, n_rows=50):\n", + " \"\"\"Fixed LLM generation with better prompt and error handling\"\"\"\n", + " if not os.getenv('OPENAI_API_KEY'):\n", + " print(\"No OpenAI API key, using rule-based fallback\")\n", + " tmp = dict(CFG); tmp['rows'] = n_rows\n", + " return generate_rule_based(tmp)\n", + " \n", + " try:\n", + " from openai import OpenAI\n", + " client = OpenAI()\n", + " \n", + " prompt = create_survey_prompt(CFG, n_rows)\n", + " \n", + " print(f\"🔄 Generating {n_rows} survey responses with LLM...\")\n", + " \n", + " # Calculate appropriate max_tokens based on batch size\n", + " # Roughly 200-300 tokens per row, with some buffer\n", + " estimated_tokens = n_rows * 300 + 500 # Buffer for JSON structure\n", + " max_tokens = min(max(estimated_tokens, 2000), 8000) # Between 2k-8k tokens\n", + " \n", + " print(f\"📊 Using max_tokens: {max_tokens} (estimated: {estimated_tokens})\")\n", + " \n", + " response = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " messages=[\n", + " {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format. Always return complete, valid JSON.'},\n", + " {'role': 'user', 'content': prompt}\n", + " ],\n", + " temperature=0.3,\n", + " max_tokens=max_tokens,\n", + " response_format={'type': 'json_object'}\n", + " )\n", + " \n", + " content = response.choices[0].message.content\n", + " print(f\"📝 Raw response length: {len(content)} characters\")\n", + " \n", + " # Check if response appears truncated\n", + " if not content.strip().endswith('}') and not content.strip().endswith(']'):\n", + " print(\"⚠️ Response appears truncated, attempting repair...\")\n", + " content = repair_truncated_json(content)\n", + " \n", + " # Try to extract JSON with improved logic\n", + " try:\n", + " data = json.loads(content)\n", + " print(f\"🔍 Parsed JSON type: {type(data)}\")\n", + " \n", + " if isinstance(data, list):\n", + " df = pd.DataFrame(data)\n", + " print(f\"📊 Direct array: {len(df)} rows\")\n", + " elif isinstance(data, dict):\n", + " # Check for common keys that might contain the data\n", + " for key in ['responses', 'rows', 'data', 'items', 'records', 'results', 'survey_responses']:\n", + " if key in data and isinstance(data[key], list):\n", + " df = pd.DataFrame(data[key])\n", + " print(f\"📊 Found data in '{key}': {len(df)} rows\")\n", + " break\n", + " else:\n", + " # If no standard key found, check if all values are lists/objects\n", + " list_keys = [k for k, v in data.items() if isinstance(v, list) and len(v) > 0]\n", + " if list_keys:\n", + " # Use the first list key found\n", + " key = list_keys[0]\n", + " df = pd.DataFrame(data[key])\n", + " print(f\"📊 Found data in '{key}': {len(df)} rows\")\n", + " else:\n", + " # Try to convert the dict values to a list\n", + " if all(isinstance(v, dict) for v in data.values()):\n", + " df = pd.DataFrame(list(data.values()))\n", + " print(f\"📊 Converted dict values: {len(df)} rows\")\n", + " else:\n", + " raise ValueError(f\"Unexpected JSON structure: {list(data.keys())}\")\n", + " else:\n", + " raise ValueError(f\"Unexpected JSON type: {type(data)}\")\n", + " \n", + " if len(df) == n_rows:\n", + " print(f\"✅ Successfully generated {len(df)} survey responses\")\n", + " return df\n", + " else:\n", + " print(f\"⚠️ Generated {len(df)} rows, expected {n_rows}\")\n", + " if len(df) > 0:\n", + " return df\n", + " else:\n", + " raise ValueError(\"No data generated\")\n", + " \n", + " except json.JSONDecodeError as e:\n", + " print(f\"❌ JSON parsing failed: {str(e)}\")\n", + " # Try the improved extract_strict_json function\n", + " try:\n", + " data = extract_strict_json(content)\n", + " df = pd.DataFrame(data)\n", + " print(f\"✅ Recovered with strict parsing: {len(df)} rows\")\n", + " return df\n", + " except Exception as e2:\n", + " print(f\"❌ Strict parsing also failed: {str(e2)}\")\n", + " # Print a sample of the content for debugging\n", + " print(f\"🔍 Content sample: {content[:500]}...\")\n", + " raise e2\n", + " \n", + " except Exception as e:\n", + " print(f'❌ LLM error, fallback to rule-based mock: {str(e)}')\n", + " tmp = dict(CFG); tmp['rows'] = n_rows\n", + " return generate_rule_based(tmp)\n", + "\n", + "def fixed_generate_llm(CFG, total_rows=200, batch_size=50):\n", + " \"\"\"Fixed LLM generation with adaptive batch processing\"\"\"\n", + " print(f\"🚀 Generating {total_rows} survey responses with adaptive batching\")\n", + " \n", + " # Adaptive batch sizing based on total rows\n", + " if total_rows <= 20:\n", + " optimal_batch_size = min(batch_size, total_rows)\n", + " elif total_rows <= 50:\n", + " optimal_batch_size = min(15, batch_size)\n", + " elif total_rows <= 100:\n", + " optimal_batch_size = min(10, batch_size)\n", + " else:\n", + " optimal_batch_size = min(8, batch_size)\n", + " \n", + " print(f\"📊 Using optimal batch size: {optimal_batch_size}\")\n", + " \n", + " all_dataframes = []\n", + " remaining = total_rows\n", + " \n", + " while remaining > 0:\n", + " current_batch_size = min(optimal_batch_size, remaining)\n", + " print(f\"\\n📦 Processing batch: {current_batch_size} rows (remaining: {remaining})\")\n", + " \n", + " try:\n", + " batch_df = fixed_llm_generate_batch(CFG, current_batch_size)\n", + " all_dataframes.append(batch_df)\n", + " remaining -= len(batch_df)\n", + " \n", + " # Small delay between batches to avoid rate limits\n", + " if remaining > 0:\n", + " time.sleep(1.5)\n", + " \n", + " except Exception as e:\n", + " print(f\"❌ Batch failed: {str(e)}\")\n", + " print(f\"🔄 Retrying with smaller batch size...\")\n", + " \n", + " # Try with smaller batch size\n", + " smaller_batch = max(1, current_batch_size // 2)\n", + " if smaller_batch < current_batch_size:\n", + " try:\n", + " print(f\"🔄 Retrying with {smaller_batch} rows...\")\n", + " batch_df = fixed_llm_generate_batch(CFG, smaller_batch)\n", + " all_dataframes.append(batch_df)\n", + " remaining -= len(batch_df)\n", + " continue\n", + " except Exception as e2:\n", + " print(f\"❌ Retry also failed: {str(e2)}\")\n", + " \n", + " print(f\"Using rule-based fallback for remaining {remaining} rows\")\n", + " fallback_df = generate_rule_based(CFG, remaining)\n", + " all_dataframes.append(fallback_df)\n", + " break\n", + " \n", + " if all_dataframes:\n", + " result = pd.concat(all_dataframes, ignore_index=True)\n", + " print(f\"✅ Generated total: {len(result)} survey responses\")\n", + " return result\n", + " else:\n", + " print(\"❌ No data generated\")\n", + " return pd.DataFrame()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1af410e", + "metadata": {}, + "outputs": [], + "source": [ + "# Test the fixed LLM generation\n", + "print(\"🧪 Testing LLM generation...\")\n", + "\n", + "# Test with small dataset first\n", + "test_df = fixed_llm_generate_batch(CFG, 10)\n", + "print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n", + "print(f\"\\n📋 First few rows:\")\n", + "print(test_df.head())\n", + "print(f\"\\n📈 Data types:\")\n", + "print(test_df.dtypes)\n", + "\n", + "# Debug function to see what the LLM is actually returning\n", + "def debug_llm_response(CFG, n_rows=5):\n", + " \"\"\"Debug function to see raw LLM response\"\"\"\n", + " if not os.getenv('OPENAI_API_KEY'):\n", + " print(\"No OpenAI API key available for debugging\")\n", + " return\n", + " \n", + " try:\n", + " from openai import OpenAI\n", + " client = OpenAI()\n", + " \n", + " prompt = create_survey_prompt(CFG, n_rows)\n", + " \n", + " print(f\"\\n🔍 DEBUG: Testing with {n_rows} rows\")\n", + " print(f\"📝 Prompt length: {len(prompt)} characters\")\n", + " \n", + " response = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " messages=[\n", + " {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format.'},\n", + " {'role': 'user', 'content': prompt}\n", + " ],\n", + " temperature=0.3,\n", + " max_tokens=2000,\n", + " response_format={'type': 'json_object'}\n", + " )\n", + " \n", + " content = response.choices[0].message.content\n", + " print(f\"📝 Raw response length: {len(content)} characters\")\n", + " print(f\"🔍 First 200 characters: {content[:200]}\")\n", + " print(f\"🔍 Last 200 characters: {content[-200:]}\")\n", + " \n", + " # Try to parse\n", + " try:\n", + " data = json.loads(content)\n", + " print(f\"✅ JSON parsed successfully\")\n", + " print(f\"🔍 Data type: {type(data)}\")\n", + " if isinstance(data, dict):\n", + " print(f\"🔍 Dict keys: {list(data.keys())}\")\n", + " elif isinstance(data, list):\n", + " print(f\"🔍 List length: {len(data)}\")\n", + " except Exception as e:\n", + " print(f\"❌ JSON parsing failed: {str(e)}\")\n", + " \n", + " except Exception as e:\n", + " print(f\"❌ Debug failed: {str(e)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75c90739", + "metadata": {}, + "outputs": [], + "source": [ + "# Test the fixed implementation\n", + "print(\"🧪 Testing the fixed LLM generation...\")\n", + "\n", + "# Test with small dataset\n", + "test_df = fixed_llm_generate_batch(CFG, 5)\n", + "print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n", + "print(f\"\\n📋 First few rows:\")\n", + "print(test_df.head())\n", + "print(f\"\\n📈 Data types:\")\n", + "print(test_df.dtypes)\n", + "\n", + "if not test_df.empty:\n", + " print(f\"\\n✅ SUCCESS! LLM generation is now working!\")\n", + " print(f\"📊 Generated {len(test_df)} survey responses using LLM\")\n", + "else:\n", + " print(f\"\\n❌ Still having issues with LLM generation\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd83b842", + "metadata": {}, + "outputs": [], + "source": [ + "#Test larger dataset generation \n", + "print(\"🚀 Testing larger dataset generation...\")\n", + "large_df = fixed_generate_llm(CFG, total_rows=100, batch_size=25)\n", + "if not large_df.empty:\n", + " print(f\"\\n📊 Large dataset shape: {large_df.shape}\")\n", + " print(f\"\\n📈 Summary statistics:\")\n", + " print(large_df.describe())\n", + " \n", + " # Save the results\n", + " from pathlib import Path\n", + " out = Path(\"data\"); out.mkdir(exist_ok=True)\n", + " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n", + " csv_path = out / f\"survey_llm_fixed_{ts}.csv\"\n", + " large_df.to_csv(csv_path, index=False)\n", + " print(f\"💾 Saved: {csv_path}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6029d3e2", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def build_json_schema(CFG):\n", + " schema = {'type':'array','items':{'type':'object','properties':{},'required':[]}}\n", + " props = schema['items']['properties']; req = schema['items']['required']\n", + " for f in CFG['fields']:\n", + " name, t = f['name'], f['type']\n", + " req.append(name)\n", + " if t in ('int','float'): props[name] = {'type':'number' if t=='float' else 'integer'}\n", + " elif t == 'enum': props[name] = {'type':'string','enum': f['values']}\n", + " elif t in ('uuid4','datetime'): props[name] = {'type':'string'}\n", + " elif t == 'bool': props[name] = {'type':'boolean'}\n", + " else: props[name] = {'type':'string'}\n", + " return schema\n", + "\n", + "PROMPT_PREAMBLE = (\n", + " \"You are a data generator. Return ONLY JSON. \"\n", + " \"Respond as a JSON object with key 'rows' whose value is an array of exactly N objects. \"\n", + " \"No prose, no code fences, no trailing commas.\"\n", + ")\n", + "\n", + "def render_prompt(CFG, n_rows=100):\n", + " minimal_cfg = {'fields': []}\n", + " for f in CFG['fields']:\n", + " base = {k: f[k] for k in ['name','type'] if k in f}\n", + " if 'min' in f and 'max' in f: base.update({'min': f['min'], 'max': f['max']})\n", + " if 'values' in f: base.update({'values': f['values']})\n", + " if 'fmt' in f: base.update({'fmt': f['fmt']})\n", + " minimal_cfg['fields'].append(base)\n", + " return {\n", + " 'preamble': PROMPT_PREAMBLE,\n", + " 'n_rows': n_rows,\n", + " 'schema': build_json_schema(CFG),\n", + " 'constraints': minimal_cfg,\n", + " 'instruction': f\"Return ONLY this structure: {{'rows': [ ... exactly {n_rows} objects ... ]}}\"\n", + " }\n", + "\n", + "def parse_llm_json_to_df(raw: str) -> pd.DataFrame:\n", + " try:\n", + " obj = json.loads(raw)\n", + " if isinstance(obj, dict) and isinstance(obj.get('rows'), list):\n", + " return pd.DataFrame(obj['rows'])\n", + " except Exception:\n", + " pass\n", + " data = extract_strict_json(raw)\n", + " return pd.DataFrame(data)\n", + "\n", + "USE_LLM = bool(os.getenv('OPENAI_API_KEY'))\n", + "print('LLM available:', USE_LLM)\n", + "\n", + "def llm_generate_batch(CFG, n_rows=50):\n", + " if USE_LLM:\n", + " try:\n", + " from openai import OpenAI\n", + " client = OpenAI()\n", + " prompt = json.dumps(render_prompt(CFG, n_rows))\n", + " resp = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " response_format={'type': 'json_object'},\n", + " messages=[\n", + " {'role':'system','content':'You output strict JSON only.'},\n", + " {'role':'user','content': prompt}\n", + " ],\n", + " temperature=0.2,\n", + " max_tokens=8192,\n", + " )\n", + " raw = resp.choices[0].message.content\n", + " try:\n", + " return parse_llm_json_to_df(raw)\n", + " except Exception:\n", + " stricter = (\n", + " prompt\n", + " + \"\\nReturn ONLY a JSON object structured as: \"\n", + " + \"{\\\"rows\\\": [ ... exactly N objects ... ]}. \"\n", + " + \"No prose, no explanations.\"\n", + " )\n", + " resp2 = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " response_format={'type': 'json_object'},\n", + " messages=[\n", + " {'role':'system','content':'You output strict JSON only.'},\n", + " {'role':'user','content': stricter}\n", + " ],\n", + " temperature=0.2,\n", + " max_tokens=8192,\n", + " )\n", + " raw2 = resp2.choices[0].message.content\n", + " return parse_llm_json_to_df(raw2)\n", + " except Exception as e:\n", + " print('LLM error, fallback to rule-based mock:', e)\n", + " tmp = dict(CFG); tmp['rows'] = n_rows\n", + " return generate_rule_based(tmp)\n", + "\n", + "def generate_llm(CFG, total_rows=200, batch_size=50):\n", + " dfs = []; remaining = total_rows\n", + " while remaining > 0:\n", + " b = min(batch_size, remaining)\n", + " dfs.append(llm_generate_batch(CFG, n_rows=b))\n", + " remaining -= b\n", + " time.sleep(0.2)\n", + " return pd.concat(dfs, ignore_index=True)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e759087", + "metadata": {}, + "outputs": [], + "source": [ + "df_llm = generate_llm(CFG, total_rows=100, batch_size=50)\n", + "df_llm.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d4908ad", + "metadata": {}, + "outputs": [], + "source": [ + "# Test the improved LLM generation with adaptive batching\n", + "print(\"🧪 Testing improved LLM generation with adaptive batching...\")\n", + "\n", + "# Test with smaller dataset first\n", + "print(\"\\n📦 Testing small batch (10 rows)...\")\n", + "small_df = fixed_llm_generate_batch(CFG, 10)\n", + "print(f\"✅ Small batch result: {len(small_df)} rows\")\n", + "\n", + "# Test with medium dataset using adaptive batching\n", + "print(\"\\n📦 Testing medium dataset (30 rows) with adaptive batching...\")\n", + "medium_df = fixed_generate_llm(CFG, total_rows=30, batch_size=15)\n", + "print(f\"✅ Medium dataset result: {len(medium_df)} rows\")\n", + "\n", + "if not medium_df.empty:\n", + " print(f\"\\n📊 Dataset shape: {medium_df.shape}\")\n", + " print(f\"\\n📋 First few rows:\")\n", + " print(medium_df.head())\n", + " \n", + " # Save the results\n", + " from pathlib import Path\n", + " out = Path(\"data\"); out.mkdir(exist_ok=True)\n", + " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n", + " csv_path = out / f\"survey_adaptive_batch_{ts}.csv\"\n", + " medium_df.to_csv(csv_path, index=False)\n", + " print(f\"💾 Saved: {csv_path}\")\n", + "else:\n", + " print(\"❌ Medium dataset generation failed\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week3/community-contributions/week3_exercise_solution-Stephen.ipynb b/week3/community-contributions/week3_exercise_solution-Stephen.ipynb new file mode 100644 index 0000000..bbc99e7 --- /dev/null +++ b/week3/community-contributions/week3_exercise_solution-Stephen.ipynb @@ -0,0 +1,216 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c58e628f", + "metadata": {}, + "source": [ + "\n", + "## **Week 3 task.**\n", + "Create your own tool that generates synthetic data/test data. Input the type of dataset or products or job postings, etc. and let the tool dream up various data samples.\n", + "\n", + "https://colab.research.google.com/drive/13wR4Blz3Ot_x0GOpflmvvFffm5XU3Kct?usp=sharing" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0ddde9ed", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "\n", + "import os\n", + "import requests\n", + "import torch\n", + "from IPython.display import Markdown, display, update_display\n", + "from openai import OpenAI\n", + "from huggingface_hub import login\n", + "from huggingface_hub import login\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig\n", + "from dotenv import load_dotenv\n", + "import gradio as gr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbbc6cc8", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "load_dotenv(override=True)\n", + "\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "llama_api_key = \"ollama\"\n", + "\n", + "# hf_token = userdata.get('HF_TOKEN')\n", + "# login(hf_token, add_to_git_credential=True)\n", + "\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + "\n", + "if llama_api_key:\n", + " print(f\"LLama API Key exists\")\n", + "else:\n", + " print(\"LLama API Key not set\")\n", + " \n", + "GPT_MODEL = \"gpt-4.1-mini\"\n", + "LLAMA_MODEL = \"llama3.1\"\n", + "\n", + "\n", + "openai = OpenAI()\n", + "\n", + "llama_url = \"http://localhost:11434/v1\"\n", + "llama = OpenAI(api_key=llama_api_key, base_url=llama_url)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "ef083ec6", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_with_gpt(user_prompt: str, num_samples: int = 5):\n", + " \"\"\"\n", + " Generates synthetic data using OpenAI's GPT.\n", + " Return a JSON string.\n", + " \"\"\"\n", + " if not openai:\n", + " return json.dumps({\"error\": \"OpenAI client not initialized. Please check your API key.\"}, indent=2)\n", + "\n", + " try:\n", + " response = openai.chat.completions.create(\n", + " model=GPT_MODEL,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": f\"You are a data generation assistant. Generate a JSON array of exactly {num_samples} objects based on the user's request. The output must be valid JSON only, without any other text or formatting.\"},\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ],\n", + " response_format={\"type\": \"json_object\"}\n", + " )\n", + " \n", + " json_text = response.choices[0].message.content\n", + " return json_text\n", + " except APIError as e:\n", + " return json.dumps({\"error\": f\"Error from OpenAI API: {e.body}\"}, indent=2)\n", + " except Exception as e:\n", + " return json.dumps({\"error\": f\"An unexpected error occurred: {e}\"}, indent=2)\n", + "\n", + "def generate_with_gpt(user_prompt: str, num_samples: int = 5):\n", + " \"\"\"\n", + " Generates synthetic data using OpenAI's GPT.\n", + " Return a JSON string.\n", + " \"\"\"\n", + " if not openai:\n", + " return json.dumps({\"error\": \"OpenAI client not initialized. Please check your API key.\"}, indent=2)\n", + "\n", + " try:\n", + " response = openai.chat.completions.create(\n", + " model=GPT_MODEL,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": f\"You are a data generation assistant. Generate a JSON array of exactly {num_samples} objects based on the user's request. The output must be valid JSON only, without any other text or formatting.\"},\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ],\n", + " response_format={\"type\": \"json_object\"}\n", + " )\n", + " \n", + " json_text = response.choices[0].message.content\n", + " return json_text\n", + " except APIError as e:\n", + " return json.dumps({\"error\": f\"Error from OpenAI API: {e.body}\"}, indent=2)\n", + " except Exception as e:\n", + " return json.dumps({\"error\": f\"An unexpected error occurred: {e}\"}, indent=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "b98f84d8", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_data(user_prompt, model_choice):\n", + " \"\"\"\n", + " Wrapper function that calls the appropriate generation function based on model choice.\n", + " \"\"\"\n", + " if not user_prompt:\n", + " return json.dumps({\"error\": \"Please provide a description for the data.\"}, indent=2)\n", + "\n", + " if model_choice == f\"Hugging Face ({LLAMA_MODEL})\":\n", + " return generate_with_llama(user_prompt)\n", + " elif model_choice == f\"OpenAI ({GPT_MODEL})\":\n", + " return generate_with_gpt(user_prompt)\n", + " else:\n", + " return json.dumps({\"error\": \"Invalid model choice.\"}, indent=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adbc19a8", + "metadata": {}, + "outputs": [], + "source": [ + "# Gradio UI\n", + "with gr.Blocks(theme=gr.themes.Glass(), title=\"Synthetic Data Generator\") as ui:\n", + " gr.Markdown(\"# Synthetic Data Generator\")\n", + " gr.Markdown(\"Describe the type of data you need, select a model, and click 'Generate'.\")\n", + "\n", + " with gr.Row():\n", + " with gr.Column(scale=3):\n", + " data_prompt = gr.Textbox(\n", + " lines=5,\n", + " label=\"Data Prompt\",\n", + " placeholder=\"e.g., a list of customer profiles with name, email, and a favorite product\"\n", + " )\n", + " \n", + " with gr.Column(scale=1):\n", + " model_choice = gr.Radio(\n", + " [f\"Hugging Face ({LLAMA_MODEL})\", f\"OpenAI ({GPT_MODEL})\"],\n", + " label=\"Choose a Model\",\n", + " value=f\"Hugging Face ({LLAMA_MODEL})\"\n", + " )\n", + " \n", + " generate_btn = gr.Button(\"Generate Data\")\n", + " \n", + " with gr.Row():\n", + " output_json = gr.JSON(label=\"Generated Data\")\n", + " \n", + " generate_btn.click(\n", + " fn=generate_data,\n", + " inputs=[data_prompt, model_choice],\n", + " outputs=output_json\n", + " )\n", + "\n", + "ui.launch(inbrowser=True, debug=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week4/community-contributions/tochi/code_converter.ipynb b/week4/community-contributions/tochi/code_converter.ipynb new file mode 100644 index 0000000..5101d61 --- /dev/null +++ b/week4/community-contributions/tochi/code_converter.ipynb @@ -0,0 +1,569 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c1fcc6e9", + "metadata": {}, + "source": [ + "# Code Converter - Python to TypeScript Code\n", + "\n", + "This implementation, converts python code to optimized TypeScript Code, and runs the function" + ] + }, + { + "cell_type": "markdown", + "id": "16b6b063", + "metadata": {}, + "source": [ + "## Set up and imports\n" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "id": "b3dc394c", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import os\n", + "import io\n", + "import sys\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import subprocess\n", + "from IPython.display import Markdown, display, display_markdown\n", + "from system_info import retrieve_system_info\n", + "import gradio as gr" + ] + }, + { + "cell_type": "markdown", + "id": "1c9a0936", + "metadata": {}, + "source": [ + "# Initializing the access keys" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "id": "fac104ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OpenAI API Key exists and begins sk-proj-\n" + ] + } + ], + "source": [ + "load_dotenv(override=True)\n", + "openai_api_key = os.getenv(\"OPENAI_API_KEY\")\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set. Check your engironment variables and try again\")" + ] + }, + { + "cell_type": "markdown", + "id": "5932182f", + "metadata": {}, + "source": [ + "# Connecting to client libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "id": "4000f231", + "metadata": {}, + "outputs": [], + "source": [ + "openai = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 118, + "id": "51c67ac0", + "metadata": {}, + "outputs": [], + "source": [ + "# contants\n", + "OPENAI_MODEL= \"gpt-5-nano\"" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "id": "ab4342bf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'os': {'system': 'Darwin',\n", + " 'arch': 'arm64',\n", + " 'release': '24.5.0',\n", + " 'version': 'Darwin Kernel Version 24.5.0: Tue Apr 22 19:48:46 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8103',\n", + " 'kernel': '24.5.0',\n", + " 'distro': None,\n", + " 'wsl': False,\n", + " 'rosetta2_translated': False,\n", + " 'target_triple': 'arm64-apple-darwin24.5.0'},\n", + " 'package_managers': ['xcode-select (CLT)', 'brew'],\n", + " 'cpu': {'brand': 'Apple M1',\n", + " 'cores_logical': 8,\n", + " 'cores_physical': 8,\n", + " 'simd': []},\n", + " 'toolchain': {'compilers': {'gcc': 'Apple clang version 17.0.0 (clang-1700.0.13.3)',\n", + " 'g++': 'Apple clang version 17.0.0 (clang-1700.0.13.3)',\n", + " 'clang': 'Apple clang version 17.0.0 (clang-1700.0.13.3)',\n", + " 'msvc_cl': ''},\n", + " 'build_tools': {'cmake': '', 'ninja': '', 'make': 'GNU Make 3.81'},\n", + " 'linkers': {'ld_lld': ''}}}" + ] + }, + "execution_count": 119, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "system_info = retrieve_system_info()\n", + "system_info" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "id": "1a1c1324", + "metadata": {}, + "outputs": [], + "source": [ + "message = f\"\"\"\n", + "Here is a report of the system information for my computer.\n", + "I want to run a TypeScript compiler to compile a single TypeScript file called main.cpp and then execute it in the simplest way possible.\n", + "Please reply with whether I need to install any TypeScript compiler to do this. If so, please provide the simplest step by step instructions to do so.\n", + "\n", + "If I'm already set up to compile TypeScript code, then I'd like to run something like this in Python to compile and execute the code:\n", + "```python\n", + "compile_command = # something here - to achieve the fastest possible runtime performance\n", + "compile_result = subprocess.run(compile_command, check=True, text=True, capture_output=True)\n", + "run_command = # something here\n", + "run_result = subprocess.run(run_command, check=True, text=True, capture_output=True)\n", + "return run_result.stdout\n", + "```\n", + "Please tell me exactly what I should use for the compile_command and run_command.\n", + "\n", + "System information:\n", + "{system_info}\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "id": "439015c1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "Short answer:\n", + "- Yes, to compile TypeScript you need a TypeScript compiler (tsc). On macOS you’ll typically install Node.js first, then install TypeScript.\n", + "- Important: main.cpp sounds like a C++ file. The TypeScript compiler (tsc) cannot compile .cpp. If you want to use TypeScript, rename the file to main.ts (and ensure its contents are TypeScript). If you actually meant C++, use a C++ compiler instead (clang/g++).\n", + "\n", + "Step-by-step to set up TypeScript (simplest path on your system):\n", + "1) Install Node.js (which also installs npm)\n", + "- brew update\n", + "- brew install node\n", + "\n", + "2) Install the TypeScript compiler globally\n", + "- npm install -g typescript\n", + "\n", + "3) Verify installations\n", + "- node -v\n", + "- npm -v\n", + "- tsc -v\n", + "\n", + "4) Compile and run a TypeScript file (assuming your file is main.ts)\n", + "- tsc main.ts\n", + "- node main.js\n", + "\n", + "Notes:\n", + "- If your file is indeed C++ (main.cpp), you cannot compile it with tsc. To compile C++, use clang++ (on macOS) or g++:\n", + " - clang++ -std=c++17 main.cpp -o main\n", + " - ./main\n", + "\n", + "Python integration (fill-in for your example)\n", + "- If you have a TypeScript file named main.ts and you want to compile it to JavaScript and then run it with Node, use:\n", + " compile_command = [\"tsc\", \"main.ts\"]\n", + " run_command = [\"node\", \"main.js\"]\n", + "\n", + "- If you want to show a single command in Python that compiles and runs in one go (still two steps because TS compiles to JS first):\n", + " compile_command = [\"tsc\", \"main.ts\"]\n", + " run_command = [\"node\", \"main.js\"]\n", + "\n", + "- If you truly want to bypass TypeScript and run C++ instead (not TypeScript):\n", + " compile_command = [\"clang++\", \"-std=c++17\", \"main.cpp\", \"-o\", \"main\"]\n", + " run_command = [\"./main\"]\n", + "\n", + "If you’d like, tell me whether main.cpp is meant to be C++ or you actually have a TypeScript file named main.ts, and I can tailor the exact commands." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "response = openai.chat.completions.create(model=OPENAI_MODEL, messages=[{\"role\":\"user\", \"content\":message}])\n", + "display(Markdown(response.choices[0].message.content))" + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "id": "576cb5fa", + "metadata": {}, + "outputs": [], + "source": [ + "compile_command = [\"tsc\", \"main.ts\", \"--target\", \"ES2020\", \"--module\", \"commonjs\"]\n", + "run_command = [\"ts-node\", \"main.ts\"]" + ] + }, + { + "cell_type": "markdown", + "id": "01b03700", + "metadata": {}, + "source": [ + "## System and user prompts for the code converter" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "id": "255e318b", + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt = \"\"\"\n", + "Your task is to convert Python code into high performance TypeScript code.\n", + "Respond only with TypeScript code. Do not provide any explanation other than occasional comments.\n", + "The TypeScript response needs to produce an identical output in the fastest possible time.\n", + "\"\"\"\n", + "\n", + "\n", + "def user_prompt_for(python):\n", + " return f\"\"\" \n", + " port this Python code to TypeScript with the fastest possible implementation that produces identical output in the least time.\n", + "\n", + " The system information is \n", + "\n", + " {system_info}\n", + "\n", + " Your response will be written to a file called main.ts and then compile and ecexted; the compilation command is:\n", + "\n", + " {compile_command}\n", + "\n", + " Respond only with C++ code.\n", + " Python code to port:\n", + "\n", + " ```python\n", + " {python}\n", + " ```\n", + "\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 124, + "id": "09da7cb1", + "metadata": {}, + "outputs": [], + "source": [ + "def messages_for(python):\n", + " return [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt_for(python)},\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "id": "abcdb617", + "metadata": {}, + "outputs": [], + "source": [ + "def write_output(code):\n", + " with open(\"main.ts\", \"w\", encoding=\"utf-8\") as f:\n", + " f.write(code)" + ] + }, + { + "cell_type": "code", + "execution_count": 126, + "id": "c7a32d5f", + "metadata": {}, + "outputs": [], + "source": [ + "def convert(python):\n", + " reasoning_effort = \"high\"\n", + " response = openai.chat.completions.create(\n", + " model=OPENAI_MODEL,\n", + " messages=messages_for(python),\n", + " reasoning_effort=reasoning_effort,\n", + " )\n", + " reply = response.choices[0].message.content\n", + " reply = reply.replace(\"```ts\", \"\").replace(\"```\", \"\")\n", + " return reply" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "id": "59a7ec1f", + "metadata": {}, + "outputs": [], + "source": [ + "pi = \"\"\"\n", + "import time\n", + "\n", + "def calculate(iterations, param1, param2):\n", + " result = 1.0\n", + " for i in range(1, iterations+1):\n", + " j = i * param1 - param2\n", + " result -= (1/j)\n", + " j = i * param1 + param2\n", + " result += (1/j)\n", + " return result\n", + "\n", + "start_time = time.time()\n", + "result = calculate(200_000_000, 4, 1) * 4\n", + "end_time = time.time()\n", + "\n", + "print(f\"Result: {result:.12f}\")\n", + "print(f\"Execution Time: {(end_time - start_time):.6f} seconds\")\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 128, + "id": "6856393b", + "metadata": {}, + "outputs": [], + "source": [ + "def run_python(code):\n", + " globals_dict = {\"__builtins__\": __builtins__}\n", + "\n", + " buffer = io.StringIO()\n", + " old_stdout = sys.stdout\n", + " sys.stdout = buffer\n", + "\n", + " try:\n", + " exec(code, globals_dict)\n", + " output = buffer.getvalue()\n", + " except Exception as e:\n", + " output = f\"Error: {e}\"\n", + " finally:\n", + " sys.stdout = old_stdout\n", + "\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "id": "c51fa5ea", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Result: 3.141592656089\\nExecution Time: 19.478347 seconds\\n'" + ] + }, + "execution_count": 129, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "run_python(pi)" + ] + }, + { + "cell_type": "code", + "execution_count": 130, + "id": "69eb2304", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"import { performance } from 'perf_hooks';\\n\\nfunction digamma(z: number): number {\\n let acc = 0;\\n while (z < 7) {\\n acc -= 1 / z;\\n z += 1;\\n }\\n const z2 = z * z;\\n const z4 = z2 * z2;\\n const z6 = z4 * z2;\\n const z8 = z4 * z4;\\n const z10 = z8 * z2;\\n const z12 = z10 * z2;\\n const series =\\n Math.log(z)\\n - 1 / (2 * z)\\n - 1 / (12 * z2)\\n + 1 / (120 * z4)\\n - 1 / (252 * z6)\\n + 1 / (240 * z8)\\n - 5 / (660 * z10)\\n + 691 / (32760 * z12);\\n return acc + series;\\n}\\n\\nconst N = 200_000_000;\\n\\nconst t0 = performance.now();\\nconst result =\\n 4 - digamma(N + 0.75) + digamma(0.75) + digamma(N + 1.25) - digamma(1.25);\\nconst t1 = performance.now();\\n\\nconsole.log(`Result: ${result.toFixed(12)}`);\\nconsole.log(`Execution Time: ${((t1 - t0) / 1000).toFixed(6)} seconds`);\"" + ] + }, + "execution_count": 130, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "convert(pi)" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "id": "2ea56d95", + "metadata": {}, + "outputs": [], + "source": [ + " \n", + "def run_typescript(code):\n", + " write_output(code)\n", + " try:\n", + " subprocess.run(compile_command, check=True, text=True, capture_output=True)\n", + " run_result = subprocess.run(run_command, check=True, text=True, capture_output=True)\n", + " return run_result.stdout\n", + " except subprocess.CalledProcessError as e:\n", + " return f\"An error occurred:\\n{e.stderr}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "id": "79d6bd87", + "metadata": {}, + "outputs": [], + "source": [ + "# run_typescript()" + ] + }, + { + "cell_type": "markdown", + "id": "b4799b88", + "metadata": {}, + "source": [ + "## User Interface" + ] + }, + { + "cell_type": "code", + "execution_count": 133, + "id": "8486ce70", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7864\n", + "* To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 133, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with gr.Blocks(\n", + " theme=gr.themes.Monochrome(), title=\"Port from Python to TypeScript\"\n", + ") as ui:\n", + " with gr.Row(equal_height=True):\n", + " with gr.Column(scale=6):\n", + " python = gr.Code(\n", + " label=\"Python Original Code\",\n", + " value=pi,\n", + " language=\"python\",\n", + " lines=30,\n", + " )\n", + " with gr.Column(scale=6):\n", + " ts = gr.Code(\n", + " label=\"TypeScript (generated)\", value=\"\", language=\"cpp\", lines=26\n", + " )\n", + " with gr.Row(elem_classes=[\"controls\"]):\n", + " python_run = gr.Button(\"Run Python\", elem_classes=[\"run-btn\", \"py\"])\n", + " port = gr.Button(\"Convert to TS\", elem_classes=[\"convert-btn\"])\n", + " ts_run = gr.Button(\"Run TS\", elem_classes=[\"run-btn\", \"ts\"])\n", + "\n", + " with gr.Row(equal_height=True):\n", + " with gr.Column(scale=6):\n", + " python_out = gr.TextArea(label=\"Python Result\", lines=10)\n", + " with gr.Column(scale=6):\n", + " ts_out = gr.TextArea(label=\"TS output\", lines=10)\n", + "\n", + " port.click(fn=convert, inputs=[python], outputs=[ts])\n", + " python_run.click(fn=run_python, inputs=[python], outputs=[python_out])\n", + " ts_run.click(fn=run_typescript, inputs=[ts], outputs=[ts_out])\n", + " \n", + " \n", + "ui.launch(inbrowser=True)" + ] + }, + { + "cell_type": "markdown", + "id": "4663a174", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9033e421", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week4/community-contributions/week4_exercise_solution-Stephen.ipynb b/week4/community-contributions/week4_exercise_solution-Stephen.ipynb new file mode 100644 index 0000000..07d5155 --- /dev/null +++ b/week4/community-contributions/week4_exercise_solution-Stephen.ipynb @@ -0,0 +1,180 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "ed8c52b6", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "\n", + "load_dotenv(override=True)\n", + "\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "ollama_api_key = os.getenv('OLLAMA_API_KEY')\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + "\n", + "if ollama_api_key:\n", + " print(f\"OLLAMA API Key exists and begins {ollama_api_key[:2]}\")\n", + "else:\n", + " print(\"OLLAMA API Key not set (and this is optional)\")\n", + "\n", + "ollama_url = \"http://localhost:11434/v1\"\n", + "\n", + "openai = OpenAI()\n", + "ollama = OpenAI(api_key=ollama_api_key, base_url=ollama_url)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "c628f95e", + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt_doc = \"\"\"You are an expert Python developer and code reviewer.\n", + "Your job is to read the user's provided function, and return:\n", + "1. A concise, PEP-257-compliant docstring summarizing what the function does, clarifying types, parameters, return values, and side effects.\n", + "2. Helpful inline comments that improve both readability and maintainability, without restating what the code obviously does.\n", + "\n", + "Only output the function, not explanations or additional text. \n", + "Do not modify variable names or refactor the function logic.\n", + "Your response should improve the code's clarity and documentation, making it easier for others to understand and maintain.\n", + "Don't be extremely verbose.\n", + "Your answer should be at a senior level of expertise.\n", + "\"\"\"\n", + "\n", + "system_prompt_tests = \"\"\"You are a seasoned Python developer and testing expert.\n", + "Your task is to read the user's provided function, and generate:\n", + "1. A concise set of meaningful unit tests that thoroughly validate the function's correctness, including typical, edge, and error cases.\n", + "2. The tests should be written for pytest (or unittest if pytest is not appropriate), use clear, descriptive names, and avoid unnecessary complexity.\n", + "3. If dependencies or mocking are needed, include minimal necessary setup code (but avoid over-mocking).\n", + "\n", + "Only output the relevant test code, not explanations or extra text.\n", + "Do not change the original function; focus solely on comprehensive, maintainable test coverage that other developers can easily understand and extend.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "4bb84e6c", + "metadata": {}, + "outputs": [], + "source": [ + "models = [\"gpt-4.1-mini\", \"llama3.1\"]\n", + "clients = {\"gpt-4.1-mini\": openai, \"llama3.1\": ollama}\n", + "\n", + "def generate_documentation(code, model):\n", + " response = clients[model].chat.completions.create(\n", + " model=model,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt_doc},\n", + " {\"role\": \"user\", \"content\": code}\n", + " ],\n", + " stream=True\n", + " )\n", + " output = \"\"\n", + " for chunk in response:\n", + " output += chunk.choices[0].delta.content or \"\"\n", + " yield output.replace(\"```python\", \"\").replace(\"```\", \"\")\n", + "\n", + "def generate_tests(code, model):\n", + " response = clients[model].chat.completions.create(\n", + " model=model,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt_tests},\n", + " {\"role\": \"user\", \"content\": code}\n", + " ],\n", + " stream=True\n", + " )\n", + " output = \"\"\n", + " for chunk in response:\n", + " output += chunk.choices[0].delta.content or \"\"\n", + " yield output.replace(\"```python\", \"\").replace(\"```\", \"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4e65b26", + "metadata": {}, + "outputs": [], + "source": [ + "with gr.Blocks(theme=gr.themes.Soft(spacing_size=gr.themes.sizes.spacing_sm, radius_size=gr.themes.sizes.radius_none)) as ui:\n", + " gr.Markdown(\"# Python Toolbox\", elem_id=\"app-title\")\n", + " \n", + " with gr.Tab(\"Docstring Generator\") as tab1:\n", + " gr.Markdown(\"## Docstring & Comment Generator\")\n", + " gr.Markdown(\"Paste your function below to generate helpful docstrings and inline comments!\")\n", + "\n", + " with gr.Row():\n", + " with gr.Column():\n", + " code_input = gr.Code(label=\"Your Python function here\", lines=20, language=\"python\")\n", + " model_dropdown = gr.Dropdown(choices=models, value=models[0], label=\"Select model\")\n", + " submit_doc_btn = gr.Button(\"Generate docstring & comments\")\n", + " with gr.Column():\n", + " code_output = gr.Code(label=\"New function with docstring and comments\", language=\"python\")\n", + "\n", + " submit_doc_btn.click(\n", + " generate_documentation, \n", + " inputs=[code_input, model_dropdown], \n", + " outputs=code_output\n", + " )\n", + "\n", + " with gr.Tab(\"Unit Tests Generator\") as tab2:\n", + " gr.Markdown(\"## Unit Test Generator\")\n", + " gr.Markdown(\"Paste your function below to generate helpful unit tests!\")\n", + "\n", + " with gr.Row():\n", + " with gr.Column():\n", + " code_input_2 = gr.Code(label=\"Your Python function here\", lines=20, language=\"python\")\n", + " model_dropdown_2 = gr.Dropdown(choices=models, value=models[0], label=\"Select model\")\n", + " submit_test_btn = gr.Button(\"Generate unit tests\")\n", + " with gr.Column():\n", + " code_output_2 = gr.Code(label=\"Generated unit tests\", language=\"python\")\n", + "\n", + " submit_test_btn.click(\n", + " generate_tests, \n", + " inputs=[code_input_2, model_dropdown_2], \n", + " outputs=code_output_2\n", + " )\n", + " \n", + " \n", + " tab1.select(lambda x: x, inputs=code_input_2, outputs=code_input)\n", + " tab2.select(lambda x: x, inputs=code_input, outputs=code_input_2)\n", + "\n", + "ui.launch(share=False, inbrowser=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}