Add multiple LLM provider support, add pytest run.

This commit is contained in:
Nik
2025-10-26 20:15:47 +05:30
parent 74e48bba2c
commit b4327b6e84

View File

@@ -10,20 +10,63 @@
"import os\n",
"import re\n",
"import ast\n",
"import sys\n",
"import uuid\n",
"import json\n",
"import textwrap\n",
"import subprocess\n",
"from pathlib import Path\n",
"from dataclasses import dataclass\n",
"from typing import List, Protocol, Optional\n",
"from openai import BadRequestError as _OpenAIBadRequest\n",
"from typing import List, Protocol, Tuple, Dict, Optional\n",
"\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"from openai import BadRequestError as _OpenAIBadRequest\n",
"import gradio as gr\n",
"\n",
"\n",
"load_dotenv(override=True)\n",
"\n",
"TESTGEN_MODEL = os.getenv(\"TESTGEN_MODEL\", \"gpt-5-nano\")\n",
"OPENAI_CLIENT = OpenAI()"
"# --- Provider base URLs (Gemini & Groq speak OpenAI-compatible API) ---\n",
"GEMINI_BASE = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
"GROQ_BASE = \"https://api.groq.com/openai/v1\"\n",
"\n",
"# --- API Keys (add these in your .env) ---\n",
"openai_api_key = os.getenv(\"OPENAI_API_KEY\") # OpenAI\n",
"google_api_key = os.getenv(\"GOOGLE_API_KEY\") # Gemini\n",
"groq_api_key = os.getenv(\"GROQ_API_KEY\") # Groq\n",
"\n",
"# --- Clients ---\n",
"openai_client = OpenAI() # OpenAI default (reads OPENAI_API_KEY)\n",
"gemini_client = OpenAI(api_key=google_api_key, base_url=GEMINI_BASE) if google_api_key else None\n",
"groq_client = OpenAI(api_key=groq_api_key, base_url=GROQ_BASE) if groq_api_key else None\n",
"\n",
"# --- Model registry: label -> { client, model } ---\n",
"MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}\n",
"\n",
"def _register(label: str, client: Optional[OpenAI], model_id: str):\n",
" \"\"\"Add a model to the registry only if its client is configured.\"\"\"\n",
" if client is not None:\n",
" MODEL_REGISTRY[label] = {\"client\": client, \"model\": model_id}\n",
"\n",
"# OpenAI\n",
"_register(\"OpenAI • GPT-5\", openai_client, \"gpt-5\")\n",
"_register(\"OpenAI • GPT-5 Nano\", openai_client, \"gpt-5-nano\")\n",
"_register(\"OpenAI • GPT-4o-mini\", openai_client, \"gpt-4o-mini\")\n",
"\n",
"# Gemini (Google)\n",
"_register(\"Gemini • 2.5 Pro\", gemini_client, \"gemini-2.5-pro\")\n",
"_register(\"Gemini • 2.5 Flash\", gemini_client, \"gemini-2.5-flash\")\n",
"\n",
"# Groq\n",
"_register(\"Groq • Llama 3.1 8B\", groq_client, \"llama-3.1-8b-instant\")\n",
"_register(\"Groq • Llama 3.1 70B\", groq_client, \"llama-3.1-70b-versatile\")\n",
"_register(\"Groq • GPT-OSS 20B\", groq_client, \"gpt-oss-20b\")\n",
"_register(\"Groq • GPT-OSS 120B\", groq_client, \"gpt-oss-120b\")\n",
"\n",
"DEFAULT_MODEL = next(iter(MODEL_REGISTRY.keys()), None)\n",
"\n",
"print(f\"Providers configured → OpenAI:{bool(openai_api_key)} Gemini:{bool(google_api_key)} Groq:{bool(groq_api_key)}\")\n",
"print(\"Models available →\", \", \".join(MODEL_REGISTRY.keys()) or \"None (add API keys in .env)\")\n"
]
},
{
@@ -33,39 +76,46 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"class CompletionClient(Protocol):\n",
" \"\"\"Any LLM client provides a .complete() method.\"\"\"\n",
" def complete(self, *, model: str, system: str, user: str) -> str: ...\n",
" \"\"\"Any LLM client provides a .complete() method using a registry label.\"\"\"\n",
" def complete(self, *, model_label: str, system: str, user: str) -> str: ...\n",
"\n",
"\n",
"class OpenAIChatClient:\n",
" \"\"\"Adapter over OpenAI chat completions that omits unsupported params.\"\"\"\n",
" def __init__(self, client: OpenAI):\n",
" self._client = client\n",
"def _extract_code_or_text(s: str) -> str:\n",
" \"\"\"Prefer fenced python if present; otherwise return raw text.\"\"\"\n",
" m = re.search(r\"```(?:python)?\\s*(.*?)```\", s, flags=re.S | re.I)\n",
" return m.group(1).strip() if m else s.strip()\n",
"\n",
" def _call(self, *, model: str, system: str, user: str, include_temperature: bool = False) -> str:\n",
"\n",
"class MultiModelChatClient:\n",
" \"\"\"Routes requests to the right provider/client based on model label.\"\"\"\n",
" def __init__(self, registry: Dict[str, Dict[str, object]]):\n",
" self._registry = registry\n",
"\n",
" def _call(self, *, client: OpenAI, model_id: str, system: str, user: str) -> str:\n",
" params = {\n",
" \"model\": model,\n",
" \"model\": model_id,\n",
" \"messages\": [\n",
" {\"role\": \"system\", \"content\": system},\n",
" {\"role\": \"user\", \"content\": user},\n",
" {\"role\": \"user\", \"content\": user},\n",
" ],\n",
" }\n",
" if include_temperature: \n",
" params[\"temperature\"] = 0.7 \n",
"\n",
" resp = self._client.chat.completions.create(**params)\n",
" resp = client.chat.completions.create(**params) # do NOT send temperature for strict providers\n",
" text = (resp.choices[0].message.content or \"\").strip()\n",
" return _extract_code_or_text(text)\n",
"\n",
" def complete(self, *, model: str, system: str, user: str) -> str:\n",
" def complete(self, *, model_label: str, system: str, user: str) -> str:\n",
" if model_label not in self._registry:\n",
" raise ValueError(f\"Unknown model label: {model_label}\")\n",
" info = self._registry[model_label]\n",
" client = info[\"client\"]\n",
" model = info[\"model\"]\n",
" try:\n",
" return self._call(model=model, system=system, user=user, include_temperature=False)\n",
" return self._call(client=client, model_id=str(model), system=system, user=user)\n",
" except _OpenAIBadRequest as e:\n",
" # Extra safety: if some lib auto-injected temperature, retry without it\n",
" # Providers may reject stray params; we don't send any, but retry anyway.\n",
" if \"temperature\" in str(e).lower():\n",
" return self._call(model=model, system=system, user=user, include_temperature=False)\n",
" return self._call(client=client, model_id=str(model), system=system, user=user)\n",
" raise\n"
]
},
@@ -197,16 +247,15 @@
"source": [
"class TestGenerator:\n",
" \"\"\"Orchestrates extraction, prompt, model call, and final polish.\"\"\"\n",
" def __init__(self, llm: CompletionClient, model: str):\n",
" def __init__(self, llm: CompletionClient):\n",
" self._llm = llm\n",
" self._model = model\n",
" self._extractor = PublicAPIExtractor()\n",
" self._prompts = PromptBuilder()\n",
"\n",
" def generate_tests(self, module_name: str, source: str) -> str:\n",
" def generate_tests(self, model_label: str, module_name: str, source: str) -> str:\n",
" symbols = self._extractor.extract(source)\n",
" user = self._prompts.build_user(module_name=module_name, source=source, symbols=symbols)\n",
" raw = self._llm.complete(model=self._model, system=self._prompts.SYSTEM, user=user)\n",
" raw = self._llm.complete(model_label=model_label, system=self._prompts.SYSTEM, user=user)\n",
" return _ensure_header_and_import(raw, module_name)\n"
]
},