From b4327b6e84d16bb5de9c6a1b2adce49412d449e0 Mon Sep 17 00:00:00 2001 From: Nik Date: Sun, 26 Oct 2025 20:15:47 +0530 Subject: [PATCH] Add multiple LLM provider support, add pytest run. --- .../pytest_generator/pytest_generator.ipynb | 103 +++++++++++++----- 1 file changed, 76 insertions(+), 27 deletions(-) diff --git a/week4/community-contributions/pytest_generator/pytest_generator.ipynb b/week4/community-contributions/pytest_generator/pytest_generator.ipynb index 84d568b..8d63b20 100644 --- a/week4/community-contributions/pytest_generator/pytest_generator.ipynb +++ b/week4/community-contributions/pytest_generator/pytest_generator.ipynb @@ -10,20 +10,63 @@ "import os\n", "import re\n", "import ast\n", + "import sys\n", + "import uuid\n", + "import json\n", "import textwrap\n", + "import subprocess\n", "from pathlib import Path\n", "from dataclasses import dataclass\n", - "from typing import List, Protocol, Optional\n", - "from openai import BadRequestError as _OpenAIBadRequest\n", + "from typing import List, Protocol, Tuple, Dict, Optional\n", + "\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", + "from openai import BadRequestError as _OpenAIBadRequest\n", "import gradio as gr\n", "\n", - "\n", "load_dotenv(override=True)\n", "\n", - "TESTGEN_MODEL = os.getenv(\"TESTGEN_MODEL\", \"gpt-5-nano\")\n", - "OPENAI_CLIENT = OpenAI()" + "# --- Provider base URLs (Gemini & Groq speak OpenAI-compatible API) ---\n", + "GEMINI_BASE = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n", + "GROQ_BASE = \"https://api.groq.com/openai/v1\"\n", + "\n", + "# --- API Keys (add these in your .env) ---\n", + "openai_api_key = os.getenv(\"OPENAI_API_KEY\") # OpenAI\n", + "google_api_key = os.getenv(\"GOOGLE_API_KEY\") # Gemini\n", + "groq_api_key = os.getenv(\"GROQ_API_KEY\") # Groq\n", + "\n", + "# --- Clients ---\n", + "openai_client = OpenAI() # OpenAI default (reads OPENAI_API_KEY)\n", + "gemini_client = OpenAI(api_key=google_api_key, base_url=GEMINI_BASE) if google_api_key else None\n", + "groq_client = OpenAI(api_key=groq_api_key, base_url=GROQ_BASE) if groq_api_key else None\n", + "\n", + "# --- Model registry: label -> { client, model } ---\n", + "MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}\n", + "\n", + "def _register(label: str, client: Optional[OpenAI], model_id: str):\n", + " \"\"\"Add a model to the registry only if its client is configured.\"\"\"\n", + " if client is not None:\n", + " MODEL_REGISTRY[label] = {\"client\": client, \"model\": model_id}\n", + "\n", + "# OpenAI\n", + "_register(\"OpenAI • GPT-5\", openai_client, \"gpt-5\")\n", + "_register(\"OpenAI • GPT-5 Nano\", openai_client, \"gpt-5-nano\")\n", + "_register(\"OpenAI • GPT-4o-mini\", openai_client, \"gpt-4o-mini\")\n", + "\n", + "# Gemini (Google)\n", + "_register(\"Gemini • 2.5 Pro\", gemini_client, \"gemini-2.5-pro\")\n", + "_register(\"Gemini • 2.5 Flash\", gemini_client, \"gemini-2.5-flash\")\n", + "\n", + "# Groq\n", + "_register(\"Groq • Llama 3.1 8B\", groq_client, \"llama-3.1-8b-instant\")\n", + "_register(\"Groq • Llama 3.1 70B\", groq_client, \"llama-3.1-70b-versatile\")\n", + "_register(\"Groq • GPT-OSS 20B\", groq_client, \"gpt-oss-20b\")\n", + "_register(\"Groq • GPT-OSS 120B\", groq_client, \"gpt-oss-120b\")\n", + "\n", + "DEFAULT_MODEL = next(iter(MODEL_REGISTRY.keys()), None)\n", + "\n", + "print(f\"Providers configured → OpenAI:{bool(openai_api_key)} Gemini:{bool(google_api_key)} Groq:{bool(groq_api_key)}\")\n", + "print(\"Models available →\", \", \".join(MODEL_REGISTRY.keys()) or \"None (add API keys in .env)\")\n" ] }, { @@ -33,39 +76,46 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "class CompletionClient(Protocol):\n", - " \"\"\"Any LLM client provides a .complete() method.\"\"\"\n", - " def complete(self, *, model: str, system: str, user: str) -> str: ...\n", + " \"\"\"Any LLM client provides a .complete() method using a registry label.\"\"\"\n", + " def complete(self, *, model_label: str, system: str, user: str) -> str: ...\n", "\n", "\n", - "class OpenAIChatClient:\n", - " \"\"\"Adapter over OpenAI chat completions that omits unsupported params.\"\"\"\n", - " def __init__(self, client: OpenAI):\n", - " self._client = client\n", + "def _extract_code_or_text(s: str) -> str:\n", + " \"\"\"Prefer fenced python if present; otherwise return raw text.\"\"\"\n", + " m = re.search(r\"```(?:python)?\\s*(.*?)```\", s, flags=re.S | re.I)\n", + " return m.group(1).strip() if m else s.strip()\n", "\n", - " def _call(self, *, model: str, system: str, user: str, include_temperature: bool = False) -> str:\n", + "\n", + "class MultiModelChatClient:\n", + " \"\"\"Routes requests to the right provider/client based on model label.\"\"\"\n", + " def __init__(self, registry: Dict[str, Dict[str, object]]):\n", + " self._registry = registry\n", + "\n", + " def _call(self, *, client: OpenAI, model_id: str, system: str, user: str) -> str:\n", " params = {\n", - " \"model\": model,\n", + " \"model\": model_id,\n", " \"messages\": [\n", " {\"role\": \"system\", \"content\": system},\n", - " {\"role\": \"user\", \"content\": user},\n", + " {\"role\": \"user\", \"content\": user},\n", " ],\n", " }\n", - " if include_temperature: \n", - " params[\"temperature\"] = 0.7 \n", - "\n", - " resp = self._client.chat.completions.create(**params)\n", + " resp = client.chat.completions.create(**params) # do NOT send temperature for strict providers\n", " text = (resp.choices[0].message.content or \"\").strip()\n", " return _extract_code_or_text(text)\n", "\n", - " def complete(self, *, model: str, system: str, user: str) -> str:\n", + " def complete(self, *, model_label: str, system: str, user: str) -> str:\n", + " if model_label not in self._registry:\n", + " raise ValueError(f\"Unknown model label: {model_label}\")\n", + " info = self._registry[model_label]\n", + " client = info[\"client\"]\n", + " model = info[\"model\"]\n", " try:\n", - " return self._call(model=model, system=system, user=user, include_temperature=False)\n", + " return self._call(client=client, model_id=str(model), system=system, user=user)\n", " except _OpenAIBadRequest as e:\n", - " # Extra safety: if some lib auto-injected temperature, retry without it\n", + " # Providers may reject stray params; we don't send any, but retry anyway.\n", " if \"temperature\" in str(e).lower():\n", - " return self._call(model=model, system=system, user=user, include_temperature=False)\n", + " return self._call(client=client, model_id=str(model), system=system, user=user)\n", " raise\n" ] }, @@ -197,16 +247,15 @@ "source": [ "class TestGenerator:\n", " \"\"\"Orchestrates extraction, prompt, model call, and final polish.\"\"\"\n", - " def __init__(self, llm: CompletionClient, model: str):\n", + " def __init__(self, llm: CompletionClient):\n", " self._llm = llm\n", - " self._model = model\n", " self._extractor = PublicAPIExtractor()\n", " self._prompts = PromptBuilder()\n", "\n", - " def generate_tests(self, module_name: str, source: str) -> str:\n", + " def generate_tests(self, model_label: str, module_name: str, source: str) -> str:\n", " symbols = self._extractor.extract(source)\n", " user = self._prompts.build_user(module_name=module_name, source=source, symbols=symbols)\n", - " raw = self._llm.complete(model=self._model, system=self._prompts.SYSTEM, user=user)\n", + " raw = self._llm.complete(model_label=model_label, system=self._prompts.SYSTEM, user=user)\n", " return _ensure_header_and_import(raw, module_name)\n" ] },