Create pytest generator tool.

This commit is contained in:
Nik
2025-10-26 20:01:49 +05:30
parent 48076f9d39
commit 74e48bba2c

View File

@@ -0,0 +1,321 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "ba193fd5",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import re\n",
"import ast\n",
"import textwrap\n",
"from pathlib import Path\n",
"from dataclasses import dataclass\n",
"from typing import List, Protocol, Optional\n",
"from openai import BadRequestError as _OpenAIBadRequest\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import gradio as gr\n",
"\n",
"\n",
"load_dotenv(override=True)\n",
"\n",
"TESTGEN_MODEL = os.getenv(\"TESTGEN_MODEL\", \"gpt-5-nano\")\n",
"OPENAI_CLIENT = OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5d6b0f2",
"metadata": {},
"outputs": [],
"source": [
"\n",
"class CompletionClient(Protocol):\n",
" \"\"\"Any LLM client provides a .complete() method.\"\"\"\n",
" def complete(self, *, model: str, system: str, user: str) -> str: ...\n",
"\n",
"\n",
"class OpenAIChatClient:\n",
" \"\"\"Adapter over OpenAI chat completions that omits unsupported params.\"\"\"\n",
" def __init__(self, client: OpenAI):\n",
" self._client = client\n",
"\n",
" def _call(self, *, model: str, system: str, user: str, include_temperature: bool = False) -> str:\n",
" params = {\n",
" \"model\": model,\n",
" \"messages\": [\n",
" {\"role\": \"system\", \"content\": system},\n",
" {\"role\": \"user\", \"content\": user},\n",
" ],\n",
" }\n",
" if include_temperature: \n",
" params[\"temperature\"] = 0.7 \n",
"\n",
" resp = self._client.chat.completions.create(**params)\n",
" text = (resp.choices[0].message.content or \"\").strip()\n",
" return _extract_code_or_text(text)\n",
"\n",
" def complete(self, *, model: str, system: str, user: str) -> str:\n",
" try:\n",
" return self._call(model=model, system=system, user=user, include_temperature=False)\n",
" except _OpenAIBadRequest as e:\n",
" # Extra safety: if some lib auto-injected temperature, retry without it\n",
" if \"temperature\" in str(e).lower():\n",
" return self._call(model=model, system=system, user=user, include_temperature=False)\n",
" raise\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "31558bf0",
"metadata": {},
"outputs": [],
"source": [
"def _extract_code_or_text(s: str) -> str:\n",
" \"\"\"Prefer fenced python if present; otherwise return raw text.\"\"\"\n",
" m = re.search(r\"```(?:python)?\\s*(.*?)```\", s, flags=re.S | re.I)\n",
" return m.group(1).strip() if m else s.strip()\n",
"\n",
"\n",
"def _ensure_header_and_import(code: str, module_name: str) -> str:\n",
" \"\"\"Make sure tests import the module and pytest; keep output minimal.\"\"\"\n",
" code = code.strip()\n",
" needs_pytest = \"import pytest\" not in code\n",
" needs_import = f\"import {module_name}\" not in code and f\"import {module_name} as mod\" not in code\n",
"\n",
" header_lines = []\n",
" if needs_pytest:\n",
" header_lines.append(\"import pytest\")\n",
" if needs_import:\n",
" header_lines.append(f\"import {module_name} as mod\")\n",
"\n",
" if header_lines:\n",
" code = \"\\n\".join(header_lines) + \"\\n\\n\" + code\n",
" return code\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aeadedc",
"metadata": {},
"outputs": [],
"source": [
"@dataclass(frozen=True)\n",
"class SymbolInfo:\n",
" kind: str # \"function\" | \"class\" | \"method\"\n",
" name: str\n",
" signature: str\n",
" lineno: int\n",
"\n",
"\n",
"class PublicAPIExtractor:\n",
" \"\"\"Extract a small 'public API' summary from a Python module.\"\"\"\n",
" def extract(self, source: str) -> List[SymbolInfo]:\n",
" tree = ast.parse(source)\n",
" out: List[SymbolInfo] = []\n",
"\n",
" for node in tree.body:\n",
" if isinstance(node, ast.FunctionDef) and not node.name.startswith(\"_\"):\n",
" out.append(SymbolInfo(\"function\", node.name, self._sig(node), node.lineno))\n",
" elif isinstance(node, ast.ClassDef) and not node.name.startswith(\"_\"):\n",
" out.append(SymbolInfo(\"class\", node.name, node.name, node.lineno))\n",
" for sub in node.body:\n",
" if isinstance(sub, ast.FunctionDef) and not sub.name.startswith(\"_\"):\n",
" out.append(SymbolInfo(\"method\",\n",
" f\"{node.name}.{sub.name}\",\n",
" self._sig(sub),\n",
" sub.lineno))\n",
" return sorted(out, key=lambda s: (s.kind, s.name.lower(), s.lineno))\n",
"\n",
" def _sig(self, fn: ast.FunctionDef) -> str:\n",
" args = [a.arg for a in fn.args.args]\n",
" if fn.args.vararg:\n",
" args.append(\"*\" + fn.args.vararg.arg)\n",
" args.extend(a.arg + \"=?\" for a in fn.args.kwonlyargs)\n",
" if fn.args.kwarg:\n",
" args.append(\"**\" + fn.args.kwarg.arg)\n",
" ret = \"\"\n",
" if fn.returns is not None:\n",
" try:\n",
" ret = f\" -> {ast.unparse(fn.returns)}\"\n",
" except Exception:\n",
" pass\n",
" return f\"def {fn.name}({', '.join(args)}){ret}:\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a45ac5be",
"metadata": {},
"outputs": [],
"source": [
"class PromptBuilder:\n",
" \"\"\"Builds concise, deterministic prompts for pytest generation.\"\"\"\n",
" SYSTEM = (\n",
" \"You are a senior Python engineer. Produce a single, self-contained pytest file.\\n\"\n",
" \"Rules:\\n\"\n",
" \"- Output only Python test code (no prose, no markdown fences).\\n\"\n",
" \"- Use plain pytest tests (functions), no classes unless unavoidable.\\n\"\n",
" \"- Deterministic: avoid network/IO; seed randomness if used.\\n\"\n",
" \"- Import the target module by module name.\\n\"\n",
" \"- Create a minimal test covering every public function and method.\\n\"\n",
" \"- Prefer straightforward, fast assertions over exhaustive checks.\\n\"\n",
" )\n",
"\n",
" def build_user(self, *, module_name: str, source: str, symbols: List[SymbolInfo]) -> str:\n",
" summary = \"\\n\".join(f\"- {s.kind:<6} {s.signature}\" for s in symbols) or \"- (no public symbols)\"\n",
" return textwrap.dedent(f\"\"\"\n",
" Create pytest tests for module `{module_name}`.\n",
"\n",
" Public API Summary:\n",
" {summary}\n",
"\n",
" Constraints:\n",
" - Import as: `import {module_name} as mod`\n",
" - Keep tests tiny, fast, and deterministic.\n",
"\n",
" Full module source (for reference):\n",
" # --- BEGIN SOURCE {module_name}.py ---\n",
" {source}\n",
" # --- END SOURCE ---\n",
" \"\"\").strip()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "787e58b6",
"metadata": {},
"outputs": [],
"source": [
"class TestGenerator:\n",
" \"\"\"Orchestrates extraction, prompt, model call, and final polish.\"\"\"\n",
" def __init__(self, llm: CompletionClient, model: str):\n",
" self._llm = llm\n",
" self._model = model\n",
" self._extractor = PublicAPIExtractor()\n",
" self._prompts = PromptBuilder()\n",
"\n",
" def generate_tests(self, module_name: str, source: str) -> str:\n",
" symbols = self._extractor.extract(source)\n",
" user = self._prompts.build_user(module_name=module_name, source=source, symbols=symbols)\n",
" raw = self._llm.complete(model=self._model, system=self._prompts.SYSTEM, user=user)\n",
" return _ensure_header_and_import(raw, module_name)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8402f62f",
"metadata": {},
"outputs": [],
"source": [
"LLM = OpenAIChatClient(OPENAI_CLIENT)\n",
"SERVICE = TestGenerator(LLM, TESTGEN_MODEL)\n",
"\n",
"def build_module_name_from_path(path: str) -> str:\n",
" p = Path(path)\n",
" return p.stem\n",
"\n",
"def generate_from_code(module_name: str, code: str, save: bool, out_dir: str) -> tuple[str, str]:\n",
" if not module_name.strip():\n",
" return \"\", \"❌ Please provide a module name.\"\n",
" if not code.strip():\n",
" return \"\", \"❌ Please paste some Python code.\"\n",
"\n",
" tests_code = SERVICE.generate_tests(module_name=module_name.strip(), source=code)\n",
" saved = \"\"\n",
" if save:\n",
" out = Path(out_dir or \"tests\")\n",
" out.mkdir(parents=True, exist_ok=True)\n",
" out_path = out / f\"test_{module_name}.py\"\n",
" out_path.write_text(tests_code, encoding=\"utf-8\")\n",
" saved = f\"✅ Saved to {out_path}\"\n",
" return tests_code, saved\n",
"\n",
"\n",
"def generate_from_file(file_obj, save: bool, out_dir: str) -> tuple[str, str]:\n",
" if file_obj is None:\n",
" return \"\", \"❌ Please upload a .py file.\"\n",
" code = file_obj.decode(\"utf-8\")\n",
" module_name = build_module_name_from_path(\"uploaded_module.py\")\n",
" return generate_from_code(module_name, code, save, out_dir)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d240ce5",
"metadata": {},
"outputs": [],
"source": [
"with gr.Blocks(title=\"Simple PyTest Generator\") as ui:\n",
" gr.Markdown(\"## 🧪 Simple PyTest Generator (Week 4 • Community Contribution)\\n\"\n",
" \"Generate **minimal, deterministic** pytest tests from a Python module using a Frontier model.\")\n",
"\n",
" with gr.Tab(\"Paste Code\"):\n",
" with gr.Row():\n",
" module_name = gr.Textbox(label=\"Module name (used in `import <name> as mod`)\", value=\"mymodule\")\n",
" code_in = gr.Code(label=\"Python module code\", language=\"python\", lines=22)\n",
" with gr.Row():\n",
" save_cb = gr.Checkbox(label=\"Save to /tests\", value=True)\n",
" out_dir = gr.Textbox(label=\"Output folder\", value=\"tests\")\n",
" gen_btn = gr.Button(\"Generate tests\", variant=\"primary\")\n",
" with gr.Row():\n",
" tests_out = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=20)\n",
" status = gr.Markdown()\n",
"\n",
" def _on_gen(name, code, save, outdir):\n",
" tests, msg = generate_from_code(name, code, save, outdir)\n",
" return tests, (msg or \"✅ Done\")\n",
"\n",
" gen_btn.click(_on_gen, inputs=[module_name, code_in, save_cb, out_dir], outputs=[tests_out, status])\n",
"\n",
" with gr.Tab(\"Upload .py\"):\n",
" upload = gr.File(file_types=[\".py\"], label=\"Upload a Python module\")\n",
" with gr.Row():\n",
" save_cb2 = gr.Checkbox(label=\"Save to /tests\", value=True)\n",
" out_dir2 = gr.Textbox(label=\"Output folder\", value=\"tests\")\n",
" gen_btn2 = gr.Button(\"Generate tests from file\")\n",
" tests_out2 = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=20)\n",
" status2 = gr.Markdown()\n",
"\n",
" def _on_gen_file(f, save, outdir):\n",
" tests, msg = generate_from_file(f.read() if f else None, save, outdir)\n",
" return tests, (msg or \"✅ Done\")\n",
"\n",
" gen_btn2.click(_on_gen_file, inputs=[upload, save_cb2, out_dir2], outputs=[tests_out2, status2])\n",
"\n",
"ui.launch(inbrowser=True)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "llm-engineering",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}