Create pytest generator tool.
This commit is contained in:
@@ -0,0 +1,321 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ba193fd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import ast\n",
|
||||
"import textwrap\n",
|
||||
"from pathlib import Path\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import List, Protocol, Optional\n",
|
||||
"from openai import BadRequestError as _OpenAIBadRequest\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from openai import OpenAI\n",
|
||||
"import gradio as gr\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"\n",
|
||||
"TESTGEN_MODEL = os.getenv(\"TESTGEN_MODEL\", \"gpt-5-nano\")\n",
|
||||
"OPENAI_CLIENT = OpenAI()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e5d6b0f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"class CompletionClient(Protocol):\n",
|
||||
" \"\"\"Any LLM client provides a .complete() method.\"\"\"\n",
|
||||
" def complete(self, *, model: str, system: str, user: str) -> str: ...\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class OpenAIChatClient:\n",
|
||||
" \"\"\"Adapter over OpenAI chat completions that omits unsupported params.\"\"\"\n",
|
||||
" def __init__(self, client: OpenAI):\n",
|
||||
" self._client = client\n",
|
||||
"\n",
|
||||
" def _call(self, *, model: str, system: str, user: str, include_temperature: bool = False) -> str:\n",
|
||||
" params = {\n",
|
||||
" \"model\": model,\n",
|
||||
" \"messages\": [\n",
|
||||
" {\"role\": \"system\", \"content\": system},\n",
|
||||
" {\"role\": \"user\", \"content\": user},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
" if include_temperature: \n",
|
||||
" params[\"temperature\"] = 0.7 \n",
|
||||
"\n",
|
||||
" resp = self._client.chat.completions.create(**params)\n",
|
||||
" text = (resp.choices[0].message.content or \"\").strip()\n",
|
||||
" return _extract_code_or_text(text)\n",
|
||||
"\n",
|
||||
" def complete(self, *, model: str, system: str, user: str) -> str:\n",
|
||||
" try:\n",
|
||||
" return self._call(model=model, system=system, user=user, include_temperature=False)\n",
|
||||
" except _OpenAIBadRequest as e:\n",
|
||||
" # Extra safety: if some lib auto-injected temperature, retry without it\n",
|
||||
" if \"temperature\" in str(e).lower():\n",
|
||||
" return self._call(model=model, system=system, user=user, include_temperature=False)\n",
|
||||
" raise\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "31558bf0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _extract_code_or_text(s: str) -> str:\n",
|
||||
" \"\"\"Prefer fenced python if present; otherwise return raw text.\"\"\"\n",
|
||||
" m = re.search(r\"```(?:python)?\\s*(.*?)```\", s, flags=re.S | re.I)\n",
|
||||
" return m.group(1).strip() if m else s.strip()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _ensure_header_and_import(code: str, module_name: str) -> str:\n",
|
||||
" \"\"\"Make sure tests import the module and pytest; keep output minimal.\"\"\"\n",
|
||||
" code = code.strip()\n",
|
||||
" needs_pytest = \"import pytest\" not in code\n",
|
||||
" needs_import = f\"import {module_name}\" not in code and f\"import {module_name} as mod\" not in code\n",
|
||||
"\n",
|
||||
" header_lines = []\n",
|
||||
" if needs_pytest:\n",
|
||||
" header_lines.append(\"import pytest\")\n",
|
||||
" if needs_import:\n",
|
||||
" header_lines.append(f\"import {module_name} as mod\")\n",
|
||||
"\n",
|
||||
" if header_lines:\n",
|
||||
" code = \"\\n\".join(header_lines) + \"\\n\\n\" + code\n",
|
||||
" return code\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aeadedc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass(frozen=True)\n",
|
||||
"class SymbolInfo:\n",
|
||||
" kind: str # \"function\" | \"class\" | \"method\"\n",
|
||||
" name: str\n",
|
||||
" signature: str\n",
|
||||
" lineno: int\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class PublicAPIExtractor:\n",
|
||||
" \"\"\"Extract a small 'public API' summary from a Python module.\"\"\"\n",
|
||||
" def extract(self, source: str) -> List[SymbolInfo]:\n",
|
||||
" tree = ast.parse(source)\n",
|
||||
" out: List[SymbolInfo] = []\n",
|
||||
"\n",
|
||||
" for node in tree.body:\n",
|
||||
" if isinstance(node, ast.FunctionDef) and not node.name.startswith(\"_\"):\n",
|
||||
" out.append(SymbolInfo(\"function\", node.name, self._sig(node), node.lineno))\n",
|
||||
" elif isinstance(node, ast.ClassDef) and not node.name.startswith(\"_\"):\n",
|
||||
" out.append(SymbolInfo(\"class\", node.name, node.name, node.lineno))\n",
|
||||
" for sub in node.body:\n",
|
||||
" if isinstance(sub, ast.FunctionDef) and not sub.name.startswith(\"_\"):\n",
|
||||
" out.append(SymbolInfo(\"method\",\n",
|
||||
" f\"{node.name}.{sub.name}\",\n",
|
||||
" self._sig(sub),\n",
|
||||
" sub.lineno))\n",
|
||||
" return sorted(out, key=lambda s: (s.kind, s.name.lower(), s.lineno))\n",
|
||||
"\n",
|
||||
" def _sig(self, fn: ast.FunctionDef) -> str:\n",
|
||||
" args = [a.arg for a in fn.args.args]\n",
|
||||
" if fn.args.vararg:\n",
|
||||
" args.append(\"*\" + fn.args.vararg.arg)\n",
|
||||
" args.extend(a.arg + \"=?\" for a in fn.args.kwonlyargs)\n",
|
||||
" if fn.args.kwarg:\n",
|
||||
" args.append(\"**\" + fn.args.kwarg.arg)\n",
|
||||
" ret = \"\"\n",
|
||||
" if fn.returns is not None:\n",
|
||||
" try:\n",
|
||||
" ret = f\" -> {ast.unparse(fn.returns)}\"\n",
|
||||
" except Exception:\n",
|
||||
" pass\n",
|
||||
" return f\"def {fn.name}({', '.join(args)}){ret}:\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a45ac5be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class PromptBuilder:\n",
|
||||
" \"\"\"Builds concise, deterministic prompts for pytest generation.\"\"\"\n",
|
||||
" SYSTEM = (\n",
|
||||
" \"You are a senior Python engineer. Produce a single, self-contained pytest file.\\n\"\n",
|
||||
" \"Rules:\\n\"\n",
|
||||
" \"- Output only Python test code (no prose, no markdown fences).\\n\"\n",
|
||||
" \"- Use plain pytest tests (functions), no classes unless unavoidable.\\n\"\n",
|
||||
" \"- Deterministic: avoid network/IO; seed randomness if used.\\n\"\n",
|
||||
" \"- Import the target module by module name.\\n\"\n",
|
||||
" \"- Create a minimal test covering every public function and method.\\n\"\n",
|
||||
" \"- Prefer straightforward, fast assertions over exhaustive checks.\\n\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def build_user(self, *, module_name: str, source: str, symbols: List[SymbolInfo]) -> str:\n",
|
||||
" summary = \"\\n\".join(f\"- {s.kind:<6} {s.signature}\" for s in symbols) or \"- (no public symbols)\"\n",
|
||||
" return textwrap.dedent(f\"\"\"\n",
|
||||
" Create pytest tests for module `{module_name}`.\n",
|
||||
"\n",
|
||||
" Public API Summary:\n",
|
||||
" {summary}\n",
|
||||
"\n",
|
||||
" Constraints:\n",
|
||||
" - Import as: `import {module_name} as mod`\n",
|
||||
" - Keep tests tiny, fast, and deterministic.\n",
|
||||
"\n",
|
||||
" Full module source (for reference):\n",
|
||||
" # --- BEGIN SOURCE {module_name}.py ---\n",
|
||||
" {source}\n",
|
||||
" # --- END SOURCE ---\n",
|
||||
" \"\"\").strip()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "787e58b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TestGenerator:\n",
|
||||
" \"\"\"Orchestrates extraction, prompt, model call, and final polish.\"\"\"\n",
|
||||
" def __init__(self, llm: CompletionClient, model: str):\n",
|
||||
" self._llm = llm\n",
|
||||
" self._model = model\n",
|
||||
" self._extractor = PublicAPIExtractor()\n",
|
||||
" self._prompts = PromptBuilder()\n",
|
||||
"\n",
|
||||
" def generate_tests(self, module_name: str, source: str) -> str:\n",
|
||||
" symbols = self._extractor.extract(source)\n",
|
||||
" user = self._prompts.build_user(module_name=module_name, source=source, symbols=symbols)\n",
|
||||
" raw = self._llm.complete(model=self._model, system=self._prompts.SYSTEM, user=user)\n",
|
||||
" return _ensure_header_and_import(raw, module_name)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8402f62f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"LLM = OpenAIChatClient(OPENAI_CLIENT)\n",
|
||||
"SERVICE = TestGenerator(LLM, TESTGEN_MODEL)\n",
|
||||
"\n",
|
||||
"def build_module_name_from_path(path: str) -> str:\n",
|
||||
" p = Path(path)\n",
|
||||
" return p.stem\n",
|
||||
"\n",
|
||||
"def generate_from_code(module_name: str, code: str, save: bool, out_dir: str) -> tuple[str, str]:\n",
|
||||
" if not module_name.strip():\n",
|
||||
" return \"\", \"❌ Please provide a module name.\"\n",
|
||||
" if not code.strip():\n",
|
||||
" return \"\", \"❌ Please paste some Python code.\"\n",
|
||||
"\n",
|
||||
" tests_code = SERVICE.generate_tests(module_name=module_name.strip(), source=code)\n",
|
||||
" saved = \"\"\n",
|
||||
" if save:\n",
|
||||
" out = Path(out_dir or \"tests\")\n",
|
||||
" out.mkdir(parents=True, exist_ok=True)\n",
|
||||
" out_path = out / f\"test_{module_name}.py\"\n",
|
||||
" out_path.write_text(tests_code, encoding=\"utf-8\")\n",
|
||||
" saved = f\"✅ Saved to {out_path}\"\n",
|
||||
" return tests_code, saved\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def generate_from_file(file_obj, save: bool, out_dir: str) -> tuple[str, str]:\n",
|
||||
" if file_obj is None:\n",
|
||||
" return \"\", \"❌ Please upload a .py file.\"\n",
|
||||
" code = file_obj.decode(\"utf-8\")\n",
|
||||
" module_name = build_module_name_from_path(\"uploaded_module.py\")\n",
|
||||
" return generate_from_code(module_name, code, save, out_dir)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5d240ce5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with gr.Blocks(title=\"Simple PyTest Generator\") as ui:\n",
|
||||
" gr.Markdown(\"## 🧪 Simple PyTest Generator (Week 4 • Community Contribution)\\n\"\n",
|
||||
" \"Generate **minimal, deterministic** pytest tests from a Python module using a Frontier model.\")\n",
|
||||
"\n",
|
||||
" with gr.Tab(\"Paste Code\"):\n",
|
||||
" with gr.Row():\n",
|
||||
" module_name = gr.Textbox(label=\"Module name (used in `import <name> as mod`)\", value=\"mymodule\")\n",
|
||||
" code_in = gr.Code(label=\"Python module code\", language=\"python\", lines=22)\n",
|
||||
" with gr.Row():\n",
|
||||
" save_cb = gr.Checkbox(label=\"Save to /tests\", value=True)\n",
|
||||
" out_dir = gr.Textbox(label=\"Output folder\", value=\"tests\")\n",
|
||||
" gen_btn = gr.Button(\"Generate tests\", variant=\"primary\")\n",
|
||||
" with gr.Row():\n",
|
||||
" tests_out = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=20)\n",
|
||||
" status = gr.Markdown()\n",
|
||||
"\n",
|
||||
" def _on_gen(name, code, save, outdir):\n",
|
||||
" tests, msg = generate_from_code(name, code, save, outdir)\n",
|
||||
" return tests, (msg or \"✅ Done\")\n",
|
||||
"\n",
|
||||
" gen_btn.click(_on_gen, inputs=[module_name, code_in, save_cb, out_dir], outputs=[tests_out, status])\n",
|
||||
"\n",
|
||||
" with gr.Tab(\"Upload .py\"):\n",
|
||||
" upload = gr.File(file_types=[\".py\"], label=\"Upload a Python module\")\n",
|
||||
" with gr.Row():\n",
|
||||
" save_cb2 = gr.Checkbox(label=\"Save to /tests\", value=True)\n",
|
||||
" out_dir2 = gr.Textbox(label=\"Output folder\", value=\"tests\")\n",
|
||||
" gen_btn2 = gr.Button(\"Generate tests from file\")\n",
|
||||
" tests_out2 = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=20)\n",
|
||||
" status2 = gr.Markdown()\n",
|
||||
"\n",
|
||||
" def _on_gen_file(f, save, outdir):\n",
|
||||
" tests, msg = generate_from_file(f.read() if f else None, save, outdir)\n",
|
||||
" return tests, (msg or \"✅ Done\")\n",
|
||||
"\n",
|
||||
" gen_btn2.click(_on_gen_file, inputs=[upload, save_cb2, out_dir2], outputs=[tests_out2, status2])\n",
|
||||
"\n",
|
||||
"ui.launch(inbrowser=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "llm-engineering",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user