Add multiple LLM provider support, add pytest run.
This commit is contained in:
@@ -10,20 +10,63 @@
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import ast\n",
|
||||
"import sys\n",
|
||||
"import uuid\n",
|
||||
"import json\n",
|
||||
"import textwrap\n",
|
||||
"import subprocess\n",
|
||||
"from pathlib import Path\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import List, Protocol, Optional\n",
|
||||
"from openai import BadRequestError as _OpenAIBadRequest\n",
|
||||
"from typing import List, Protocol, Tuple, Dict, Optional\n",
|
||||
"\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from openai import OpenAI\n",
|
||||
"from openai import BadRequestError as _OpenAIBadRequest\n",
|
||||
"import gradio as gr\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"\n",
|
||||
"TESTGEN_MODEL = os.getenv(\"TESTGEN_MODEL\", \"gpt-5-nano\")\n",
|
||||
"OPENAI_CLIENT = OpenAI()"
|
||||
"# --- Provider base URLs (Gemini & Groq speak OpenAI-compatible API) ---\n",
|
||||
"GEMINI_BASE = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
|
||||
"GROQ_BASE = \"https://api.groq.com/openai/v1\"\n",
|
||||
"\n",
|
||||
"# --- API Keys (add these in your .env) ---\n",
|
||||
"openai_api_key = os.getenv(\"OPENAI_API_KEY\") # OpenAI\n",
|
||||
"google_api_key = os.getenv(\"GOOGLE_API_KEY\") # Gemini\n",
|
||||
"groq_api_key = os.getenv(\"GROQ_API_KEY\") # Groq\n",
|
||||
"\n",
|
||||
"# --- Clients ---\n",
|
||||
"openai_client = OpenAI() # OpenAI default (reads OPENAI_API_KEY)\n",
|
||||
"gemini_client = OpenAI(api_key=google_api_key, base_url=GEMINI_BASE) if google_api_key else None\n",
|
||||
"groq_client = OpenAI(api_key=groq_api_key, base_url=GROQ_BASE) if groq_api_key else None\n",
|
||||
"\n",
|
||||
"# --- Model registry: label -> { client, model } ---\n",
|
||||
"MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}\n",
|
||||
"\n",
|
||||
"def _register(label: str, client: Optional[OpenAI], model_id: str):\n",
|
||||
" \"\"\"Add a model to the registry only if its client is configured.\"\"\"\n",
|
||||
" if client is not None:\n",
|
||||
" MODEL_REGISTRY[label] = {\"client\": client, \"model\": model_id}\n",
|
||||
"\n",
|
||||
"# OpenAI\n",
|
||||
"_register(\"OpenAI • GPT-5\", openai_client, \"gpt-5\")\n",
|
||||
"_register(\"OpenAI • GPT-5 Nano\", openai_client, \"gpt-5-nano\")\n",
|
||||
"_register(\"OpenAI • GPT-4o-mini\", openai_client, \"gpt-4o-mini\")\n",
|
||||
"\n",
|
||||
"# Gemini (Google)\n",
|
||||
"_register(\"Gemini • 2.5 Pro\", gemini_client, \"gemini-2.5-pro\")\n",
|
||||
"_register(\"Gemini • 2.5 Flash\", gemini_client, \"gemini-2.5-flash\")\n",
|
||||
"\n",
|
||||
"# Groq\n",
|
||||
"_register(\"Groq • Llama 3.1 8B\", groq_client, \"llama-3.1-8b-instant\")\n",
|
||||
"_register(\"Groq • Llama 3.1 70B\", groq_client, \"llama-3.1-70b-versatile\")\n",
|
||||
"_register(\"Groq • GPT-OSS 20B\", groq_client, \"gpt-oss-20b\")\n",
|
||||
"_register(\"Groq • GPT-OSS 120B\", groq_client, \"gpt-oss-120b\")\n",
|
||||
"\n",
|
||||
"DEFAULT_MODEL = next(iter(MODEL_REGISTRY.keys()), None)\n",
|
||||
"\n",
|
||||
"print(f\"Providers configured → OpenAI:{bool(openai_api_key)} Gemini:{bool(google_api_key)} Groq:{bool(groq_api_key)}\")\n",
|
||||
"print(\"Models available →\", \", \".join(MODEL_REGISTRY.keys()) or \"None (add API keys in .env)\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -33,39 +76,46 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"class CompletionClient(Protocol):\n",
|
||||
" \"\"\"Any LLM client provides a .complete() method.\"\"\"\n",
|
||||
" def complete(self, *, model: str, system: str, user: str) -> str: ...\n",
|
||||
" \"\"\"Any LLM client provides a .complete() method using a registry label.\"\"\"\n",
|
||||
" def complete(self, *, model_label: str, system: str, user: str) -> str: ...\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class OpenAIChatClient:\n",
|
||||
" \"\"\"Adapter over OpenAI chat completions that omits unsupported params.\"\"\"\n",
|
||||
" def __init__(self, client: OpenAI):\n",
|
||||
" self._client = client\n",
|
||||
"def _extract_code_or_text(s: str) -> str:\n",
|
||||
" \"\"\"Prefer fenced python if present; otherwise return raw text.\"\"\"\n",
|
||||
" m = re.search(r\"```(?:python)?\\s*(.*?)```\", s, flags=re.S | re.I)\n",
|
||||
" return m.group(1).strip() if m else s.strip()\n",
|
||||
"\n",
|
||||
" def _call(self, *, model: str, system: str, user: str, include_temperature: bool = False) -> str:\n",
|
||||
"\n",
|
||||
"class MultiModelChatClient:\n",
|
||||
" \"\"\"Routes requests to the right provider/client based on model label.\"\"\"\n",
|
||||
" def __init__(self, registry: Dict[str, Dict[str, object]]):\n",
|
||||
" self._registry = registry\n",
|
||||
"\n",
|
||||
" def _call(self, *, client: OpenAI, model_id: str, system: str, user: str) -> str:\n",
|
||||
" params = {\n",
|
||||
" \"model\": model,\n",
|
||||
" \"model\": model_id,\n",
|
||||
" \"messages\": [\n",
|
||||
" {\"role\": \"system\", \"content\": system},\n",
|
||||
" {\"role\": \"user\", \"content\": user},\n",
|
||||
" {\"role\": \"user\", \"content\": user},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
" if include_temperature: \n",
|
||||
" params[\"temperature\"] = 0.7 \n",
|
||||
"\n",
|
||||
" resp = self._client.chat.completions.create(**params)\n",
|
||||
" resp = client.chat.completions.create(**params) # do NOT send temperature for strict providers\n",
|
||||
" text = (resp.choices[0].message.content or \"\").strip()\n",
|
||||
" return _extract_code_or_text(text)\n",
|
||||
"\n",
|
||||
" def complete(self, *, model: str, system: str, user: str) -> str:\n",
|
||||
" def complete(self, *, model_label: str, system: str, user: str) -> str:\n",
|
||||
" if model_label not in self._registry:\n",
|
||||
" raise ValueError(f\"Unknown model label: {model_label}\")\n",
|
||||
" info = self._registry[model_label]\n",
|
||||
" client = info[\"client\"]\n",
|
||||
" model = info[\"model\"]\n",
|
||||
" try:\n",
|
||||
" return self._call(model=model, system=system, user=user, include_temperature=False)\n",
|
||||
" return self._call(client=client, model_id=str(model), system=system, user=user)\n",
|
||||
" except _OpenAIBadRequest as e:\n",
|
||||
" # Extra safety: if some lib auto-injected temperature, retry without it\n",
|
||||
" # Providers may reject stray params; we don't send any, but retry anyway.\n",
|
||||
" if \"temperature\" in str(e).lower():\n",
|
||||
" return self._call(model=model, system=system, user=user, include_temperature=False)\n",
|
||||
" return self._call(client=client, model_id=str(model), system=system, user=user)\n",
|
||||
" raise\n"
|
||||
]
|
||||
},
|
||||
@@ -197,16 +247,15 @@
|
||||
"source": [
|
||||
"class TestGenerator:\n",
|
||||
" \"\"\"Orchestrates extraction, prompt, model call, and final polish.\"\"\"\n",
|
||||
" def __init__(self, llm: CompletionClient, model: str):\n",
|
||||
" def __init__(self, llm: CompletionClient):\n",
|
||||
" self._llm = llm\n",
|
||||
" self._model = model\n",
|
||||
" self._extractor = PublicAPIExtractor()\n",
|
||||
" self._prompts = PromptBuilder()\n",
|
||||
"\n",
|
||||
" def generate_tests(self, module_name: str, source: str) -> str:\n",
|
||||
" def generate_tests(self, model_label: str, module_name: str, source: str) -> str:\n",
|
||||
" symbols = self._extractor.extract(source)\n",
|
||||
" user = self._prompts.build_user(module_name=module_name, source=source, symbols=symbols)\n",
|
||||
" raw = self._llm.complete(model=self._model, system=self._prompts.SYSTEM, user=user)\n",
|
||||
" raw = self._llm.complete(model_label=model_label, system=self._prompts.SYSTEM, user=user)\n",
|
||||
" return _ensure_header_and_import(raw, module_name)\n"
|
||||
]
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user