Merge pull request #842 from rnik12/rnik12-week4
[Bootcamp] - Nikhil - Week 4 Exercise - Multi LLM Testing for Pytest Generattion
This commit is contained in:
@@ -0,0 +1,498 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b8be8252",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!uv pip install pytest"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ba193fd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import ast\n",
|
||||
"import sys\n",
|
||||
"import uuid\n",
|
||||
"import json\n",
|
||||
"import textwrap\n",
|
||||
"import subprocess\n",
|
||||
"from pathlib import Path\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import List, Protocol, Tuple, Dict, Optional\n",
|
||||
"\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from openai import OpenAI\n",
|
||||
"from openai import BadRequestError as _OpenAIBadRequest\n",
|
||||
"import gradio as gr\n",
|
||||
"\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"\n",
|
||||
"# --- Provider base URLs (Gemini & Groq speak OpenAI-compatible API) ---\n",
|
||||
"GEMINI_BASE = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
|
||||
"GROQ_BASE = \"https://api.groq.com/openai/v1\"\n",
|
||||
"\n",
|
||||
"# --- API Keys (add these in your .env) ---\n",
|
||||
"openai_api_key = os.getenv(\"OPENAI_API_KEY\") # OpenAI\n",
|
||||
"google_api_key = os.getenv(\"GOOGLE_API_KEY\") # Gemini\n",
|
||||
"groq_api_key = os.getenv(\"GROQ_API_KEY\") # Groq\n",
|
||||
"\n",
|
||||
"# --- Clients ---\n",
|
||||
"openai_client = OpenAI() # OpenAI default (reads OPENAI_API_KEY)\n",
|
||||
"gemini_client = OpenAI(api_key=google_api_key, base_url=GEMINI_BASE) if google_api_key else None\n",
|
||||
"groq_client = OpenAI(api_key=groq_api_key, base_url=GROQ_BASE) if groq_api_key else None\n",
|
||||
"\n",
|
||||
"# --- Model registry: label -> { client, model } ---\n",
|
||||
"MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}\n",
|
||||
"\n",
|
||||
"def _register(label: str, client: Optional[OpenAI], model_id: str):\n",
|
||||
" \"\"\"Add a model to the registry only if its client is configured.\"\"\"\n",
|
||||
" if client is not None:\n",
|
||||
" MODEL_REGISTRY[label] = {\"client\": client, \"model\": model_id}\n",
|
||||
"\n",
|
||||
"# OpenAI\n",
|
||||
"_register(\"OpenAI • GPT-5\", openai_client, \"gpt-5\")\n",
|
||||
"_register(\"OpenAI • GPT-5 Nano\", openai_client, \"gpt-5-nano\")\n",
|
||||
"_register(\"OpenAI • GPT-4o-mini\", openai_client, \"gpt-4o-mini\")\n",
|
||||
"\n",
|
||||
"# Gemini (Google)\n",
|
||||
"_register(\"Gemini • 2.5 Pro\", gemini_client, \"gemini-2.5-pro\")\n",
|
||||
"_register(\"Gemini • 2.5 Flash\", gemini_client, \"gemini-2.5-flash\")\n",
|
||||
"\n",
|
||||
"# Groq\n",
|
||||
"_register(\"Groq • Llama 3.1 8B\", groq_client, \"llama-3.1-8b-instant\")\n",
|
||||
"_register(\"Groq • Llama 3.3 70B\", groq_client, \"llama-3.3-70b-versatile\")\n",
|
||||
"_register(\"Groq • GPT-OSS 20B\", groq_client, \"openai/gpt-oss-20b\")\n",
|
||||
"_register(\"Groq • GPT-OSS 120B\", groq_client, \"openai/gpt-oss-120b\")\n",
|
||||
"\n",
|
||||
"DEFAULT_MODEL = next(iter(MODEL_REGISTRY.keys()), None)\n",
|
||||
"\n",
|
||||
"print(f\"Providers configured → OpenAI:{bool(openai_api_key)} Gemini:{bool(google_api_key)} Groq:{bool(groq_api_key)}\")\n",
|
||||
"print(\"Models available →\", \", \".join(MODEL_REGISTRY.keys()) or \"None (add API keys in .env)\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e5d6b0f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class CompletionClient(Protocol):\n",
|
||||
" \"\"\"Any LLM client provides a .complete() method using a registry label.\"\"\"\n",
|
||||
" def complete(self, *, model_label: str, system: str, user: str) -> str: ...\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _extract_code_or_text(s: str) -> str:\n",
|
||||
" \"\"\"Prefer fenced python if present; otherwise return raw text.\"\"\"\n",
|
||||
" m = re.search(r\"```(?:python)?\\s*(.*?)```\", s, flags=re.S | re.I)\n",
|
||||
" return m.group(1).strip() if m else s.strip()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MultiModelChatClient:\n",
|
||||
" \"\"\"Routes requests to the right provider/client based on model label.\"\"\"\n",
|
||||
" def __init__(self, registry: Dict[str, Dict[str, object]]):\n",
|
||||
" self._registry = registry\n",
|
||||
"\n",
|
||||
" def _call(self, *, client: OpenAI, model_id: str, system: str, user: str) -> str:\n",
|
||||
" params = {\n",
|
||||
" \"model\": model_id,\n",
|
||||
" \"messages\": [\n",
|
||||
" {\"role\": \"system\", \"content\": system},\n",
|
||||
" {\"role\": \"user\", \"content\": user},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
" resp = client.chat.completions.create(**params) # do NOT send temperature for strict providers\n",
|
||||
" text = (resp.choices[0].message.content or \"\").strip()\n",
|
||||
" return _extract_code_or_text(text)\n",
|
||||
"\n",
|
||||
" def complete(self, *, model_label: str, system: str, user: str) -> str:\n",
|
||||
" if model_label not in self._registry:\n",
|
||||
" raise ValueError(f\"Unknown model label: {model_label}\")\n",
|
||||
" info = self._registry[model_label]\n",
|
||||
" client = info[\"client\"]\n",
|
||||
" model = info[\"model\"]\n",
|
||||
" try:\n",
|
||||
" return self._call(client=client, model_id=str(model), system=system, user=user)\n",
|
||||
" except _OpenAIBadRequest as e:\n",
|
||||
" # Providers may reject stray params; we don't send any, but retry anyway.\n",
|
||||
" if \"temperature\" in str(e).lower():\n",
|
||||
" return self._call(client=client, model_id=str(model), system=system, user=user)\n",
|
||||
" raise\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "31558bf0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass(frozen=True)\n",
|
||||
"class SymbolInfo:\n",
|
||||
" kind: str # \"function\" | \"class\" | \"method\"\n",
|
||||
" name: str\n",
|
||||
" signature: str\n",
|
||||
" lineno: int\n",
|
||||
"\n",
|
||||
"class PublicAPIExtractor:\n",
|
||||
" \"\"\"Extract concise 'public API' summary from a Python module.\"\"\"\n",
|
||||
" def extract(self, source: str) -> List[SymbolInfo]:\n",
|
||||
" tree = ast.parse(source)\n",
|
||||
" out: List[SymbolInfo] = []\n",
|
||||
" for node in tree.body:\n",
|
||||
" if isinstance(node, ast.FunctionDef) and not node.name.startswith(\"_\"):\n",
|
||||
" out.append(SymbolInfo(\"function\", node.name, self._sig(node), node.lineno))\n",
|
||||
" elif isinstance(node, ast.ClassDef) and not node.name.startswith(\"_\"):\n",
|
||||
" out.append(SymbolInfo(\"class\", node.name, node.name, node.lineno))\n",
|
||||
" for sub in node.body:\n",
|
||||
" if isinstance(sub, ast.FunctionDef) and not sub.name.startswith(\"_\"):\n",
|
||||
" out.append(SymbolInfo(\"method\",\n",
|
||||
" f\"{node.name}.{sub.name}\",\n",
|
||||
" self._sig(sub),\n",
|
||||
" sub.lineno))\n",
|
||||
" return sorted(out, key=lambda s: (s.kind, s.name.lower(), s.lineno))\n",
|
||||
"\n",
|
||||
" def _sig(self, fn: ast.FunctionDef) -> str:\n",
|
||||
" args = [a.arg for a in fn.args.args]\n",
|
||||
" if fn.args.vararg:\n",
|
||||
" args.append(\"*\" + fn.args.vararg.arg)\n",
|
||||
" args.extend(a.arg + \"=?\" for a in fn.args.kwonlyargs)\n",
|
||||
" if fn.args.kwarg:\n",
|
||||
" args.append(\"**\" + fn.args.kwarg.arg)\n",
|
||||
" ret = \"\"\n",
|
||||
" if fn.returns is not None:\n",
|
||||
" try:\n",
|
||||
" ret = f\" -> {ast.unparse(fn.returns)}\"\n",
|
||||
" except Exception:\n",
|
||||
" pass\n",
|
||||
" return f\"def {fn.name}({', '.join(args)}){ret}:\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aeadedc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class PromptBuilder:\n",
|
||||
" \"\"\"Builds deterministic prompts for pytest generation.\"\"\"\n",
|
||||
" SYSTEM = (\n",
|
||||
" \"You are a senior Python engineer. Produce a single, self-contained pytest file.\\n\"\n",
|
||||
" \"Rules:\\n\"\n",
|
||||
" \"- Output only Python test code (no prose, no markdown fences).\\n\"\n",
|
||||
" \"- Use plain pytest tests (functions), no classes unless unavoidable.\\n\"\n",
|
||||
" \"- Deterministic: avoid network/IO; seed randomness if used.\\n\"\n",
|
||||
" \"- Import the target module by module name only.\\n\"\n",
|
||||
" \"- Cover every public function and method with at least one tiny test.\\n\"\n",
|
||||
" \"- Prefer straightforward, fast assertions.\\n\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def build_user(self, *, module_name: str, source: str, symbols: List[SymbolInfo]) -> str:\n",
|
||||
" summary = \"\\n\".join(f\"- {s.kind:<6} {s.signature}\" for s in symbols) or \"- (no public symbols)\"\n",
|
||||
" return textwrap.dedent(f\"\"\"\n",
|
||||
" Create pytest tests for module `{module_name}`.\n",
|
||||
"\n",
|
||||
" Public API Summary:\n",
|
||||
" {summary}\n",
|
||||
"\n",
|
||||
" Constraints:\n",
|
||||
" - Import as: `import {module_name} as mod`\n",
|
||||
" - Keep tests tiny, fast, and deterministic.\n",
|
||||
"\n",
|
||||
" Full module source (for reference):\n",
|
||||
" # --- BEGIN SOURCE {module_name}.py ---\n",
|
||||
" {source}\n",
|
||||
" # --- END SOURCE ---\n",
|
||||
" \"\"\").strip()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a45ac5be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _ensure_header_and_import(code: str, module_name: str) -> str:\n",
|
||||
" \"\"\"Ensure tests import pytest and the target module as 'mod'.\"\"\"\n",
|
||||
" code = code.strip()\n",
|
||||
" needs_pytest = \"import pytest\" not in code\n",
|
||||
" has_mod = (f\"import {module_name} as mod\" in code) or (f\"from {module_name} import\" in code)\n",
|
||||
" needs_import = not has_mod\n",
|
||||
"\n",
|
||||
" header = []\n",
|
||||
" if needs_pytest:\n",
|
||||
" header.append(\"import pytest\")\n",
|
||||
" if needs_import:\n",
|
||||
" header.append(f\"import {module_name} as mod\")\n",
|
||||
"\n",
|
||||
" return (\"\\n\".join(header) + \"\\n\\n\" + code) if header else code\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def build_module_name_from_path(path: str) -> str:\n",
|
||||
" return Path(path).stem\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "787e58b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TestGenerator:\n",
|
||||
" \"\"\"Extraction → prompt → model → polish.\"\"\"\n",
|
||||
" def __init__(self, llm: CompletionClient):\n",
|
||||
" self._llm = llm\n",
|
||||
" self._extractor = PublicAPIExtractor()\n",
|
||||
" self._prompts = PromptBuilder()\n",
|
||||
"\n",
|
||||
" def generate_tests(self, model_label: str, module_name: str, source: str) -> str:\n",
|
||||
" symbols = self._extractor.extract(source)\n",
|
||||
" user = self._prompts.build_user(module_name=module_name, source=source, symbols=symbols)\n",
|
||||
" raw = self._llm.complete(model_label=model_label, system=self._prompts.SYSTEM, user=user)\n",
|
||||
" return _ensure_header_and_import(raw, module_name)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8402f62f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _parse_pytest_summary(output: str) -> Tuple[str, Dict[str, int]]:\n",
|
||||
" \"\"\"\n",
|
||||
" Parse the final summary line like:\n",
|
||||
" '3 passed, 1 failed, 2 skipped in 0.12s'\n",
|
||||
" Return (summary_line, counts_dict).\n",
|
||||
" \"\"\"\n",
|
||||
" summary_line = \"\"\n",
|
||||
" for line in output.strip().splitlines()[::-1]: # scan from end\n",
|
||||
" if \" passed\" in line or \" failed\" in line or \" error\" in line or \" skipped\" in line or \" deselected\" in line:\n",
|
||||
" summary_line = line.strip()\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" counts = {\"passed\": 0, \"failed\": 0, \"errors\": 0, \"skipped\": 0, \"xfail\": 0, \"xpassed\": 0}\n",
|
||||
" m = re.findall(r\"(\\d+)\\s+(passed|failed|errors?|skipped|xfailed|xpassed)\", summary_line)\n",
|
||||
" for num, kind in m:\n",
|
||||
" if kind.startswith(\"error\"):\n",
|
||||
" counts[\"errors\"] += int(num)\n",
|
||||
" elif kind == \"passed\":\n",
|
||||
" counts[\"passed\"] += int(num)\n",
|
||||
" elif kind == \"failed\":\n",
|
||||
" counts[\"failed\"] += int(num)\n",
|
||||
" elif kind == \"skipped\":\n",
|
||||
" counts[\"skipped\"] += int(num)\n",
|
||||
" elif kind == \"xfailed\":\n",
|
||||
" counts[\"xfail\"] += int(num)\n",
|
||||
" elif kind == \"xpassed\":\n",
|
||||
" counts[\"xpassed\"] += int(num)\n",
|
||||
"\n",
|
||||
" return summary_line or \"(no summary line found)\", counts\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def run_pytest_on_snippet(module_name: str, module_code: str, tests_code: str) -> Tuple[str, str]:\n",
|
||||
" \"\"\"\n",
|
||||
" Create an isolated temp workspace, write module + tests, run pytest,\n",
|
||||
" and return (human_summary, full_cli_output).\n",
|
||||
" \"\"\"\n",
|
||||
" if not module_name or not module_code.strip() or not tests_code.strip():\n",
|
||||
" return \"❌ Provide module name, module code, and tests.\", \"\"\n",
|
||||
"\n",
|
||||
" run_id = uuid.uuid4().hex[:8]\n",
|
||||
" base = Path(\".pytest_runs\") / f\"run_{run_id}\"\n",
|
||||
" tests_dir = base / \"tests\"\n",
|
||||
" tests_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
" # Write module and tests\n",
|
||||
" (base / f\"{module_name}.py\").write_text(module_code, encoding=\"utf-8\")\n",
|
||||
" (tests_dir / f\"test_{module_name}.py\").write_text(tests_code, encoding=\"utf-8\")\n",
|
||||
"\n",
|
||||
" # Run pytest with this temp dir on PYTHONPATH\n",
|
||||
" env = os.environ.copy()\n",
|
||||
" env[\"PYTHONPATH\"] = str(base) + os.pathsep + env.get(\"PYTHONPATH\", \"\")\n",
|
||||
" cmd = [sys.executable, \"-m\", \"pytest\", \"-q\"] # quiet output, but still includes summary\n",
|
||||
" proc = subprocess.run(cmd, cwd=base, env=env, text=True, capture_output=True)\n",
|
||||
"\n",
|
||||
" full_out = (proc.stdout or \"\") + (\"\\n\" + proc.stderr if proc.stderr else \"\")\n",
|
||||
" summary_line, counts = _parse_pytest_summary(full_out)\n",
|
||||
"\n",
|
||||
" badges = []\n",
|
||||
" for key in (\"passed\", \"failed\", \"errors\", \"skipped\", \"xpassed\", \"xfail\"):\n",
|
||||
" val = counts.get(key, 0)\n",
|
||||
" if val:\n",
|
||||
" badges.append(f\"**{key}: {val}**\")\n",
|
||||
" badges = \" • \".join(badges) if badges else \"no tests collected?\"\n",
|
||||
"\n",
|
||||
" human = f\"{summary_line}\\n\\n{badges}\"\n",
|
||||
" return human, full_out\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5d240ce5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"LLM = MultiModelChatClient(MODEL_REGISTRY)\n",
|
||||
"SERVICE = TestGenerator(LLM)\n",
|
||||
"\n",
|
||||
"def generate_from_code(model_label: str, module_name: str, code: str, save: bool, out_dir: str) -> Tuple[str, str]:\n",
|
||||
" if not model_label or model_label not in MODEL_REGISTRY:\n",
|
||||
" return \"\", \"❌ Pick a model (or add API keys for providers in .env).\"\n",
|
||||
" if not module_name.strip():\n",
|
||||
" return \"\", \"❌ Please provide a module name.\"\n",
|
||||
" if not code.strip():\n",
|
||||
" return \"\", \"❌ Please paste some Python code.\"\n",
|
||||
"\n",
|
||||
" tests_code = SERVICE.generate_tests(model_label=model_label, module_name=module_name.strip(), source=code)\n",
|
||||
" saved = \"\"\n",
|
||||
" if save:\n",
|
||||
" out = Path(out_dir or \"tests\")\n",
|
||||
" out.mkdir(parents=True, exist_ok=True)\n",
|
||||
" out_path = out / f\"test_{module_name}.py\"\n",
|
||||
" out_path.write_text(tests_code, encoding=\"utf-8\")\n",
|
||||
" saved = f\"✅ Saved to {out_path}\"\n",
|
||||
" return tests_code, saved\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def generate_from_file(model_label: str, file_obj, save: bool, out_dir: str) -> Tuple[str, str]:\n",
|
||||
" if file_obj is None:\n",
|
||||
" return \"\", \"❌ Please upload a .py file.\"\n",
|
||||
" code = file_obj.decode(\"utf-8\")\n",
|
||||
" module_name = build_module_name_from_path(\"uploaded_module.py\")\n",
|
||||
" return generate_from_code(model_label, module_name, code, save, out_dir)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e3e1401a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"EXAMPLE_CODE = \"\"\"\\\n",
|
||||
"def add(a: int, b: int) -> int:\n",
|
||||
" return a + b\n",
|
||||
"\n",
|
||||
"def divide(a: float, b: float) -> float:\n",
|
||||
" if b == 0:\n",
|
||||
" raise ZeroDivisionError(\"b must be non-zero\")\n",
|
||||
" return a / b\n",
|
||||
"\n",
|
||||
"class Counter:\n",
|
||||
" def __init__(self, start: int = 0):\n",
|
||||
" self.value = start\n",
|
||||
"\n",
|
||||
" def inc(self, by: int = 1):\n",
|
||||
" self.value += by\n",
|
||||
" return self.value\n",
|
||||
"\"\"\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f802450e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with gr.Blocks(title=\"PyTest Generator\") as ui:\n",
|
||||
" gr.Markdown(\n",
|
||||
" \"## 🧪 PyTest Generator (Week 4 • Community Contribution)\\n\"\n",
|
||||
" \"Generate **minimal, deterministic** pytest tests from a Python module using your chosen model/provider.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" with gr.Row(equal_height=True):\n",
|
||||
" # LEFT: inputs (module code)\n",
|
||||
" with gr.Column(scale=6):\n",
|
||||
" with gr.Row():\n",
|
||||
" model_dd = gr.Dropdown(\n",
|
||||
" list(MODEL_REGISTRY.keys()),\n",
|
||||
" value=DEFAULT_MODEL,\n",
|
||||
" label=\"Model (OpenAI, Gemini, Groq)\"\n",
|
||||
" )\n",
|
||||
" module_name_tb = gr.Textbox(\n",
|
||||
" label=\"Module name (used in `import <name> as mod`)\",\n",
|
||||
" value=\"mymodule\"\n",
|
||||
" )\n",
|
||||
" code_in = gr.Code(\n",
|
||||
" label=\"Python module code\",\n",
|
||||
" language=\"python\",\n",
|
||||
" lines=24,\n",
|
||||
" value=EXAMPLE_CODE\n",
|
||||
" )\n",
|
||||
" with gr.Row():\n",
|
||||
" save_cb = gr.Checkbox(label=\"Also save generated tests to /tests\", value=True)\n",
|
||||
" out_dir_tb = gr.Textbox(label=\"Output folder\", value=\"tests\")\n",
|
||||
" gen_btn = gr.Button(\"Generate tests\", variant=\"primary\")\n",
|
||||
"\n",
|
||||
" # RIGHT: outputs (generated tests + pytest run)\n",
|
||||
" with gr.Column(scale=6):\n",
|
||||
" tests_out = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=24)\n",
|
||||
" with gr.Row():\n",
|
||||
" run_btn = gr.Button(\"Run PyTest\", variant=\"secondary\")\n",
|
||||
" summary_md = gr.Markdown()\n",
|
||||
" full_out = gr.Textbox(label=\"Full PyTest output\", lines=12)\n",
|
||||
"\n",
|
||||
" # --- events ---\n",
|
||||
"\n",
|
||||
" def _on_gen(model_label, name, code, save, outdir):\n",
|
||||
" tests, msg = generate_from_code(model_label, name, code, save, outdir)\n",
|
||||
" status = msg or \"✅ Done\"\n",
|
||||
" return tests, status\n",
|
||||
"\n",
|
||||
" gen_btn.click(\n",
|
||||
" _on_gen,\n",
|
||||
" inputs=[model_dd, module_name_tb, code_in, save_cb, out_dir_tb],\n",
|
||||
" outputs=[tests_out, summary_md],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def _on_run(name, code, tests):\n",
|
||||
" summary, details = run_pytest_on_snippet(name, code, tests)\n",
|
||||
" return summary, details\n",
|
||||
"\n",
|
||||
" run_btn.click(\n",
|
||||
" _on_run,\n",
|
||||
" inputs=[module_name_tb, code_in, tests_out],\n",
|
||||
" outputs=[summary_md, full_out],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"ui.launch(inbrowser=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "llm-engineering",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user