Add multiple LLM provider support, add pytest run.
This commit is contained in:
@@ -125,36 +125,6 @@
|
||||
"id": "31558bf0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _extract_code_or_text(s: str) -> str:\n",
|
||||
" \"\"\"Prefer fenced python if present; otherwise return raw text.\"\"\"\n",
|
||||
" m = re.search(r\"```(?:python)?\\s*(.*?)```\", s, flags=re.S | re.I)\n",
|
||||
" return m.group(1).strip() if m else s.strip()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _ensure_header_and_import(code: str, module_name: str) -> str:\n",
|
||||
" \"\"\"Make sure tests import the module and pytest; keep output minimal.\"\"\"\n",
|
||||
" code = code.strip()\n",
|
||||
" needs_pytest = \"import pytest\" not in code\n",
|
||||
" needs_import = f\"import {module_name}\" not in code and f\"import {module_name} as mod\" not in code\n",
|
||||
"\n",
|
||||
" header_lines = []\n",
|
||||
" if needs_pytest:\n",
|
||||
" header_lines.append(\"import pytest\")\n",
|
||||
" if needs_import:\n",
|
||||
" header_lines.append(f\"import {module_name} as mod\")\n",
|
||||
"\n",
|
||||
" if header_lines:\n",
|
||||
" code = \"\\n\".join(header_lines) + \"\\n\\n\" + code\n",
|
||||
" return code\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aeadedc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass(frozen=True)\n",
|
||||
"class SymbolInfo:\n",
|
||||
@@ -163,13 +133,11 @@
|
||||
" signature: str\n",
|
||||
" lineno: int\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class PublicAPIExtractor:\n",
|
||||
" \"\"\"Extract a small 'public API' summary from a Python module.\"\"\"\n",
|
||||
" \"\"\"Extract concise 'public API' summary from a Python module.\"\"\"\n",
|
||||
" def extract(self, source: str) -> List[SymbolInfo]:\n",
|
||||
" tree = ast.parse(source)\n",
|
||||
" out: List[SymbolInfo] = []\n",
|
||||
"\n",
|
||||
" for node in tree.body:\n",
|
||||
" if isinstance(node, ast.FunctionDef) and not node.name.startswith(\"_\"):\n",
|
||||
" out.append(SymbolInfo(\"function\", node.name, self._sig(node), node.lineno))\n",
|
||||
@@ -202,21 +170,21 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a45ac5be",
|
||||
"id": "3aeadedc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class PromptBuilder:\n",
|
||||
" \"\"\"Builds concise, deterministic prompts for pytest generation.\"\"\"\n",
|
||||
" \"\"\"Builds deterministic prompts for pytest generation.\"\"\"\n",
|
||||
" SYSTEM = (\n",
|
||||
" \"You are a senior Python engineer. Produce a single, self-contained pytest file.\\n\"\n",
|
||||
" \"Rules:\\n\"\n",
|
||||
" \"- Output only Python test code (no prose, no markdown fences).\\n\"\n",
|
||||
" \"- Use plain pytest tests (functions), no classes unless unavoidable.\\n\"\n",
|
||||
" \"- Deterministic: avoid network/IO; seed randomness if used.\\n\"\n",
|
||||
" \"- Import the target module by module name.\\n\"\n",
|
||||
" \"- Create a minimal test covering every public function and method.\\n\"\n",
|
||||
" \"- Prefer straightforward, fast assertions over exhaustive checks.\\n\"\n",
|
||||
" \"- Import the target module by module name only.\\n\"\n",
|
||||
" \"- Cover every public function and method with at least one tiny test.\\n\"\n",
|
||||
" \"- Prefer straightforward, fast assertions.\\n\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def build_user(self, *, module_name: str, source: str, symbols: List[SymbolInfo]) -> str:\n",
|
||||
@@ -238,6 +206,33 @@
|
||||
" \"\"\").strip()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a45ac5be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _ensure_header_and_import(code: str, module_name: str) -> str:\n",
|
||||
" \"\"\"Ensure tests import pytest and the target module as 'mod'.\"\"\"\n",
|
||||
" code = code.strip()\n",
|
||||
" needs_pytest = \"import pytest\" not in code\n",
|
||||
" has_mod = (f\"import {module_name} as mod\" in code) or (f\"from {module_name} import\" in code)\n",
|
||||
" needs_import = not has_mod\n",
|
||||
"\n",
|
||||
" header = []\n",
|
||||
" if needs_pytest:\n",
|
||||
" header.append(\"import pytest\")\n",
|
||||
" if needs_import:\n",
|
||||
" header.append(f\"import {module_name} as mod\")\n",
|
||||
"\n",
|
||||
" return (\"\\n\".join(header) + \"\\n\\n\" + code) if header else code\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def build_module_name_from_path(path: str) -> str:\n",
|
||||
" return Path(path).stem\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -246,7 +241,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TestGenerator:\n",
|
||||
" \"\"\"Orchestrates extraction, prompt, model call, and final polish.\"\"\"\n",
|
||||
" \"\"\"Extraction → prompt → model → polish.\"\"\"\n",
|
||||
" def __init__(self, llm: CompletionClient):\n",
|
||||
" self._llm = llm\n",
|
||||
" self._extractor = PublicAPIExtractor()\n",
|
||||
@@ -266,20 +261,93 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"LLM = OpenAIChatClient(OPENAI_CLIENT)\n",
|
||||
"SERVICE = TestGenerator(LLM, TESTGEN_MODEL)\n",
|
||||
"def _parse_pytest_summary(output: str) -> Tuple[str, Dict[str, int]]:\n",
|
||||
" \"\"\"\n",
|
||||
" Parse the final summary line like:\n",
|
||||
" '3 passed, 1 failed, 2 skipped in 0.12s'\n",
|
||||
" Return (summary_line, counts_dict).\n",
|
||||
" \"\"\"\n",
|
||||
" summary_line = \"\"\n",
|
||||
" for line in output.strip().splitlines()[::-1]: # scan from end\n",
|
||||
" if \" passed\" in line or \" failed\" in line or \" error\" in line or \" skipped\" in line or \" deselected\" in line:\n",
|
||||
" summary_line = line.strip()\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"def build_module_name_from_path(path: str) -> str:\n",
|
||||
" p = Path(path)\n",
|
||||
" return p.stem\n",
|
||||
" counts = {\"passed\": 0, \"failed\": 0, \"errors\": 0, \"skipped\": 0, \"xfail\": 0, \"xpassed\": 0}\n",
|
||||
" m = re.findall(r\"(\\d+)\\s+(passed|failed|errors?|skipped|xfailed|xpassed)\", summary_line)\n",
|
||||
" for num, kind in m:\n",
|
||||
" if kind.startswith(\"error\"):\n",
|
||||
" counts[\"errors\"] += int(num)\n",
|
||||
" elif kind == \"passed\":\n",
|
||||
" counts[\"passed\"] += int(num)\n",
|
||||
" elif kind == \"failed\":\n",
|
||||
" counts[\"failed\"] += int(num)\n",
|
||||
" elif kind == \"skipped\":\n",
|
||||
" counts[\"skipped\"] += int(num)\n",
|
||||
" elif kind == \"xfailed\":\n",
|
||||
" counts[\"xfail\"] += int(num)\n",
|
||||
" elif kind == \"xpassed\":\n",
|
||||
" counts[\"xpassed\"] += int(num)\n",
|
||||
"\n",
|
||||
"def generate_from_code(module_name: str, code: str, save: bool, out_dir: str) -> tuple[str, str]:\n",
|
||||
" return summary_line or \"(no summary line found)\", counts\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def run_pytest_on_snippet(module_name: str, module_code: str, tests_code: str) -> Tuple[str, str]:\n",
|
||||
" \"\"\"\n",
|
||||
" Create an isolated temp workspace, write module + tests, run pytest,\n",
|
||||
" and return (human_summary, full_cli_output).\n",
|
||||
" \"\"\"\n",
|
||||
" if not module_name or not module_code.strip() or not tests_code.strip():\n",
|
||||
" return \"❌ Provide module name, module code, and tests.\", \"\"\n",
|
||||
"\n",
|
||||
" run_id = uuid.uuid4().hex[:8]\n",
|
||||
" base = Path(\".pytest_runs\") / f\"run_{run_id}\"\n",
|
||||
" tests_dir = base / \"tests\"\n",
|
||||
" tests_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
" # Write module and tests\n",
|
||||
" (base / f\"{module_name}.py\").write_text(module_code, encoding=\"utf-8\")\n",
|
||||
" (tests_dir / f\"test_{module_name}.py\").write_text(tests_code, encoding=\"utf-8\")\n",
|
||||
"\n",
|
||||
" # Run pytest with this temp dir on PYTHONPATH\n",
|
||||
" env = os.environ.copy()\n",
|
||||
" env[\"PYTHONPATH\"] = str(base) + os.pathsep + env.get(\"PYTHONPATH\", \"\")\n",
|
||||
" cmd = [sys.executable, \"-m\", \"pytest\", \"-q\"] # quiet output, but still includes summary\n",
|
||||
" proc = subprocess.run(cmd, cwd=base, env=env, text=True, capture_output=True)\n",
|
||||
"\n",
|
||||
" full_out = (proc.stdout or \"\") + (\"\\n\" + proc.stderr if proc.stderr else \"\")\n",
|
||||
" summary_line, counts = _parse_pytest_summary(full_out)\n",
|
||||
"\n",
|
||||
" badges = []\n",
|
||||
" for key in (\"passed\", \"failed\", \"errors\", \"skipped\", \"xpassed\", \"xfail\"):\n",
|
||||
" val = counts.get(key, 0)\n",
|
||||
" if val:\n",
|
||||
" badges.append(f\"**{key}: {val}**\")\n",
|
||||
" badges = \" • \".join(badges) if badges else \"no tests collected?\"\n",
|
||||
"\n",
|
||||
" human = f\"{summary_line}\\n\\n{badges}\"\n",
|
||||
" return human, full_out\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5d240ce5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"LLM = MultiModelChatClient(MODEL_REGISTRY)\n",
|
||||
"SERVICE = TestGenerator(LLM)\n",
|
||||
"\n",
|
||||
"def generate_from_code(model_label: str, module_name: str, code: str, save: bool, out_dir: str) -> Tuple[str, str]:\n",
|
||||
" if not model_label or model_label not in MODEL_REGISTRY:\n",
|
||||
" return \"\", \"❌ Pick a model (or add API keys for providers in .env).\"\n",
|
||||
" if not module_name.strip():\n",
|
||||
" return \"\", \"❌ Please provide a module name.\"\n",
|
||||
" if not code.strip():\n",
|
||||
" return \"\", \"❌ Please paste some Python code.\"\n",
|
||||
"\n",
|
||||
" tests_code = SERVICE.generate_tests(module_name=module_name.strip(), source=code)\n",
|
||||
" tests_code = SERVICE.generate_tests(model_label=model_label, module_name=module_name.strip(), source=code)\n",
|
||||
" saved = \"\"\n",
|
||||
" if save:\n",
|
||||
" out = Path(out_dir or \"tests\")\n",
|
||||
@@ -290,57 +358,107 @@
|
||||
" return tests_code, saved\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def generate_from_file(file_obj, save: bool, out_dir: str) -> tuple[str, str]:\n",
|
||||
"def generate_from_file(model_label: str, file_obj, save: bool, out_dir: str) -> Tuple[str, str]:\n",
|
||||
" if file_obj is None:\n",
|
||||
" return \"\", \"❌ Please upload a .py file.\"\n",
|
||||
" code = file_obj.decode(\"utf-8\")\n",
|
||||
" module_name = build_module_name_from_path(\"uploaded_module.py\")\n",
|
||||
" return generate_from_code(module_name, code, save, out_dir)\n"
|
||||
" return generate_from_code(model_label, module_name, code, save, out_dir)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5d240ce5",
|
||||
"id": "e3e1401a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with gr.Blocks(title=\"Simple PyTest Generator\") as ui:\n",
|
||||
" gr.Markdown(\"## 🧪 Simple PyTest Generator (Week 4 • Community Contribution)\\n\"\n",
|
||||
" \"Generate **minimal, deterministic** pytest tests from a Python module using a Frontier model.\")\n",
|
||||
"EXAMPLE_CODE = \"\"\"\\\n",
|
||||
"def add(a: int, b: int) -> int:\n",
|
||||
" return a + b\n",
|
||||
"\n",
|
||||
" with gr.Tab(\"Paste Code\"):\n",
|
||||
"def divide(a: float, b: float) -> float:\n",
|
||||
" if b == 0:\n",
|
||||
" raise ZeroDivisionError(\"b must be non-zero\")\n",
|
||||
" return a / b\n",
|
||||
"\n",
|
||||
"class Counter:\n",
|
||||
" def __init__(self, start: int = 0):\n",
|
||||
" self.value = start\n",
|
||||
"\n",
|
||||
" def inc(self, by: int = 1):\n",
|
||||
" self.value += by\n",
|
||||
" return self.value\n",
|
||||
"\"\"\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f802450e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with gr.Blocks(title=\"PyTest Generator\") as ui:\n",
|
||||
" gr.Markdown(\n",
|
||||
" \"## 🧪 PyTest Generator (Week 4 • Community Contribution)\\n\"\n",
|
||||
" \"Generate **minimal, deterministic** pytest tests from a Python module using your chosen model/provider.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" with gr.Row(equal_height=True):\n",
|
||||
" # LEFT: inputs (module code)\n",
|
||||
" with gr.Column(scale=6):\n",
|
||||
" with gr.Row():\n",
|
||||
" module_name = gr.Textbox(label=\"Module name (used in `import <name> as mod`)\", value=\"mymodule\")\n",
|
||||
" code_in = gr.Code(label=\"Python module code\", language=\"python\", lines=22)\n",
|
||||
" model_dd = gr.Dropdown(\n",
|
||||
" list(MODEL_REGISTRY.keys()),\n",
|
||||
" value=DEFAULT_MODEL,\n",
|
||||
" label=\"Model (OpenAI, Gemini, Groq)\"\n",
|
||||
" )\n",
|
||||
" module_name_tb = gr.Textbox(\n",
|
||||
" label=\"Module name (used in `import <name> as mod`)\",\n",
|
||||
" value=\"mymodule\"\n",
|
||||
" )\n",
|
||||
" code_in = gr.Code(\n",
|
||||
" label=\"Python module code\",\n",
|
||||
" language=\"python\",\n",
|
||||
" lines=24,\n",
|
||||
" value=EXAMPLE_CODE\n",
|
||||
" )\n",
|
||||
" with gr.Row():\n",
|
||||
" save_cb = gr.Checkbox(label=\"Save to /tests\", value=True)\n",
|
||||
" out_dir = gr.Textbox(label=\"Output folder\", value=\"tests\")\n",
|
||||
" save_cb = gr.Checkbox(label=\"Also save generated tests to /tests\", value=True)\n",
|
||||
" out_dir_tb = gr.Textbox(label=\"Output folder\", value=\"tests\")\n",
|
||||
" gen_btn = gr.Button(\"Generate tests\", variant=\"primary\")\n",
|
||||
"\n",
|
||||
" # RIGHT: outputs (generated tests + pytest run)\n",
|
||||
" with gr.Column(scale=6):\n",
|
||||
" tests_out = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=24)\n",
|
||||
" with gr.Row():\n",
|
||||
" tests_out = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=20)\n",
|
||||
" status = gr.Markdown()\n",
|
||||
" run_btn = gr.Button(\"Run PyTest\", variant=\"secondary\")\n",
|
||||
" summary_md = gr.Markdown()\n",
|
||||
" full_out = gr.Textbox(label=\"Full PyTest output\", lines=12)\n",
|
||||
"\n",
|
||||
" def _on_gen(name, code, save, outdir):\n",
|
||||
" tests, msg = generate_from_code(name, code, save, outdir)\n",
|
||||
" return tests, (msg or \"✅ Done\")\n",
|
||||
" # --- events ---\n",
|
||||
"\n",
|
||||
" gen_btn.click(_on_gen, inputs=[module_name, code_in, save_cb, out_dir], outputs=[tests_out, status])\n",
|
||||
" def _on_gen(model_label, name, code, save, outdir):\n",
|
||||
" tests, msg = generate_from_code(model_label, name, code, save, outdir)\n",
|
||||
" status = msg or \"✅ Done\"\n",
|
||||
" return tests, status\n",
|
||||
"\n",
|
||||
" with gr.Tab(\"Upload .py\"):\n",
|
||||
" upload = gr.File(file_types=[\".py\"], label=\"Upload a Python module\")\n",
|
||||
" with gr.Row():\n",
|
||||
" save_cb2 = gr.Checkbox(label=\"Save to /tests\", value=True)\n",
|
||||
" out_dir2 = gr.Textbox(label=\"Output folder\", value=\"tests\")\n",
|
||||
" gen_btn2 = gr.Button(\"Generate tests from file\")\n",
|
||||
" tests_out2 = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=20)\n",
|
||||
" status2 = gr.Markdown()\n",
|
||||
" gen_btn.click(\n",
|
||||
" _on_gen,\n",
|
||||
" inputs=[model_dd, module_name_tb, code_in, save_cb, out_dir_tb],\n",
|
||||
" outputs=[tests_out, summary_md],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def _on_gen_file(f, save, outdir):\n",
|
||||
" tests, msg = generate_from_file(f.read() if f else None, save, outdir)\n",
|
||||
" return tests, (msg or \"✅ Done\")\n",
|
||||
" def _on_run(name, code, tests):\n",
|
||||
" summary, details = run_pytest_on_snippet(name, code, tests)\n",
|
||||
" return summary, details\n",
|
||||
"\n",
|
||||
" gen_btn2.click(_on_gen_file, inputs=[upload, save_cb2, out_dir2], outputs=[tests_out2, status2])\n",
|
||||
" run_btn.click(\n",
|
||||
" _on_run,\n",
|
||||
" inputs=[module_name_tb, code_in, tests_out],\n",
|
||||
" outputs=[summary_md, full_out],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"ui.launch(inbrowser=True)\n"
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user