diff --git a/week4/community-contributions/pytest_generator/pytest_generator.ipynb b/week4/community-contributions/pytest_generator/pytest_generator.ipynb new file mode 100644 index 0000000..7051957 --- /dev/null +++ b/week4/community-contributions/pytest_generator/pytest_generator.ipynb @@ -0,0 +1,498 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "b8be8252", + "metadata": {}, + "outputs": [], + "source": [ + "!uv pip install pytest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba193fd5", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "import ast\n", + "import sys\n", + "import uuid\n", + "import json\n", + "import textwrap\n", + "import subprocess\n", + "from pathlib import Path\n", + "from dataclasses import dataclass\n", + "from typing import List, Protocol, Tuple, Dict, Optional\n", + "\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "from openai import BadRequestError as _OpenAIBadRequest\n", + "import gradio as gr\n", + "\n", + "load_dotenv(override=True)\n", + "\n", + "# --- Provider base URLs (Gemini & Groq speak OpenAI-compatible API) ---\n", + "GEMINI_BASE = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n", + "GROQ_BASE = \"https://api.groq.com/openai/v1\"\n", + "\n", + "# --- API Keys (add these in your .env) ---\n", + "openai_api_key = os.getenv(\"OPENAI_API_KEY\") # OpenAI\n", + "google_api_key = os.getenv(\"GOOGLE_API_KEY\") # Gemini\n", + "groq_api_key = os.getenv(\"GROQ_API_KEY\") # Groq\n", + "\n", + "# --- Clients ---\n", + "openai_client = OpenAI() # OpenAI default (reads OPENAI_API_KEY)\n", + "gemini_client = OpenAI(api_key=google_api_key, base_url=GEMINI_BASE) if google_api_key else None\n", + "groq_client = OpenAI(api_key=groq_api_key, base_url=GROQ_BASE) if groq_api_key else None\n", + "\n", + "# --- Model registry: label -> { client, model } ---\n", + "MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}\n", + "\n", + "def _register(label: str, client: Optional[OpenAI], model_id: str):\n", + " \"\"\"Add a model to the registry only if its client is configured.\"\"\"\n", + " if client is not None:\n", + " MODEL_REGISTRY[label] = {\"client\": client, \"model\": model_id}\n", + "\n", + "# OpenAI\n", + "_register(\"OpenAI • GPT-5\", openai_client, \"gpt-5\")\n", + "_register(\"OpenAI • GPT-5 Nano\", openai_client, \"gpt-5-nano\")\n", + "_register(\"OpenAI • GPT-4o-mini\", openai_client, \"gpt-4o-mini\")\n", + "\n", + "# Gemini (Google)\n", + "_register(\"Gemini • 2.5 Pro\", gemini_client, \"gemini-2.5-pro\")\n", + "_register(\"Gemini • 2.5 Flash\", gemini_client, \"gemini-2.5-flash\")\n", + "\n", + "# Groq\n", + "_register(\"Groq • Llama 3.1 8B\", groq_client, \"llama-3.1-8b-instant\")\n", + "_register(\"Groq • Llama 3.3 70B\", groq_client, \"llama-3.3-70b-versatile\")\n", + "_register(\"Groq • GPT-OSS 20B\", groq_client, \"openai/gpt-oss-20b\")\n", + "_register(\"Groq • GPT-OSS 120B\", groq_client, \"openai/gpt-oss-120b\")\n", + "\n", + "DEFAULT_MODEL = next(iter(MODEL_REGISTRY.keys()), None)\n", + "\n", + "print(f\"Providers configured → OpenAI:{bool(openai_api_key)} Gemini:{bool(google_api_key)} Groq:{bool(groq_api_key)}\")\n", + "print(\"Models available →\", \", \".join(MODEL_REGISTRY.keys()) or \"None (add API keys in .env)\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5d6b0f2", + "metadata": {}, + "outputs": [], + "source": [ + "class CompletionClient(Protocol):\n", + " \"\"\"Any LLM client provides a .complete() method using a registry label.\"\"\"\n", + " def complete(self, *, model_label: str, system: str, user: str) -> str: ...\n", + "\n", + "\n", + "def _extract_code_or_text(s: str) -> str:\n", + " \"\"\"Prefer fenced python if present; otherwise return raw text.\"\"\"\n", + " m = re.search(r\"```(?:python)?\\s*(.*?)```\", s, flags=re.S | re.I)\n", + " return m.group(1).strip() if m else s.strip()\n", + "\n", + "\n", + "class MultiModelChatClient:\n", + " \"\"\"Routes requests to the right provider/client based on model label.\"\"\"\n", + " def __init__(self, registry: Dict[str, Dict[str, object]]):\n", + " self._registry = registry\n", + "\n", + " def _call(self, *, client: OpenAI, model_id: str, system: str, user: str) -> str:\n", + " params = {\n", + " \"model\": model_id,\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": system},\n", + " {\"role\": \"user\", \"content\": user},\n", + " ],\n", + " }\n", + " resp = client.chat.completions.create(**params) # do NOT send temperature for strict providers\n", + " text = (resp.choices[0].message.content or \"\").strip()\n", + " return _extract_code_or_text(text)\n", + "\n", + " def complete(self, *, model_label: str, system: str, user: str) -> str:\n", + " if model_label not in self._registry:\n", + " raise ValueError(f\"Unknown model label: {model_label}\")\n", + " info = self._registry[model_label]\n", + " client = info[\"client\"]\n", + " model = info[\"model\"]\n", + " try:\n", + " return self._call(client=client, model_id=str(model), system=system, user=user)\n", + " except _OpenAIBadRequest as e:\n", + " # Providers may reject stray params; we don't send any, but retry anyway.\n", + " if \"temperature\" in str(e).lower():\n", + " return self._call(client=client, model_id=str(model), system=system, user=user)\n", + " raise\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31558bf0", + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass(frozen=True)\n", + "class SymbolInfo:\n", + " kind: str # \"function\" | \"class\" | \"method\"\n", + " name: str\n", + " signature: str\n", + " lineno: int\n", + "\n", + "class PublicAPIExtractor:\n", + " \"\"\"Extract concise 'public API' summary from a Python module.\"\"\"\n", + " def extract(self, source: str) -> List[SymbolInfo]:\n", + " tree = ast.parse(source)\n", + " out: List[SymbolInfo] = []\n", + " for node in tree.body:\n", + " if isinstance(node, ast.FunctionDef) and not node.name.startswith(\"_\"):\n", + " out.append(SymbolInfo(\"function\", node.name, self._sig(node), node.lineno))\n", + " elif isinstance(node, ast.ClassDef) and not node.name.startswith(\"_\"):\n", + " out.append(SymbolInfo(\"class\", node.name, node.name, node.lineno))\n", + " for sub in node.body:\n", + " if isinstance(sub, ast.FunctionDef) and not sub.name.startswith(\"_\"):\n", + " out.append(SymbolInfo(\"method\",\n", + " f\"{node.name}.{sub.name}\",\n", + " self._sig(sub),\n", + " sub.lineno))\n", + " return sorted(out, key=lambda s: (s.kind, s.name.lower(), s.lineno))\n", + "\n", + " def _sig(self, fn: ast.FunctionDef) -> str:\n", + " args = [a.arg for a in fn.args.args]\n", + " if fn.args.vararg:\n", + " args.append(\"*\" + fn.args.vararg.arg)\n", + " args.extend(a.arg + \"=?\" for a in fn.args.kwonlyargs)\n", + " if fn.args.kwarg:\n", + " args.append(\"**\" + fn.args.kwarg.arg)\n", + " ret = \"\"\n", + " if fn.returns is not None:\n", + " try:\n", + " ret = f\" -> {ast.unparse(fn.returns)}\"\n", + " except Exception:\n", + " pass\n", + " return f\"def {fn.name}({', '.join(args)}){ret}:\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3aeadedc", + "metadata": {}, + "outputs": [], + "source": [ + "class PromptBuilder:\n", + " \"\"\"Builds deterministic prompts for pytest generation.\"\"\"\n", + " SYSTEM = (\n", + " \"You are a senior Python engineer. Produce a single, self-contained pytest file.\\n\"\n", + " \"Rules:\\n\"\n", + " \"- Output only Python test code (no prose, no markdown fences).\\n\"\n", + " \"- Use plain pytest tests (functions), no classes unless unavoidable.\\n\"\n", + " \"- Deterministic: avoid network/IO; seed randomness if used.\\n\"\n", + " \"- Import the target module by module name only.\\n\"\n", + " \"- Cover every public function and method with at least one tiny test.\\n\"\n", + " \"- Prefer straightforward, fast assertions.\\n\"\n", + " )\n", + "\n", + " def build_user(self, *, module_name: str, source: str, symbols: List[SymbolInfo]) -> str:\n", + " summary = \"\\n\".join(f\"- {s.kind:<6} {s.signature}\" for s in symbols) or \"- (no public symbols)\"\n", + " return textwrap.dedent(f\"\"\"\n", + " Create pytest tests for module `{module_name}`.\n", + "\n", + " Public API Summary:\n", + " {summary}\n", + "\n", + " Constraints:\n", + " - Import as: `import {module_name} as mod`\n", + " - Keep tests tiny, fast, and deterministic.\n", + "\n", + " Full module source (for reference):\n", + " # --- BEGIN SOURCE {module_name}.py ---\n", + " {source}\n", + " # --- END SOURCE ---\n", + " \"\"\").strip()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a45ac5be", + "metadata": {}, + "outputs": [], + "source": [ + "def _ensure_header_and_import(code: str, module_name: str) -> str:\n", + " \"\"\"Ensure tests import pytest and the target module as 'mod'.\"\"\"\n", + " code = code.strip()\n", + " needs_pytest = \"import pytest\" not in code\n", + " has_mod = (f\"import {module_name} as mod\" in code) or (f\"from {module_name} import\" in code)\n", + " needs_import = not has_mod\n", + "\n", + " header = []\n", + " if needs_pytest:\n", + " header.append(\"import pytest\")\n", + " if needs_import:\n", + " header.append(f\"import {module_name} as mod\")\n", + "\n", + " return (\"\\n\".join(header) + \"\\n\\n\" + code) if header else code\n", + "\n", + "\n", + "def build_module_name_from_path(path: str) -> str:\n", + " return Path(path).stem\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "787e58b6", + "metadata": {}, + "outputs": [], + "source": [ + "class TestGenerator:\n", + " \"\"\"Extraction → prompt → model → polish.\"\"\"\n", + " def __init__(self, llm: CompletionClient):\n", + " self._llm = llm\n", + " self._extractor = PublicAPIExtractor()\n", + " self._prompts = PromptBuilder()\n", + "\n", + " def generate_tests(self, model_label: str, module_name: str, source: str) -> str:\n", + " symbols = self._extractor.extract(source)\n", + " user = self._prompts.build_user(module_name=module_name, source=source, symbols=symbols)\n", + " raw = self._llm.complete(model_label=model_label, system=self._prompts.SYSTEM, user=user)\n", + " return _ensure_header_and_import(raw, module_name)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8402f62f", + "metadata": {}, + "outputs": [], + "source": [ + "def _parse_pytest_summary(output: str) -> Tuple[str, Dict[str, int]]:\n", + " \"\"\"\n", + " Parse the final summary line like:\n", + " '3 passed, 1 failed, 2 skipped in 0.12s'\n", + " Return (summary_line, counts_dict).\n", + " \"\"\"\n", + " summary_line = \"\"\n", + " for line in output.strip().splitlines()[::-1]: # scan from end\n", + " if \" passed\" in line or \" failed\" in line or \" error\" in line or \" skipped\" in line or \" deselected\" in line:\n", + " summary_line = line.strip()\n", + " break\n", + "\n", + " counts = {\"passed\": 0, \"failed\": 0, \"errors\": 0, \"skipped\": 0, \"xfail\": 0, \"xpassed\": 0}\n", + " m = re.findall(r\"(\\d+)\\s+(passed|failed|errors?|skipped|xfailed|xpassed)\", summary_line)\n", + " for num, kind in m:\n", + " if kind.startswith(\"error\"):\n", + " counts[\"errors\"] += int(num)\n", + " elif kind == \"passed\":\n", + " counts[\"passed\"] += int(num)\n", + " elif kind == \"failed\":\n", + " counts[\"failed\"] += int(num)\n", + " elif kind == \"skipped\":\n", + " counts[\"skipped\"] += int(num)\n", + " elif kind == \"xfailed\":\n", + " counts[\"xfail\"] += int(num)\n", + " elif kind == \"xpassed\":\n", + " counts[\"xpassed\"] += int(num)\n", + "\n", + " return summary_line or \"(no summary line found)\", counts\n", + "\n", + "\n", + "def run_pytest_on_snippet(module_name: str, module_code: str, tests_code: str) -> Tuple[str, str]:\n", + " \"\"\"\n", + " Create an isolated temp workspace, write module + tests, run pytest,\n", + " and return (human_summary, full_cli_output).\n", + " \"\"\"\n", + " if not module_name or not module_code.strip() or not tests_code.strip():\n", + " return \"❌ Provide module name, module code, and tests.\", \"\"\n", + "\n", + " run_id = uuid.uuid4().hex[:8]\n", + " base = Path(\".pytest_runs\") / f\"run_{run_id}\"\n", + " tests_dir = base / \"tests\"\n", + " tests_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + " # Write module and tests\n", + " (base / f\"{module_name}.py\").write_text(module_code, encoding=\"utf-8\")\n", + " (tests_dir / f\"test_{module_name}.py\").write_text(tests_code, encoding=\"utf-8\")\n", + "\n", + " # Run pytest with this temp dir on PYTHONPATH\n", + " env = os.environ.copy()\n", + " env[\"PYTHONPATH\"] = str(base) + os.pathsep + env.get(\"PYTHONPATH\", \"\")\n", + " cmd = [sys.executable, \"-m\", \"pytest\", \"-q\"] # quiet output, but still includes summary\n", + " proc = subprocess.run(cmd, cwd=base, env=env, text=True, capture_output=True)\n", + "\n", + " full_out = (proc.stdout or \"\") + (\"\\n\" + proc.stderr if proc.stderr else \"\")\n", + " summary_line, counts = _parse_pytest_summary(full_out)\n", + "\n", + " badges = []\n", + " for key in (\"passed\", \"failed\", \"errors\", \"skipped\", \"xpassed\", \"xfail\"):\n", + " val = counts.get(key, 0)\n", + " if val:\n", + " badges.append(f\"**{key}: {val}**\")\n", + " badges = \" • \".join(badges) if badges else \"no tests collected?\"\n", + "\n", + " human = f\"{summary_line}\\n\\n{badges}\"\n", + " return human, full_out\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d240ce5", + "metadata": {}, + "outputs": [], + "source": [ + "LLM = MultiModelChatClient(MODEL_REGISTRY)\n", + "SERVICE = TestGenerator(LLM)\n", + "\n", + "def generate_from_code(model_label: str, module_name: str, code: str, save: bool, out_dir: str) -> Tuple[str, str]:\n", + " if not model_label or model_label not in MODEL_REGISTRY:\n", + " return \"\", \"❌ Pick a model (or add API keys for providers in .env).\"\n", + " if not module_name.strip():\n", + " return \"\", \"❌ Please provide a module name.\"\n", + " if not code.strip():\n", + " return \"\", \"❌ Please paste some Python code.\"\n", + "\n", + " tests_code = SERVICE.generate_tests(model_label=model_label, module_name=module_name.strip(), source=code)\n", + " saved = \"\"\n", + " if save:\n", + " out = Path(out_dir or \"tests\")\n", + " out.mkdir(parents=True, exist_ok=True)\n", + " out_path = out / f\"test_{module_name}.py\"\n", + " out_path.write_text(tests_code, encoding=\"utf-8\")\n", + " saved = f\"✅ Saved to {out_path}\"\n", + " return tests_code, saved\n", + "\n", + "\n", + "def generate_from_file(model_label: str, file_obj, save: bool, out_dir: str) -> Tuple[str, str]:\n", + " if file_obj is None:\n", + " return \"\", \"❌ Please upload a .py file.\"\n", + " code = file_obj.decode(\"utf-8\")\n", + " module_name = build_module_name_from_path(\"uploaded_module.py\")\n", + " return generate_from_code(model_label, module_name, code, save, out_dir)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3e1401a", + "metadata": {}, + "outputs": [], + "source": [ + "EXAMPLE_CODE = \"\"\"\\\n", + "def add(a: int, b: int) -> int:\n", + " return a + b\n", + "\n", + "def divide(a: float, b: float) -> float:\n", + " if b == 0:\n", + " raise ZeroDivisionError(\"b must be non-zero\")\n", + " return a / b\n", + "\n", + "class Counter:\n", + " def __init__(self, start: int = 0):\n", + " self.value = start\n", + "\n", + " def inc(self, by: int = 1):\n", + " self.value += by\n", + " return self.value\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f802450e", + "metadata": {}, + "outputs": [], + "source": [ + "with gr.Blocks(title=\"PyTest Generator\") as ui:\n", + " gr.Markdown(\n", + " \"## 🧪 PyTest Generator (Week 4 • Community Contribution)\\n\"\n", + " \"Generate **minimal, deterministic** pytest tests from a Python module using your chosen model/provider.\"\n", + " )\n", + "\n", + " with gr.Row(equal_height=True):\n", + " # LEFT: inputs (module code)\n", + " with gr.Column(scale=6):\n", + " with gr.Row():\n", + " model_dd = gr.Dropdown(\n", + " list(MODEL_REGISTRY.keys()),\n", + " value=DEFAULT_MODEL,\n", + " label=\"Model (OpenAI, Gemini, Groq)\"\n", + " )\n", + " module_name_tb = gr.Textbox(\n", + " label=\"Module name (used in `import as mod`)\",\n", + " value=\"mymodule\"\n", + " )\n", + " code_in = gr.Code(\n", + " label=\"Python module code\",\n", + " language=\"python\",\n", + " lines=24,\n", + " value=EXAMPLE_CODE\n", + " )\n", + " with gr.Row():\n", + " save_cb = gr.Checkbox(label=\"Also save generated tests to /tests\", value=True)\n", + " out_dir_tb = gr.Textbox(label=\"Output folder\", value=\"tests\")\n", + " gen_btn = gr.Button(\"Generate tests\", variant=\"primary\")\n", + "\n", + " # RIGHT: outputs (generated tests + pytest run)\n", + " with gr.Column(scale=6):\n", + " tests_out = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=24)\n", + " with gr.Row():\n", + " run_btn = gr.Button(\"Run PyTest\", variant=\"secondary\")\n", + " summary_md = gr.Markdown()\n", + " full_out = gr.Textbox(label=\"Full PyTest output\", lines=12)\n", + "\n", + " # --- events ---\n", + "\n", + " def _on_gen(model_label, name, code, save, outdir):\n", + " tests, msg = generate_from_code(model_label, name, code, save, outdir)\n", + " status = msg or \"✅ Done\"\n", + " return tests, status\n", + "\n", + " gen_btn.click(\n", + " _on_gen,\n", + " inputs=[model_dd, module_name_tb, code_in, save_cb, out_dir_tb],\n", + " outputs=[tests_out, summary_md],\n", + " )\n", + "\n", + " def _on_run(name, code, tests):\n", + " summary, details = run_pytest_on_snippet(name, code, tests)\n", + " return summary, details\n", + "\n", + " run_btn.click(\n", + " _on_run,\n", + " inputs=[module_name_tb, code_in, tests_out],\n", + " outputs=[summary_md, full_out],\n", + " )\n", + "\n", + "ui.launch(inbrowser=True)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llm-engineering", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}