{ "cells": [ { "cell_type": "markdown", "id": "56957b7f-e289-4999-8a40-ce1a8378d8cd", "metadata": {}, "source": [ "# Unit Test Generator\n", "\n", "The requirement: use a Frontier model to generate fast and repeatable unit tests for Python code.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3ef67ef0-1bda-45bb-abca-f003217602d4", "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "import os\n", "import io\n", "import sys\n", "import ast\n", "import unittest, contextlib\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", "import google.generativeai\n", "import anthropic\n", "from IPython.display import Markdown, display, update_display\n", "import gradio as gr\n", "import subprocess\n", "\n", "# environment\n", "\n", "load_dotenv(override=True)\n", "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n", "os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n", "\n", "openai = OpenAI()\n", "claude = anthropic.Anthropic()\n", "OPENAI_MODEL = \"gpt-4o\"\n", "CLAUDE_MODEL = \"claude-3-7-sonnet-20250219\"\n", "\n", "system_message = \"You are an assistant that implements unit testing for Python code. \"\n", "system_message += \"Respond only with Python code; use comments sparingly and do not provide any explanation other than occasional comments. \"\n", "\n", "def remove_main_block_from_code(code):\n", " \"\"\"\n", " Remove top-level `if __name__ == \"__main__\":` blocks from code.\n", " \"\"\"\n", " try:\n", " tree = ast.parse(code)\n", "\n", " class RemoveMain(ast.NodeTransformer):\n", " def visit_If(self, node):\n", " # check if this is: if __name__ == \"__main__\":\n", " test = node.test\n", " if (\n", " isinstance(test, ast.Compare) and\n", " isinstance(test.left, ast.Name) and\n", " test.left.id == \"__name__\" and\n", " len(test.ops) == 1 and isinstance(test.ops[0], ast.Eq) and\n", " len(test.comparators) == 1 and\n", " isinstance(test.comparators[0], ast.Constant) and\n", " test.comparators[0].value == \"__main__\"\n", " ):\n", " return None # remove this node entirely\n", " return node\n", "\n", " tree = RemoveMain().visit(tree)\n", " ast.fix_missing_locations(tree)\n", " return ast.unparse(tree) # get back code as string\n", " except Exception as e:\n", " print(\"Error removing __main__ block:\", e)\n", " return code # fallback: return original code if AST fails\n", "\n", "def user_prompt_for(python_file):\n", " if isinstance(python_file, dict): # from Gradio\n", " file_path = python_file[\"name\"]\n", " elif hasattr(python_file, \"name\"): # tempfile\n", " file_path = python_file.name\n", " else: # string path\n", " file_path = python_file\n", "\n", " with open(file_path, \"r\", encoding=\"utf-8\") as f:\n", " python_code = f.read()\n", "\n", " # strip __main__ blocks\n", " python_code = remove_main_block_from_code(python_code)\n", "\n", " user_prompt = \"Write unit tests for this Python code. \"\n", " user_prompt += \"Respond only with Python code; do not explain your work other than a few comments. \"\n", " user_prompt += \"The unit testing is done in Jupyterlab, so you should use packages that play nicely with the Jupyter kernel. \\n\\n\"\n", " user_prompt += \"Include the original Python code in your generated output so that I can run all in one fell swoop.\\n\\n\"\n", " user_prompt += python_code\n", "\n", " return user_prompt\n", "\n", "def messages_for(python_file):\n", " return [\n", " {\"role\": \"system\", \"content\": system_message},\n", " {\"role\": \"user\", \"content\": user_prompt_for(python_file)}\n", " ]\n", "\t\n", "def stream_gpt(python_file): \n", " stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python_file), stream=True)\n", " reply = \"\"\n", " for chunk in stream:\n", " fragment = chunk.choices[0].delta.content or \"\"\n", " reply += fragment\n", " yield reply.replace('```python\\n','').replace('```','')\n", "\t\t\n", "def stream_claude(python_file):\n", " result = claude.messages.stream(\n", " model=CLAUDE_MODEL,\n", " max_tokens=2000,\n", " system=system_message,\n", " messages=[{\"role\": \"user\", \"content\": user_prompt_for(python_file)}],\n", " )\n", " reply = \"\"\n", " with result as stream:\n", " for text in stream.text_stream:\n", " reply += text\n", " yield reply.replace('```python\\n','').replace('```','')\n", "\t\t\t\n", "def unit_test(python_file, model):\n", " if model==\"GPT\":\n", " result = stream_gpt(python_file)\n", " elif model==\"Claude\":\n", " result = stream_claude(python_file)\n", " else:\n", " raise ValueError(\"Unknown model\")\n", " for stream_so_far in result:\n", " yield stream_so_far\n", "\n", "def execute_python(code):\n", " buffer = io.StringIO()\n", " try:\n", " with contextlib.redirect_stdout(buffer), contextlib.redirect_stderr(buffer):\n", " # execute code in isolated namespace\n", " ns = {}\n", " exec(code, ns)\n", "\n", " # manually collect TestCase subclasses\n", " test_cases = [\n", " obj for obj in ns.values()\n", " if isinstance(obj, type) and issubclass(obj, unittest.TestCase)\n", " ]\n", " if test_cases:\n", " suite = unittest.TestSuite()\n", " for case in test_cases:\n", " suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(case))\n", " runner = unittest.TextTestRunner(stream=buffer, verbosity=2)\n", " runner.run(suite)\n", " except Exception as e:\n", " print(f\"Error during execution: {e}\", file=buffer)\n", "\n", " return buffer.getvalue()" ] }, { "cell_type": "code", "execution_count": null, "id": "670b8b78-0b13-488a-9533-59802b2fe101", "metadata": {}, "outputs": [], "source": [ "# --- Gradio UI ---\n", "with gr.Blocks() as ui:\n", " gr.Markdown(\"## Unit Test Generator\\nUpload a Python file and get structured unit testing.\")\n", " with gr.Row(): # Row 1\n", " orig_code = gr.File(label=\"Upload your Python file\", file_types=[\".py\"])\n", " test_code = gr.Textbox(label=\"Unit test code:\", lines=10)\n", " with gr.Row(): # Row 2\n", " model = gr.Dropdown([\"GPT\", \"Claude\"], label=\"Select model\", value=\"GPT\")\n", " with gr.Row(): # Row 3\n", " generate = gr.Button(\"Generate unit test code\")\n", " with gr.Row(): # Row 4\n", " unit_run = gr.Button(\"Run Python unit test\")\n", " with gr.Row(): # Row 5\n", " test_out = gr.Textbox(label=\"Unit test result:\", lines=10)\n", "\n", " generate.click(unit_test, inputs=[orig_code, model], outputs=[test_code])\n", "\n", " unit_run.click(execute_python, inputs=[test_code], outputs=[test_out])" ] }, { "cell_type": "code", "execution_count": null, "id": "609bbdc3-1e1c-4538-91dd-7d13134ab381", "metadata": {}, "outputs": [], "source": [ "ui.launch(inbrowser=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 5 }