From 74e48bba2cdfa0ada01c90737937c458a5e1553d Mon Sep 17 00:00:00 2001 From: Nik Date: Sun, 26 Oct 2025 20:01:49 +0530 Subject: [PATCH] Create pytest generator tool. --- .../pytest_generator/pytest_generator.ipynb | 321 ++++++++++++++++++ 1 file changed, 321 insertions(+) create mode 100644 week4/community-contributions/pytest_generator/pytest_generator.ipynb diff --git a/week4/community-contributions/pytest_generator/pytest_generator.ipynb b/week4/community-contributions/pytest_generator/pytest_generator.ipynb new file mode 100644 index 0000000..84d568b --- /dev/null +++ b/week4/community-contributions/pytest_generator/pytest_generator.ipynb @@ -0,0 +1,321 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "ba193fd5", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "import ast\n", + "import textwrap\n", + "from pathlib import Path\n", + "from dataclasses import dataclass\n", + "from typing import List, Protocol, Optional\n", + "from openai import BadRequestError as _OpenAIBadRequest\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "\n", + "\n", + "load_dotenv(override=True)\n", + "\n", + "TESTGEN_MODEL = os.getenv(\"TESTGEN_MODEL\", \"gpt-5-nano\")\n", + "OPENAI_CLIENT = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5d6b0f2", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class CompletionClient(Protocol):\n", + " \"\"\"Any LLM client provides a .complete() method.\"\"\"\n", + " def complete(self, *, model: str, system: str, user: str) -> str: ...\n", + "\n", + "\n", + "class OpenAIChatClient:\n", + " \"\"\"Adapter over OpenAI chat completions that omits unsupported params.\"\"\"\n", + " def __init__(self, client: OpenAI):\n", + " self._client = client\n", + "\n", + " def _call(self, *, model: str, system: str, user: str, include_temperature: bool = False) -> str:\n", + " params = {\n", + " \"model\": model,\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": system},\n", + " {\"role\": \"user\", \"content\": user},\n", + " ],\n", + " }\n", + " if include_temperature: \n", + " params[\"temperature\"] = 0.7 \n", + "\n", + " resp = self._client.chat.completions.create(**params)\n", + " text = (resp.choices[0].message.content or \"\").strip()\n", + " return _extract_code_or_text(text)\n", + "\n", + " def complete(self, *, model: str, system: str, user: str) -> str:\n", + " try:\n", + " return self._call(model=model, system=system, user=user, include_temperature=False)\n", + " except _OpenAIBadRequest as e:\n", + " # Extra safety: if some lib auto-injected temperature, retry without it\n", + " if \"temperature\" in str(e).lower():\n", + " return self._call(model=model, system=system, user=user, include_temperature=False)\n", + " raise\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31558bf0", + "metadata": {}, + "outputs": [], + "source": [ + "def _extract_code_or_text(s: str) -> str:\n", + " \"\"\"Prefer fenced python if present; otherwise return raw text.\"\"\"\n", + " m = re.search(r\"```(?:python)?\\s*(.*?)```\", s, flags=re.S | re.I)\n", + " return m.group(1).strip() if m else s.strip()\n", + "\n", + "\n", + "def _ensure_header_and_import(code: str, module_name: str) -> str:\n", + " \"\"\"Make sure tests import the module and pytest; keep output minimal.\"\"\"\n", + " code = code.strip()\n", + " needs_pytest = \"import pytest\" not in code\n", + " needs_import = f\"import {module_name}\" not in code and f\"import {module_name} as mod\" not in code\n", + "\n", + " header_lines = []\n", + " if needs_pytest:\n", + " header_lines.append(\"import pytest\")\n", + " if needs_import:\n", + " header_lines.append(f\"import {module_name} as mod\")\n", + "\n", + " if header_lines:\n", + " code = \"\\n\".join(header_lines) + \"\\n\\n\" + code\n", + " return code\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3aeadedc", + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass(frozen=True)\n", + "class SymbolInfo:\n", + " kind: str # \"function\" | \"class\" | \"method\"\n", + " name: str\n", + " signature: str\n", + " lineno: int\n", + "\n", + "\n", + "class PublicAPIExtractor:\n", + " \"\"\"Extract a small 'public API' summary from a Python module.\"\"\"\n", + " def extract(self, source: str) -> List[SymbolInfo]:\n", + " tree = ast.parse(source)\n", + " out: List[SymbolInfo] = []\n", + "\n", + " for node in tree.body:\n", + " if isinstance(node, ast.FunctionDef) and not node.name.startswith(\"_\"):\n", + " out.append(SymbolInfo(\"function\", node.name, self._sig(node), node.lineno))\n", + " elif isinstance(node, ast.ClassDef) and not node.name.startswith(\"_\"):\n", + " out.append(SymbolInfo(\"class\", node.name, node.name, node.lineno))\n", + " for sub in node.body:\n", + " if isinstance(sub, ast.FunctionDef) and not sub.name.startswith(\"_\"):\n", + " out.append(SymbolInfo(\"method\",\n", + " f\"{node.name}.{sub.name}\",\n", + " self._sig(sub),\n", + " sub.lineno))\n", + " return sorted(out, key=lambda s: (s.kind, s.name.lower(), s.lineno))\n", + "\n", + " def _sig(self, fn: ast.FunctionDef) -> str:\n", + " args = [a.arg for a in fn.args.args]\n", + " if fn.args.vararg:\n", + " args.append(\"*\" + fn.args.vararg.arg)\n", + " args.extend(a.arg + \"=?\" for a in fn.args.kwonlyargs)\n", + " if fn.args.kwarg:\n", + " args.append(\"**\" + fn.args.kwarg.arg)\n", + " ret = \"\"\n", + " if fn.returns is not None:\n", + " try:\n", + " ret = f\" -> {ast.unparse(fn.returns)}\"\n", + " except Exception:\n", + " pass\n", + " return f\"def {fn.name}({', '.join(args)}){ret}:\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a45ac5be", + "metadata": {}, + "outputs": [], + "source": [ + "class PromptBuilder:\n", + " \"\"\"Builds concise, deterministic prompts for pytest generation.\"\"\"\n", + " SYSTEM = (\n", + " \"You are a senior Python engineer. Produce a single, self-contained pytest file.\\n\"\n", + " \"Rules:\\n\"\n", + " \"- Output only Python test code (no prose, no markdown fences).\\n\"\n", + " \"- Use plain pytest tests (functions), no classes unless unavoidable.\\n\"\n", + " \"- Deterministic: avoid network/IO; seed randomness if used.\\n\"\n", + " \"- Import the target module by module name.\\n\"\n", + " \"- Create a minimal test covering every public function and method.\\n\"\n", + " \"- Prefer straightforward, fast assertions over exhaustive checks.\\n\"\n", + " )\n", + "\n", + " def build_user(self, *, module_name: str, source: str, symbols: List[SymbolInfo]) -> str:\n", + " summary = \"\\n\".join(f\"- {s.kind:<6} {s.signature}\" for s in symbols) or \"- (no public symbols)\"\n", + " return textwrap.dedent(f\"\"\"\n", + " Create pytest tests for module `{module_name}`.\n", + "\n", + " Public API Summary:\n", + " {summary}\n", + "\n", + " Constraints:\n", + " - Import as: `import {module_name} as mod`\n", + " - Keep tests tiny, fast, and deterministic.\n", + "\n", + " Full module source (for reference):\n", + " # --- BEGIN SOURCE {module_name}.py ---\n", + " {source}\n", + " # --- END SOURCE ---\n", + " \"\"\").strip()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "787e58b6", + "metadata": {}, + "outputs": [], + "source": [ + "class TestGenerator:\n", + " \"\"\"Orchestrates extraction, prompt, model call, and final polish.\"\"\"\n", + " def __init__(self, llm: CompletionClient, model: str):\n", + " self._llm = llm\n", + " self._model = model\n", + " self._extractor = PublicAPIExtractor()\n", + " self._prompts = PromptBuilder()\n", + "\n", + " def generate_tests(self, module_name: str, source: str) -> str:\n", + " symbols = self._extractor.extract(source)\n", + " user = self._prompts.build_user(module_name=module_name, source=source, symbols=symbols)\n", + " raw = self._llm.complete(model=self._model, system=self._prompts.SYSTEM, user=user)\n", + " return _ensure_header_and_import(raw, module_name)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8402f62f", + "metadata": {}, + "outputs": [], + "source": [ + "LLM = OpenAIChatClient(OPENAI_CLIENT)\n", + "SERVICE = TestGenerator(LLM, TESTGEN_MODEL)\n", + "\n", + "def build_module_name_from_path(path: str) -> str:\n", + " p = Path(path)\n", + " return p.stem\n", + "\n", + "def generate_from_code(module_name: str, code: str, save: bool, out_dir: str) -> tuple[str, str]:\n", + " if not module_name.strip():\n", + " return \"\", \"โŒ Please provide a module name.\"\n", + " if not code.strip():\n", + " return \"\", \"โŒ Please paste some Python code.\"\n", + "\n", + " tests_code = SERVICE.generate_tests(module_name=module_name.strip(), source=code)\n", + " saved = \"\"\n", + " if save:\n", + " out = Path(out_dir or \"tests\")\n", + " out.mkdir(parents=True, exist_ok=True)\n", + " out_path = out / f\"test_{module_name}.py\"\n", + " out_path.write_text(tests_code, encoding=\"utf-8\")\n", + " saved = f\"โœ… Saved to {out_path}\"\n", + " return tests_code, saved\n", + "\n", + "\n", + "def generate_from_file(file_obj, save: bool, out_dir: str) -> tuple[str, str]:\n", + " if file_obj is None:\n", + " return \"\", \"โŒ Please upload a .py file.\"\n", + " code = file_obj.decode(\"utf-8\")\n", + " module_name = build_module_name_from_path(\"uploaded_module.py\")\n", + " return generate_from_code(module_name, code, save, out_dir)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d240ce5", + "metadata": {}, + "outputs": [], + "source": [ + "with gr.Blocks(title=\"Simple PyTest Generator\") as ui:\n", + " gr.Markdown(\"## ๐Ÿงช Simple PyTest Generator (Week 4 โ€ข Community Contribution)\\n\"\n", + " \"Generate **minimal, deterministic** pytest tests from a Python module using a Frontier model.\")\n", + "\n", + " with gr.Tab(\"Paste Code\"):\n", + " with gr.Row():\n", + " module_name = gr.Textbox(label=\"Module name (used in `import as mod`)\", value=\"mymodule\")\n", + " code_in = gr.Code(label=\"Python module code\", language=\"python\", lines=22)\n", + " with gr.Row():\n", + " save_cb = gr.Checkbox(label=\"Save to /tests\", value=True)\n", + " out_dir = gr.Textbox(label=\"Output folder\", value=\"tests\")\n", + " gen_btn = gr.Button(\"Generate tests\", variant=\"primary\")\n", + " with gr.Row():\n", + " tests_out = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=20)\n", + " status = gr.Markdown()\n", + "\n", + " def _on_gen(name, code, save, outdir):\n", + " tests, msg = generate_from_code(name, code, save, outdir)\n", + " return tests, (msg or \"โœ… Done\")\n", + "\n", + " gen_btn.click(_on_gen, inputs=[module_name, code_in, save_cb, out_dir], outputs=[tests_out, status])\n", + "\n", + " with gr.Tab(\"Upload .py\"):\n", + " upload = gr.File(file_types=[\".py\"], label=\"Upload a Python module\")\n", + " with gr.Row():\n", + " save_cb2 = gr.Checkbox(label=\"Save to /tests\", value=True)\n", + " out_dir2 = gr.Textbox(label=\"Output folder\", value=\"tests\")\n", + " gen_btn2 = gr.Button(\"Generate tests from file\")\n", + " tests_out2 = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=20)\n", + " status2 = gr.Markdown()\n", + "\n", + " def _on_gen_file(f, save, outdir):\n", + " tests, msg = generate_from_file(f.read() if f else None, save, outdir)\n", + " return tests, (msg or \"โœ… Done\")\n", + "\n", + " gen_btn2.click(_on_gen_file, inputs=[upload, save_cb2, out_dir2], outputs=[tests_out2, status2])\n", + "\n", + "ui.launch(inbrowser=True)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llm-engineering", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}