{ "cells": [ { "cell_type": "markdown", "id": "4a6ab9a2-28a2-445d-8512-a0dc8d1b54e9", "metadata": {}, "source": [ "# Unit test Generator\n", "\n", "Create unit tests on the Python code" ] }, { "cell_type": "code", "execution_count": null, "id": "e610bf56-a46e-4aff-8de1-ab49d62b1ad3", "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "import os\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", "from google import genai\n", "from google.genai import types\n", "import anthropic\n", "import ollama\n", "import gradio as gr" ] }, { "cell_type": "code", "execution_count": null, "id": "4f672e1c-87e9-4865-b760-370fa605e614", "metadata": {}, "outputs": [], "source": [ "# environment\n", "\n", "load_dotenv(override=True)\n", "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n", "os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n", "os.environ['GOOGLE_API_KEY'] = os.getenv('GOOGLE_API_KEY', 'your-key-if-not-using-env')" ] }, { "cell_type": "code", "execution_count": null, "id": "8aa149ed-9298-4d69-8fe2-8f5de0f667da", "metadata": {}, "outputs": [], "source": [ "# initialize\n", "\n", "openai = OpenAI()\n", "claude = anthropic.Anthropic()\n", "client = genai.Client()\n", "\n", "\n", "OPENAI_MODEL = \"gpt-4o\"\n", "CLAUDE_MODEL = \"claude-sonnet-4-20250514\"\n", "GEMINI_MODEL = 'gemini-2.5-flash'\n", "LLAMA_MODEL = \"llama3.2\"\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6896636f-923e-4a2c-9d6c-fac07828a201", "metadata": {}, "outputs": [], "source": [ "system_message = \"\"\"\n", "You are an effective programming assistant specialized to generate Python code based on the inputs.\n", "Respond only with Python code; use comments sparingly and do not provide any explanation other than occasional comments.\n", "Do not include Markdown formatting (```), language tags (python), or extra text \\n.\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "8e7b3546-57aa-4c29-bc5d-f211970d04eb", "metadata": {}, "outputs": [], "source": [ "def user_prompt_for_unit_test(python):\n", " user_prompt = f\"\"\"\n", " Consider the following Python code: \\n\\n\n", " {python} \\n\\n\n", "\n", " Generate a unit test around this code and it alongside with the Python code. \\n\n", " Response rule: in your response do not include Markdown formatting (```), language tags (python), or extra text.\n", "\n", " \"\"\"\n", " return user_prompt" ] }, { "cell_type": "code", "execution_count": null, "id": "c6190659-f54c-4951-bef4-4960f8e51cc4", "metadata": {}, "outputs": [], "source": [ "def messages_for_unit_test(python):\n", " return [\n", " {\"role\": \"system\", \"content\": system_message},\n", " {\"role\": \"user\", \"content\": user_prompt_for_unit_test(python)}\n", " ]" ] }, { "cell_type": "code", "execution_count": null, "id": "c3b497b3-f569-420e-b92e-fb0f49957ce0", "metadata": {}, "outputs": [], "source": [ "python_hard = \"\"\"\n", "\n", "def lcg(seed, a=1664525, c=1013904223, m=2**32):\n", " value = seed\n", " while True:\n", " value = (a * value + c) % m\n", " yield value\n", "\n", "def max_subarray_sum(n, seed, min_val, max_val):\n", " lcg_gen = lcg(seed)\n", " random_numbers = [next(lcg_gen) % (max_val - min_val + 1) + min_val for _ in range(n)]\n", " max_sum = float('-inf')\n", " for i in range(n):\n", " current_sum = 0\n", " for j in range(i, n):\n", " current_sum += random_numbers[j]\n", " if current_sum > max_sum:\n", " max_sum = current_sum\n", " return max_sum\n", "\n", "def total_max_subarray_sum(n, initial_seed, min_val, max_val):\n", " total_sum = 0\n", " lcg_gen = lcg(initial_seed)\n", " for _ in range(20):\n", " seed = next(lcg_gen)\n", " total_sum += max_subarray_sum(n, seed, min_val, max_val)\n", " return total_sum\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "0be9f47d-5213-4700-b0e2-d444c7c738c0", "metadata": {}, "outputs": [], "source": [ "def stream_gpt(python):\n", " stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for_unit_test(python), stream=True)\n", " reply = \"\"\n", " for chunk in stream:\n", " fragment = chunk.choices[0].delta.content or \"\"\n", " reply += fragment\n", " yield reply" ] }, { "cell_type": "code", "execution_count": null, "id": "8669f56b-8314-4582-a167-78842caea131", "metadata": {}, "outputs": [], "source": [ "def stream_claude(python):\n", " result = claude.messages.stream(\n", " model=CLAUDE_MODEL,\n", " max_tokens=2000,\n", " system=system_message,\n", " messages=[{\"role\": \"user\", \"content\": user_prompt_for_unit_test(python)}],\n", " )\n", " reply = \"\"\n", " with result as stream:\n", " for text in stream.text_stream:\n", " reply += text\n", " yield reply" ] }, { "cell_type": "code", "execution_count": null, "id": "97205162", "metadata": {}, "outputs": [], "source": [ "def stream_gemini(python):\n", " response = client.models.generate_content_stream(\n", " model=GEMINI_MODEL,\n", " config=types.GenerateContentConfig(\n", " system_instruction=system_message),\n", " contents=user_prompt_for_unit_test(python)\n", " )\n", "\n", " reply = \"\"\n", " for chunk in response:\n", " fragment = chunk.text or \"\"\n", " reply += fragment\n", " yield reply\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4f94b13e", "metadata": {}, "outputs": [], "source": [ "def stream_llama_local(python):\n", " stream = ollama.chat(\n", " model='llama3.2',\n", " messages=messages_for_unit_test(python),\n", " stream=True,\n", " )\n", "\n", " reply = \"\"\n", " # Iterate through the streamed chunks and print the content\n", " for chunk in stream:\n", " #print(chunk['message']['content'], end='', flush=True)\n", " if 'content' in chunk['message']:\n", " fragment = chunk['message']['content']\n", " reply += fragment\n", " yield reply\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2f1ae8f5-16c8-40a0-aa18-63b617df078d", "metadata": {}, "outputs": [], "source": [ "def generate_unit_test(python, model):\n", " if model==\"GPT\":\n", " result = stream_gpt(python)\n", " elif model==\"Claude\":\n", " result = stream_claude(python)\n", " elif model==\"Gemini\":\n", " result = stream_gemini(python)\n", " elif model==\"Llama\":\n", " result = stream_llama_local(python)\n", "\n", " else:\n", " raise ValueError(\"Unknown model\")\n", " for stream_so_far in result:\n", " yield stream_so_far" ] }, { "cell_type": "code", "execution_count": null, "id": "f1ddb38e-6b0a-4c37-baa4-ace0b7de887a", "metadata": {}, "outputs": [], "source": [ "with gr.Blocks() as ui:\n", " with gr.Row():\n", " python = gr.Textbox(label=\"Python code:\", lines=10, value=python_hard)\n", " unit_test = gr.Textbox(label=\"Unit test\", lines=10)\n", " with gr.Row():\n", " model = gr.Dropdown([\"GPT\", \"Claude\", \"Gemini\", \"Llama\"], label=\"Select model\", value=\"GPT\")\n", " generate_ut = gr.Button(\"Generate Unit tests\")\n", "\n", " generate_ut.click(generate_unit_test, inputs=[python, model], outputs=[unit_test])\n", "\n", "ui.launch(inbrowser=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }