Merge branch 'main' of github.com:ed-donner/llm_engineering

This commit is contained in:
Edward Donner
2025-11-04 21:21:19 -05:00
4 changed files with 2098 additions and 0 deletions

View File

@@ -0,0 +1,725 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import re\n",
"from typing import List, Dict, Optional\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import gradio as gr\n",
"from IPython.display import Markdown, display\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Groq API Key not set (optional)\n",
"OpenRouter API Key loaded (begins with sk-or-)\n"
]
}
],
"source": [
"load_dotenv(override=True)\n",
"\n",
"# Ollama connection \n",
"ollama_url = \"http://localhost:11434/v1\"\n",
"ollama_client = OpenAI(api_key=\"ollama\", base_url=ollama_url)\n",
"\n",
"# Groq connection\n",
"groq_api_key = os.getenv('GROQ_API_KEY')\n",
"groq_url = \"https://api.groq.com/openai/v1\"\n",
"groq_client = None\n",
"if groq_api_key:\n",
" groq_client = OpenAI(api_key=groq_api_key, base_url=groq_url)\n",
" print(f\"Groq API Key loaded (begins with {groq_api_key[:4]})\")\n",
"else:\n",
" print(\"Groq API Key not set (optional)\")\n",
"\n",
"# OpenRouter connection\n",
"openrouter_api_key = os.getenv('OPENROUTER_API_KEY')\n",
"openrouter_url = \"https://openrouter.ai/api/v1\"\n",
"openrouter_client = None\n",
"if openrouter_api_key:\n",
" openrouter_client = OpenAI(api_key=openrouter_api_key, base_url=openrouter_url)\n",
" print(f\"OpenRouter API Key loaded (begins with {openrouter_api_key[:6]})\")\n",
"else:\n",
" print(\"OpenRouter API Key not set (optional)\")\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Configured 2 models\n",
"OpenRouter models available (perfect for limited storage demos!)\n"
]
}
],
"source": [
"# Open-source code models configuration\n",
"MODELS = {}\n",
"\n",
"if groq_client:\n",
" MODELS.update({\n",
" \"gpt-oss-20b-groq\": {\n",
" \"name\": \"GPT-OSS-20B (Groq)\",\n",
" \"client\": groq_client,\n",
" \"model\": \"gpt-oss:20b\",\n",
" \"description\": \"Cloud\"\n",
" },\n",
" \"gpt-oss-120b-groq\": {\n",
" \"name\": \"GPT-OSS-120B (Groq)\",\n",
" \"client\": groq_client,\n",
" \"model\": \"openai/gpt-oss-120b\",\n",
" \"description\": \"Cloud - Larger GPT-OSS\"\n",
" },\n",
" \"qwen2.5-coder-32b-groq\": {\n",
" \"name\": \"Qwen2.5-Coder 32B (Groq)\",\n",
" \"client\": groq_client,\n",
" \"model\": \"qwen/qwen2.5-coder-32b-instruct\",\n",
" \"description\": \"Cloud\"\n",
" },\n",
" })\n",
"\n",
"# OpenRouter models\n",
"if openrouter_client:\n",
" MODELS.update({\n",
" \"qwen-2.5-coder-32b-openrouter\": {\n",
" \"name\": \"Qwen2.5-Coder 32B (OpenRouter)\",\n",
" \"client\": openrouter_client,\n",
" \"model\": \"qwen/qwen-2.5-coder-32b-instruct\",\n",
" \"description\": \"Cloud - Perfect for demos, 50 req/day free\"\n",
" },\n",
" \"gpt-oss-20b-groq\": {\n",
" \"name\": \"GPT-OSS-20B\",\n",
" \"client\": openrouter_client,\n",
" \"model\": \"openai/gpt-oss-20b\",\n",
" \"description\": \"Cloud - OpenAI's open model, excellent for code!\"\n",
" },\n",
" })\n",
"\n",
"print(f\"Configured {len(MODELS)} models\")\n",
"if openrouter_client:\n",
" print(\"OpenRouter models available (perfect for limited storage demos!)\")\n",
"if groq_client:\n",
" print(\"Groq models available (fast cloud inference!)\")\n",
"if \"qwen2.5-coder:7b\" in MODELS:\n",
" print(\"Ollama models available (unlimited local usage!)\")\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"BUG_DETECTION_SYSTEM_PROMPT = \"\"\"You are an expert code reviewer specializing in finding bugs, security vulnerabilities, and logic errors.\n",
"\n",
"Your task is to analyze Python code and identify issues. Return ONLY a valid JSON array with this exact format:\n",
"[{\n",
" \"severity\": \"critical|high|medium|low\",\n",
" \"line\": number,\n",
" \"issue\": \"brief description of the problem\",\n",
" \"suggestion\": \"specific fix recommendation\"\n",
"}]\n",
"\n",
"Be thorough but concise. Focus on real bugs and security issues.\"\"\"\n",
"\n",
"IMPROVEMENTS_SYSTEM_PROMPT = \"\"\"You are a senior software engineer specializing in code quality and best practices.\n",
"\n",
"Analyze the Python code and suggest improvements for:\n",
"- Code readability and maintainability\n",
"- Performance optimizations\n",
"- Pythonic idioms and conventions\n",
"- Better error handling\n",
"\n",
"Return ONLY a JSON array:\n",
"[{\n",
" \"category\": \"readability|performance|style|error_handling\",\n",
" \"line\": number,\n",
" \"current\": \"current code snippet\",\n",
" \"improved\": \"improved code snippet\",\n",
" \"explanation\": \"why this is better\"\n",
"}]\n",
"\n",
"Only suggest meaningful improvements.\"\"\"\n",
"\n",
"TEST_GENERATION_SYSTEM_PROMPT = \"\"\"You are an expert in writing comprehensive unit tests.\n",
"\n",
"Generate pytest unit tests for the given Python code. Include:\n",
"- Test cases for normal operation\n",
"- Edge cases and boundary conditions\n",
"- Error handling tests\n",
"- Tests for any bugs that were identified\n",
"\n",
"Return ONLY Python code with pytest tests. Include the original code at the top if needed.\n",
"Put the imports at the top of the file first.\n",
"Do not include explanations or markdown formatting.\"\"\"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def extract_json_from_response(text: str) -> List[Dict]:\n",
" \"\"\"Extract JSON array from model response, handling markdown code blocks.\"\"\"\n",
" # Remove markdown code blocks\n",
" text = re.sub(r'```json\\n?', '', text)\n",
" text = re.sub(r'```\\n?', '', text)\n",
" \n",
" # Try to find JSON array\n",
" json_match = re.search(r'\\[\\s*\\{.*\\}\\s*\\]', text, re.DOTALL)\n",
" if json_match:\n",
" try:\n",
" return json.loads(json_match.group())\n",
" except json.JSONDecodeError:\n",
" pass\n",
" \n",
" # Fallback: try parsing entire response\n",
" try:\n",
" return json.loads(text.strip())\n",
" except json.JSONDecodeError:\n",
" return []\n",
"\n",
"def detect_bugs(code: str, model_key: str) -> Dict:\n",
" \"\"\"Detect bugs and security issues in code.\"\"\"\n",
" model_config = MODELS[model_key]\n",
" client = model_config[\"client\"]\n",
" model_name = model_config[\"model\"]\n",
" \n",
" user_prompt = f\"Analyze this Python code for bugs and security issues:\\n\\n```python\\n{code}\\n```\"\n",
" \n",
" try:\n",
" response = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": BUG_DETECTION_SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": user_prompt}\n",
" ],\n",
" temperature=0.1\n",
" )\n",
" \n",
" content = response.choices[0].message.content\n",
" issues = extract_json_from_response(content)\n",
" \n",
" return {\n",
" \"model\": model_config[\"name\"],\n",
" \"issues\": issues,\n",
" \"raw_response\": content,\n",
" \"success\": True\n",
" }\n",
" except Exception as e:\n",
" return {\n",
" \"model\": model_config[\"name\"],\n",
" \"issues\": [],\n",
" \"error\": str(e),\n",
" \"success\": False\n",
" }\n",
"\n",
"def suggest_improvements(code: str, model_key: str) -> Dict:\n",
" \"\"\"Suggest code improvements and best practices.\"\"\"\n",
" model_config = MODELS[model_key]\n",
" client = model_config[\"client\"]\n",
" model_name = model_config[\"model\"]\n",
" \n",
" user_prompt = f\"Suggest improvements for this Python code:\\n\\n```python\\n{code}\\n```\"\n",
" \n",
" try:\n",
" response = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": IMPROVEMENTS_SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": user_prompt}\n",
" ],\n",
" temperature=0.2\n",
" )\n",
" \n",
" content = response.choices[0].message.content\n",
" improvements = extract_json_from_response(content)\n",
" \n",
" return {\n",
" \"model\": model_config[\"name\"],\n",
" \"improvements\": improvements,\n",
" \"raw_response\": content,\n",
" \"success\": True\n",
" }\n",
" except Exception as e:\n",
" return {\n",
" \"model\": model_config[\"name\"],\n",
" \"improvements\": [],\n",
" \"error\": str(e),\n",
" \"success\": False\n",
" }\n",
"\n",
"def generate_tests(code: str, bugs: List[Dict], model_key: str) -> Dict:\n",
" \"\"\"Generate unit tests for the code.\"\"\"\n",
" model_config = MODELS[model_key]\n",
" client = model_config[\"client\"]\n",
" model_name = model_config[\"model\"]\n",
" \n",
" bugs_context = \"\"\n",
" if bugs:\n",
" bugs_context = f\"\\n\\nNote: The following bugs were identified:\\n\" + \"\\n\".join([f\"- Line {b.get('line', '?')}: {b.get('issue', '')}\" for b in bugs])\n",
" \n",
" user_prompt = f\"Generate pytest unit tests for this Python code:{bugs_context}\\n\\n```python\\n{code}\\n```\"\n",
" \n",
" try:\n",
" response = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": TEST_GENERATION_SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": user_prompt}\n",
" ],\n",
" temperature=0.3\n",
" )\n",
" \n",
" content = response.choices[0].message.content\n",
" # Remove markdown code blocks if present\n",
" test_code = re.sub(r'```python\\n?', '', content)\n",
" test_code = re.sub(r'```\\n?', '', test_code)\n",
" \n",
" return {\n",
" \"model\": model_config[\"name\"],\n",
" \"test_code\": test_code.strip(),\n",
" \"raw_response\": content,\n",
" \"success\": True\n",
" }\n",
" except Exception as e:\n",
" return {\n",
" \"model\": model_config[\"name\"],\n",
" \"test_code\": \"\",\n",
" \"error\": str(e),\n",
" \"success\": False\n",
" }\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def format_bugs_output(result: Dict) -> str:\n",
" \"\"\"Format bug detection results for display.\"\"\"\n",
" if not result.get(\"success\"):\n",
" return f\"**Error with {result['model']}:** {result.get('error', 'Unknown error')}\"\n",
" \n",
" issues = result.get(\"issues\", [])\n",
" if not issues:\n",
" return f\"✅ **{result['model']}**: No issues found. Code looks good!\"\n",
" \n",
" output = [f\"**{result['model']}** - Found {len(issues)} issue(s):\\n\"]\n",
" \n",
" severity_order = {\"critical\": 0, \"high\": 1, \"medium\": 2, \"low\": 3}\n",
" sorted_issues = sorted(issues, key=lambda x: severity_order.get(x.get(\"severity\", \"low\"), 3))\n",
" \n",
" for issue in sorted_issues:\n",
" severity = issue.get(\"severity\", \"unknown\").upper()\n",
" line = issue.get(\"line\", \"?\")\n",
" issue_desc = issue.get(\"issue\", \"\")\n",
" suggestion = issue.get(\"suggestion\", \"\")\n",
" \n",
" severity_emoji = {\n",
" \"CRITICAL\": \"🔴\",\n",
" \"HIGH\": \"🟠\",\n",
" \"MEDIUM\": \"🟡\",\n",
" \"LOW\": \"🔵\"\n",
" }.get(severity, \"⚪\")\n",
" \n",
" output.append(f\"{severity_emoji} **{severity}** (Line {line}): {issue_desc}\")\n",
" if suggestion:\n",
" output.append(f\" 💡 *Fix:* {suggestion}\")\n",
" output.append(\"\")\n",
" \n",
" return \"\\n\".join(output)\n",
"\n",
"def format_improvements_output(result: Dict) -> str:\n",
" \"\"\"Format improvement suggestions for display.\"\"\"\n",
" if not result.get(\"success\"):\n",
" return f\"**Error with {result['model']}:** {result.get('error', 'Unknown error')}\"\n",
" \n",
" improvements = result.get(\"improvements\", [])\n",
" if not improvements:\n",
" return f\"✅ **{result['model']}**: Code follows best practices. No major improvements needed!\"\n",
" \n",
" output = [f\"**{result['model']}** - {len(improvements)} suggestion(s):\\n\"]\n",
" \n",
" for imp in improvements:\n",
" category = imp.get(\"category\", \"general\").replace(\"_\", \" \").title()\n",
" line = imp.get(\"line\", \"?\")\n",
" current = imp.get(\"current\", \"\")\n",
" improved = imp.get(\"improved\", \"\")\n",
" explanation = imp.get(\"explanation\", \"\")\n",
" \n",
" output.append(f\"\\n📝 **{category}** (Line {line}):\")\n",
" if current and improved:\n",
" output.append(f\" Before: `{current[:60]}{'...' if len(current) > 60 else ''}`\")\n",
" output.append(f\" After: `{improved[:60]}{'...' if len(improved) > 60 else ''}`\")\n",
" if explanation:\n",
" output.append(f\" 💡 {explanation}\")\n",
" \n",
" return \"\\n\".join(output)\n",
"\n",
"def format_tests_output(result: Dict) -> str:\n",
" \"\"\"Format test generation results for display.\"\"\"\n",
" if not result.get(\"success\"):\n",
" return f\"**Error with {result['model']}:** {result.get('error', 'Unknown error')}\"\n",
" \n",
" test_code = result.get(\"test_code\", \"\")\n",
" if not test_code:\n",
" return f\"⚠️ **{result['model']}**: No tests generated.\"\n",
" \n",
" return test_code\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def review_code(code: str, model_key: str, include_tests: bool = True) -> tuple:\n",
" \"\"\"Main function to perform complete code review.\"\"\"\n",
" if not code.strip():\n",
" return \"Please provide code to review.\", \"\", \"\"\n",
" \n",
" # Detect bugs\n",
" bugs_result = detect_bugs(code, model_key)\n",
" bugs_output = format_bugs_output(bugs_result)\n",
" bugs_issues = bugs_result.get(\"issues\", [])\n",
" \n",
" # Suggest improvements\n",
" improvements_result = suggest_improvements(code, model_key)\n",
" improvements_output = format_improvements_output(improvements_result)\n",
" \n",
" # Generate tests\n",
" tests_output = \"\"\n",
" if include_tests:\n",
" tests_result = generate_tests(code, bugs_issues, model_key)\n",
" tests_output = format_tests_output(tests_result)\n",
" \n",
" return bugs_output, improvements_output, tests_output\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def compare_models(code: str, model_keys: List[str]) -> str:\n",
" \"\"\"Compare multiple models on the same code.\"\"\"\n",
" if not code.strip():\n",
" return \"Please provide code to review.\"\n",
" \n",
" results = []\n",
" all_issues = []\n",
" \n",
" for model_key in model_keys:\n",
" result = detect_bugs(code, model_key)\n",
" results.append(result)\n",
" if result.get(\"success\"):\n",
" all_issues.extend(result.get(\"issues\", []))\n",
" \n",
" # Build comparison output\n",
" output = [\"# Model Comparison Results\\n\"]\n",
" \n",
" for result in results:\n",
" model_name = result[\"model\"]\n",
" issues = result.get(\"issues\", [])\n",
" success = result.get(\"success\", False)\n",
" \n",
" if success:\n",
" output.append(f\"\\n**{model_name}**: Found {len(issues)} issue(s)\")\n",
" if issues:\n",
" severity_counts = {}\n",
" for issue in issues:\n",
" sev = issue.get(\"severity\", \"low\")\n",
" severity_counts[sev] = severity_counts.get(sev, 0) + 1\n",
" output.append(f\" Breakdown: {dict(severity_counts)}\")\n",
" else:\n",
" output.append(f\"\\n**{model_name}**: Error - {result.get('error', 'Unknown')}\")\n",
" \n",
" # Find consensus issues (found by multiple models)\n",
" if len(results) > 1:\n",
" issue_signatures = {}\n",
" for result in results:\n",
" if result.get(\"success\"):\n",
" for issue in result.get(\"issues\", []):\n",
" # Create signature from line and issue description\n",
" sig = f\"{issue.get('line')}-{issue.get('issue', '')[:50]}\"\n",
" if sig not in issue_signatures:\n",
" issue_signatures[sig] = []\n",
" issue_signatures[sig].append(result[\"model\"])\n",
" \n",
" consensus = [sig for sig, models in issue_signatures.items() if len(models) > 1]\n",
" if consensus:\n",
" output.append(f\"\\n\\n **Consensus Issues**: {len(consensus)} issue(s) identified by multiple models\")\n",
" \n",
" return \"\\n\".join(output)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gradio UI\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"* Running on local URL: http://127.0.0.1:7884\n",
"* To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"<div><iframe src=\"http://127.0.0.1:7884/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\queueing.py\", line 745, in process_events\n",
" response = await route_utils.call_process_api(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\route_utils.py\", line 354, in call_process_api\n",
" output = await app.get_blocks().process_api(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\blocks.py\", line 2116, in process_api\n",
" result = await self.call_function(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\blocks.py\", line 1623, in call_function\n",
" prediction = await anyio.to_thread.run_sync( # type: ignore\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\anyio\\to_thread.py\", line 56, in run_sync\n",
" return await get_async_backend().run_sync_in_worker_thread(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 2485, in run_sync_in_worker_thread\n",
" return await future\n",
" ^^^^^^^^^^^^\n",
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 976, in run\n",
" result = context.run(func, *args)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\utils.py\", line 915, in wrapper\n",
" response = f(*args, **kwargs)\n",
" ^^^^^^^^^^^^^^^^^^\n",
" File \"C:\\Users\\Philo Baba\\AppData\\Local\\Temp\\ipykernel_43984\\2272281361.py\", line 7, in review_code\n",
" bugs_result = detect_bugs(code, model_key)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"C:\\Users\\Philo Baba\\AppData\\Local\\Temp\\ipykernel_43984\\638705449.py\", line 23, in detect_bugs\n",
" model_config = MODELS[model_key]\n",
" ~~~~~~^^^^^^^^^^^\n",
"KeyError: 'deepseek-coder-v2-openrouter'\n"
]
}
],
"source": [
"# Example buggy code for testing\n",
"EXAMPLE_CODE = '''def divide_numbers(a, b):\n",
" return a / b\n",
"\n",
"def process_user_data(user_input):\n",
" # Missing input validation\n",
" result = eval(user_input)\n",
" return result\n",
"\n",
"def get_user_by_id(user_id):\n",
" # SQL injection vulnerability\n",
" query = f\"SELECT * FROM users WHERE id = {user_id}\"\n",
" return query\n",
"\n",
"def calculate_average(numbers):\n",
" total = sum(numbers)\n",
" return total / len(numbers) # Potential division by zero\n",
"'''\n",
"\n",
"def create_ui():\n",
" with gr.Blocks(title=\"AI Code Review Assistant\", theme=gr.themes.Soft()) as demo:\n",
" gr.Markdown(\"\"\"\n",
" # 🔍 AI-Powered Code Review Assistant\n",
" \n",
" Review your Python code using open-source AI models. Detect bugs, get improvement suggestions, and generate unit tests.\n",
" \"\"\")\n",
" \n",
" with gr.Row():\n",
" with gr.Column(scale=2):\n",
" code_input = gr.Code(\n",
" label=\"Python Code to Review\",\n",
" value=EXAMPLE_CODE,\n",
" language=\"python\",\n",
" lines=20\n",
" )\n",
" \n",
" with gr.Row():\n",
" model_selector = gr.Dropdown(\n",
" choices=list(MODELS.keys()),\n",
" value=list(MODELS.keys())[0],\n",
" label=\"Select Model\",\n",
" info=\"Choose an open-source code model\"\n",
" )\n",
" \n",
" include_tests = gr.Checkbox(\n",
" label=\"Generate Tests\",\n",
" value=True\n",
" )\n",
" \n",
" with gr.Row():\n",
" review_btn = gr.Button(\"🔍 Review Code\", variant=\"primary\", scale=2)\n",
" compare_btn = gr.Button(\"📊 Compare Models\", variant=\"secondary\", scale=1)\n",
" \n",
" with gr.Column(scale=3):\n",
" with gr.Tabs() as tabs:\n",
" with gr.Tab(\"🐛 Bug Detection\"):\n",
" bugs_output = gr.Markdown(value=\"Select a model and click 'Review Code' to analyze your code.\")\n",
" \n",
" with gr.Tab(\"✨ Improvements\"):\n",
" improvements_output = gr.Markdown(value=\"Get suggestions for code improvements and best practices.\")\n",
" \n",
" with gr.Tab(\"🧪 Unit Tests\"):\n",
" tests_output = gr.Code(\n",
" label=\"Generated Test Code\",\n",
" language=\"python\",\n",
" lines=25\n",
" )\n",
" \n",
" with gr.Tab(\"📊 Comparison\"):\n",
" comparison_output = gr.Markdown(value=\"Compare multiple models side-by-side.\")\n",
" \n",
" # Event handlers\n",
" review_btn.click(\n",
" fn=review_code,\n",
" inputs=[code_input, model_selector, include_tests],\n",
" outputs=[bugs_output, improvements_output, tests_output]\n",
" )\n",
" \n",
" def compare_selected_models(code):\n",
" # Compare first 3 models by default\n",
" model_keys = list(MODELS.keys())[:3]\n",
" return compare_models(code, model_keys)\n",
" \n",
" compare_btn.click(\n",
" fn=compare_selected_models,\n",
" inputs=[code_input],\n",
" outputs=[comparison_output]\n",
" )\n",
" \n",
" gr.Examples(\n",
" examples=[\n",
" [EXAMPLE_CODE],\n",
" [\"\"\"def fibonacci(n):\n",
" if n <= 1:\n",
" return n\n",
" return fibonacci(n-1) + fibonacci(n-2)\n",
"\"\"\"],\n",
" [\"\"\"def parse_config(file_path):\n",
" with open(file_path) as f:\n",
" return eval(f.read())\n",
"\"\"\"]\n",
" ],\n",
" inputs=[code_input]\n",
" )\n",
" \n",
" return demo\n",
"\n",
"demo = create_ui()\n",
"demo.launch(inbrowser=True, share=False)\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"**Qwen2.5-Coder 32B (OpenRouter)** - Found 2 issue(s):\n",
"\n",
"🔴 **CRITICAL** (Line 2): No division by zero protection\n",
" 💡 *Fix:* Add a check for b == 0 and raise ValueError or handle ZeroDivisionError\n",
"\n",
"🟡 **MEDIUM** (Line 2): No input validation for numeric types\n",
" 💡 *Fix:* Add type checking to ensure a and b are numbers (int/float)\n",
"\n"
]
}
],
"source": [
"# Test with a simple example\n",
"test_code = \"\"\"def divide(a, b):\n",
" return a / b\n",
"\"\"\"\n",
"\n",
"# Test bug detection\n",
"result = detect_bugs(test_code, list(MODELS.keys())[0])\n",
"print(format_bugs_output(result))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,454 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "275415f0",
"metadata": {
"id": "275415f0"
},
"outputs": [],
"source": [
"# pip installs\n",
"\n",
"!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
"!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "535bd9de",
"metadata": {
"id": "535bd9de"
},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import re\n",
"import math\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"from google.colab import userdata\n",
"from huggingface_hub import login\n",
"import torch\n",
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, mean_absolute_percentage_error\n",
"import torch.nn.functional as F\n",
"import transformers\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n",
"from datasets import load_dataset, Dataset, DatasetDict\n",
"from datetime import datetime\n",
"from peft import PeftModel\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc58234a",
"metadata": {
"id": "fc58234a"
},
"outputs": [],
"source": [
"# Constants\n",
"\n",
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
"PROJECT_NAME = \"pricer\"\n",
"HF_USER = \"ed-donner\"\n",
"RUN_NAME = \"2024-09-13_13.04.39\"\n",
"PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
"REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n",
"FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
"\n",
"\n",
"DATASET_NAME = f\"{HF_USER}/home-data\"\n",
"\n",
"QUANT_4_BIT = True\n",
"top_K = 6\n",
"\n",
"%matplotlib inline\n",
"\n",
"# Used for writing to output in color\n",
"\n",
"GREEN = \"\\033[92m\"\n",
"YELLOW = \"\\033[93m\"\n",
"RED = \"\\033[91m\"\n",
"RESET = \"\\033[0m\"\n",
"COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0145ad8a",
"metadata": {
"id": "0145ad8a"
},
"outputs": [],
"source": [
"# Log in to HuggingFace\n",
"\n",
"hf_token = userdata.get('HF_TOKEN')\n",
"login(hf_token, add_to_git_credential=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6919506e",
"metadata": {
"id": "6919506e"
},
"outputs": [],
"source": [
"dataset = load_dataset(DATASET_NAME)\n",
"train = dataset['train']\n",
"test = dataset['test']\n",
"len(train), len(test)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea79cde1",
"metadata": {
"id": "ea79cde1"
},
"outputs": [],
"source": [
"if QUANT_4_BIT:\n",
" quant_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
" bnb_4bit_quant_type=\"nf4\"\n",
" )\n",
"else:\n",
" quant_config = BitsAndBytesConfig(\n",
" load_in_8bit=True,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef108f8d",
"metadata": {
"id": "ef108f8d"
},
"outputs": [],
"source": [
"# Load the Tokenizer and the Model\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"tokenizer.padding_side = \"right\"\n",
"\n",
"base_model = AutoModelForCausalLM.from_pretrained(\n",
" BASE_MODEL,\n",
" quantization_config=quant_config,\n",
" device_map=\"auto\",\n",
")\n",
"base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
"\n",
"# Load the fine-tuned model with PEFT\n",
"if REVISION:\n",
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n",
"else:\n",
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n",
"\n",
"fine_tuned_model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f3c4176",
"metadata": {
"id": "7f3c4176"
},
"outputs": [],
"source": [
"def extract_price(s):\n",
" if \"Price is $\" in s:\n",
" contents = s.split(\"Price is $\")[1]\n",
" contents = contents.replace(',','')\n",
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", contents)\n",
" return float(match.group()) if match else 0\n",
" return 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "436fa29a",
"metadata": {
"id": "436fa29a"
},
"outputs": [],
"source": [
"# Original prediction function takes the most likely next token\n",
"\n",
"def model_predict(prompt):\n",
" set_seed(42)\n",
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
" attention_mask = torch.ones(inputs.shape, device=\"cuda\")\n",
" outputs = fine_tuned_model.generate(inputs, attention_mask=attention_mask, max_new_tokens=3, num_return_sequences=1)\n",
" response = tokenizer.decode(outputs[0])\n",
" return extract_price(response)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a666dab6",
"metadata": {
"id": "a666dab6"
},
"outputs": [],
"source": [
"def improved_model_predict(prompt, device=\"cuda\", top_k=3):\n",
" set_seed(42)\n",
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
" attention_mask = torch.ones_like(inputs)\n",
"\n",
" with torch.no_grad():\n",
" outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n",
" next_token_logits = outputs.logits[:, -1, :].to(\"cpu\")\n",
"\n",
" probs = F.softmax(next_token_logits, dim=-1)\n",
" top_prob, top_token_id = probs.topk(top_k)\n",
"\n",
" prices, weights = [], []\n",
" for weight, token_id in zip(top_prob[0], top_token_id[0]):\n",
" token_text = tokenizer.decode(token_id)\n",
" try:\n",
" value = float(token_text)\n",
" except ValueError:\n",
" continue\n",
" prices.append(value)\n",
" weights.append(weight.item())\n",
"\n",
" if not prices:\n",
" return 0.0, 0.0, 0.0 # price, confidence, spread\n",
"\n",
" total = sum(weights)\n",
" normalized = [w / total for w in weights]\n",
" weighted_price = sum(p * w for p, w in zip(prices, normalized))\n",
" variance = sum(w * (p - weighted_price) ** 2 for p, w in zip(prices, normalized))\n",
" confidence = 1 - min(variance / (weighted_price + 1e-6), 1)\n",
" spread = max(prices) - min(prices)\n",
" return weighted_price, confidence, spread"
]
},
{
"cell_type": "code",
"source": [
"!pip install -q gradio>=4.0\n",
"import gradio as gr\n",
"\n",
"def format_prediction(description):\n",
" price, confidence, spread = improved_model_predict(description)\n",
" return (\n",
" f\"${price:,.2f}\",\n",
" f\"{confidence*100:.1f}%\",\n",
" f\"${spread:,.2f}\",\n",
" )\n",
"\n",
"demo = gr.Interface(\n",
" fn=format_prediction,\n",
" inputs=gr.Textbox(lines=10, label=\"Product Description\"),\n",
" outputs=[\n",
" gr.Textbox(label=\"Estimated Price\"),\n",
" gr.Textbox(label=\"Confidence\"),\n",
" gr.Textbox(label=\"Token Spread\"),\n",
" ],\n",
" title=\"Open-Source Product Price Estimator\",\n",
" description=\"Paste a cleaned product blurb. The model returns a price estimate, a confidence score based on top-k token dispersion, and the spread between top predictions.\",\n",
")\n",
"\n",
"demo.launch(share=True)"
],
"metadata": {
"id": "Km4tHaeQyMoW"
},
"id": "Km4tHaeQyMoW",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "9664c4c7",
"metadata": {
"id": "9664c4c7"
},
"outputs": [],
"source": [
"\n",
"class Tester:\n",
"\n",
" def __init__(self, predictor, data, title=None, show_progress=True):\n",
" self.predictor = predictor\n",
" self.data = data\n",
" self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
" self.size = len(data)\n",
" self.guesses, self.truths, self.errors, self.rel_errors, self.sles, self.colors = [], [], [], [], [], []\n",
" self.show_progress = show_progress\n",
"\n",
" def color_for(self, error, truth):\n",
" if error < 40 or error / truth < 0.2:\n",
" return \"green\"\n",
" elif error < 80 or error / truth < 0.4:\n",
" return \"orange\"\n",
" else:\n",
" return \"red\"\n",
"\n",
" def run_datapoint(self, i):\n",
" datapoint = self.data[i]\n",
" guess = self.predictor(datapoint[\"text\"])\n",
" truth = datapoint[\"price\"]\n",
"\n",
" error = guess - truth\n",
" abs_error = abs(error)\n",
" rel_error = abs_error / truth if truth != 0 else 0\n",
" log_error = math.log(truth + 1) - math.log(guess + 1)\n",
" sle = log_error ** 2\n",
" color = self.color_for(abs_error, truth)\n",
"\n",
" title = (datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\") if \"\\n\\n\" in datapoint[\"text\"] else datapoint[\"text\"][:20]\n",
" self.guesses.append(guess)\n",
" self.truths.append(truth)\n",
" self.errors.append(error)\n",
" self.rel_errors.append(rel_error)\n",
" self.sles.append(sle)\n",
" self.colors.append(color)\n",
"\n",
" print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} \"\n",
" f\"Error: ${abs_error:,.2f} RelErr: {rel_error*100:.1f}% SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
"\n",
" def chart_all(self, chart_title):\n",
" \"\"\"Compact version: 4 performance charts in one grid.\"\"\"\n",
" t, g = np.array(self.truths), np.array(self.guesses)\n",
" rel_err, abs_err = np.array(self.rel_errors) * 100, np.abs(np.array(self.errors))\n",
"\n",
" fig, axs = plt.subplots(2, 2, figsize=(14, 10))\n",
" fig.suptitle(f\"Performance Dashboard — {chart_title}\", fontsize=16, fontweight=\"bold\")\n",
"\n",
" # Scatter plot\n",
" max_val = max(t.max(), g.max()) * 1.05\n",
" axs[1, 1].plot([0, max_val], [0, max_val], \"b--\", alpha=0.6)\n",
" axs[1, 1].scatter(t, g, s=20, c=self.colors, alpha=0.6)\n",
" axs[1, 1].set_title(\"Predictions vs Ground Truth\")\n",
" axs[1, 1].set_xlabel(\"True Price ($)\")\n",
" axs[1, 1].set_ylabel(\"Predicted ($)\")\n",
"\n",
" # Accuracy by price range\n",
" bins = np.linspace(t.min(), t.max(), 6)\n",
" labels = [f\"${bins[i]:.0f}${bins[i+1]:.0f}\" for i in range(len(bins)-1)]\n",
" inds = np.digitize(t, bins) - 1\n",
" avg_err = [rel_err[inds == i].mean() for i in range(len(labels))]\n",
" axs[0, 0].bar(labels, avg_err, color=\"seagreen\", alpha=0.8)\n",
" axs[0, 0].set_title(\"Avg Relative Error by Price Range\")\n",
" axs[0, 0].set_ylabel(\"Relative Error (%)\")\n",
" axs[0, 0].tick_params(axis=\"x\", rotation=30)\n",
"\n",
" # Relative error distribution\n",
" axs[0, 1].hist(rel_err, bins=25, color=\"mediumpurple\", edgecolor=\"black\", alpha=0.7)\n",
" axs[0, 1].set_title(\"Relative Error Distribution (%)\")\n",
" axs[0, 1].set_xlabel(\"Relative Error (%)\")\n",
"\n",
" # Absolute error distribution\n",
" axs[1, 0].hist(abs_err, bins=25, color=\"steelblue\", edgecolor=\"black\", alpha=0.7)\n",
" axs[1, 0].axvline(abs_err.mean(), color=\"red\", linestyle=\"--\", label=f\"Mean={abs_err.mean():.2f}\")\n",
" axs[1, 0].set_title(\"Absolute Error Distribution\")\n",
" axs[1, 0].set_xlabel(\"Absolute Error ($)\")\n",
" axs[1, 0].legend()\n",
"\n",
" for ax in axs.ravel():\n",
" ax.grid(alpha=0.3)\n",
"\n",
" plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
" plt.show()\n",
"\n",
" def report(self):\n",
" y_true = np.array(self.truths)\n",
" y_pred = np.array(self.guesses)\n",
"\n",
" mae = mean_absolute_error(y_true, y_pred)\n",
" rmse = math.sqrt(mean_squared_error(y_true, y_pred))\n",
" rmsle = math.sqrt(sum(self.sles) / self.size)\n",
" mape = mean_absolute_percentage_error(y_true, y_pred) * 100\n",
" median_error = float(np.median(np.abs(y_true - y_pred)))\n",
" r2 = r2_score(y_true, y_pred)\n",
"\n",
" hit_rate_green = sum(1 for c in self.colors if c == \"green\") / self.size * 100\n",
" hit_rate_acceptable = sum(1 for c in self.colors if c in (\"green\", \"orange\")) / self.size * 100\n",
"\n",
" print(f\"\\n{'='*70}\")\n",
" print(f\"FINAL REPORT: {self.title}\")\n",
" print(f\"{'='*70}\")\n",
" print(f\"Total Predictions: {self.size}\")\n",
" print(f\"\\n--- Error Metrics ---\")\n",
" print(f\"Mean Absolute Error (MAE): ${mae:,.2f}\")\n",
" print(f\"Median Error: ${median_error:,.2f}\")\n",
" print(f\"Root Mean Squared Error (RMSE): ${rmse:,.2f}\")\n",
" print(f\"Root Mean Squared Log Error (RMSLE): {rmsle:.4f}\")\n",
" print(f\"Mean Absolute Percentage Error (MAPE): {mape:.2f}%\")\n",
" print(f\"\\n--- Accuracy Metrics ---\")\n",
" print(f\"R² Score: {r2:.4f}\")\n",
" print(f\"Hit Rate (Green): {hit_rate_green:.1f}%\")\n",
" print(f\"Hit Rate (Green+Orange): {hit_rate_acceptable:.1f}%\")\n",
" print(f\"{'='*70}\\n\")\n",
" chart_title = f\"{self.title} | MAE=${mae:,.2f} | RMSLE={rmsle:.3f} | R²={r2:.3f}\"\n",
"\n",
" self.chart_all(chart_title)\n",
"\n",
" def run(self):\n",
" iterator = tqdm(range(self.size), desc=\"Testing Model\") if self.show_progress else range(self.size)\n",
" for i in iterator:\n",
" self.run_datapoint(i)\n",
" self.report()\n",
"\n",
" @classmethod\n",
" def test(cls, function, data, title=None):\n",
" cls(function, data, title=title).run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e60a696",
"metadata": {
"id": "2e60a696"
},
"outputs": [],
"source": [
"Tester.test(\n",
" improved_model_predict,\n",
" test,\n",
" title=\"Home appliances prediction\"\n",
")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,466 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "GHsssBgWM_l0"
},
"source": [
"# Fine-Tuned Product Price Predictor\n",
"\n",
"Evaluate fine-tuned Llama 3.1 8B model for product price estimation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MDyR63OTNUJ6"
},
"outputs": [],
"source": [
"# Install required libraries for model inference\n",
"%pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
"%pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-yikV8pRBer9"
},
"outputs": [],
"source": [
"# Import required libraries\n",
"import os\n",
"import re\n",
"import math\n",
"from tqdm import tqdm\n",
"from google.colab import userdata\n",
"from huggingface_hub import login\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import transformers\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n",
"from datasets import load_dataset, Dataset, DatasetDict\n",
"from datetime import datetime\n",
"from peft import PeftModel\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uuTX-xonNeOK"
},
"outputs": [],
"source": [
"# Configuration\n",
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
"PROJECT_NAME = \"pricer\"\n",
"HF_USER = \"ed-donner\" # Change to your HF username\n",
"RUN_NAME = \"2024-09-13_13.04.39\"\n",
"PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
"REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n",
"FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
"\n",
"# Quantization setting (False = 8-bit = better accuracy, more memory)\n",
"QUANT_4_BIT = False # Changed to 8-bit for better accuracy\n",
"\n",
"%matplotlib inline\n",
"\n",
"# Color codes for output\n",
"GREEN = \"\\033[92m\"\n",
"YELLOW = \"\\033[93m\"\n",
"RED = \"\\033[91m\"\n",
"RESET = \"\\033[0m\"\n",
"COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8JArT3QAQAjx"
},
"source": [
"# Step 1\n",
"\n",
"### Load dataset and fine-tuned model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WyFPZeMcM88v"
},
"outputs": [],
"source": [
"# Login to HuggingFace\n",
"hf_token = userdata.get('HF_TOKEN')\n",
"login(hf_token, add_to_git_credential=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cvXVoJH8LS6u"
},
"outputs": [],
"source": [
"# Load product pricing dataset\n",
"dataset = load_dataset(DATASET_NAME)\n",
"train = dataset['train']\n",
"test = dataset['test']\n",
"\n",
"print(f\"✓ Loaded {len(train)} train and {len(test)} test samples\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xb86e__Wc7j_"
},
"outputs": [],
"source": [
"# Verify data structure\n",
"test[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qJWQ0a3wZ0Bw"
},
"source": [
"### Load Tokenizer and Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lAUAAcEC6ido"
},
"outputs": [],
"source": [
"# Configure quantization for memory efficiency\n",
"if QUANT_4_BIT:\n",
" quant_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
" bnb_4bit_quant_type=\"nf4\"\n",
" )\n",
"else:\n",
" quant_config = BitsAndBytesConfig(\n",
" load_in_8bit=True,\n",
" bnb_8bit_compute_dtype=torch.bfloat16\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "R_O04fKxMMT-"
},
"outputs": [],
"source": [
"# Load tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"tokenizer.padding_side = \"right\"\n",
"\n",
"# Load base model with quantization\n",
"base_model = AutoModelForCausalLM.from_pretrained(\n",
" BASE_MODEL,\n",
" quantization_config=quant_config,\n",
" device_map=\"auto\",\n",
")\n",
"base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
"\n",
"# Load fine-tuned weights\n",
"if REVISION:\n",
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n",
"else:\n",
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n",
"\n",
"print(f\"✓ Model loaded - Memory: {fine_tuned_model.get_memory_footprint() / 1e6:.1f} MB\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kD-GJtbrdd5t"
},
"outputs": [],
"source": [
"# Verify model loaded\n",
"fine_tuned_model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UObo1-RqaNnT"
},
"source": [
"# Step 2\n",
"\n",
"### Model inference and evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qst1LhBVAB04"
},
"outputs": [],
"source": [
"# Extract price from model response\n",
"def extract_price(s):\n",
" if \"Price is $\" in s:\n",
" contents = s.split(\"Price is $\")[1]\n",
" contents = contents.replace(',','')\n",
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", contents)\n",
" return float(match.group()) if match else 0\n",
" return 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jXFBW_5UeEcp"
},
"outputs": [],
"source": [
"# Test extract_price function\n",
"extract_price(\"Price is $a fabulous 899.99 or so\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Oj_PzpdFAIMk"
},
"outputs": [],
"source": [
"# Simple prediction: takes most likely next token\n",
"def model_predict(prompt):\n",
" set_seed(42)\n",
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
" attention_mask = torch.ones(inputs.shape, device=\"cuda\")\n",
" outputs = fine_tuned_model.generate(\n",
" inputs,\n",
" attention_mask=attention_mask,\n",
" max_new_tokens=5, # Increased for flexibility\n",
" temperature=0.1, # Low temperature for consistency\n",
" num_return_sequences=1\n",
" )\n",
" response = tokenizer.decode(outputs[0])\n",
" return extract_price(response)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Je5dR8QEAI1d"
},
"outputs": [],
"source": [
"# Improved prediction: weighted average of top K predictions\n",
"top_K = 5 # Increased from 3 to 5 for better accuracy\n",
"\n",
"def improved_model_predict(prompt, device=\"cuda\"):\n",
" set_seed(42)\n",
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
" attention_mask = torch.ones(inputs.shape, device=device)\n",
"\n",
" with torch.no_grad():\n",
" outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n",
" next_token_logits = outputs.logits[:, -1, :].to('cpu')\n",
"\n",
" next_token_probs = F.softmax(next_token_logits, dim=-1)\n",
" top_prob, top_token_id = next_token_probs.topk(top_K)\n",
" prices, weights = [], []\n",
" for i in range(top_K):\n",
" predicted_token = tokenizer.decode(top_token_id[0][i])\n",
" probability = top_prob[0][i]\n",
" try:\n",
" result = float(predicted_token)\n",
" except ValueError as e:\n",
" result = 0.0\n",
" if result > 0:\n",
" prices.append(result)\n",
" weights.append(probability)\n",
" if not prices:\n",
" return 0.0, 0.0\n",
" total = sum(weights)\n",
" weighted_prices = [price * weight / total for price, weight in zip(prices, weights)]\n",
" return sum(weighted_prices).item()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EpGVJPuC1iho"
},
"source": [
"# Step 3\n",
"\n",
"### Test and evaluate model performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "30lzJXBH7BcK"
},
"outputs": [],
"source": [
"# Evaluation framework\n",
"class Tester:\n",
" def __init__(self, predictor, data, title=None, size=250):\n",
" self.predictor = predictor\n",
" self.data = data\n",
" self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
" self.size = size\n",
" self.guesses = []\n",
" self.truths = []\n",
" self.errors = []\n",
" self.sles = []\n",
" self.colors = []\n",
"\n",
" def color_for(self, error, truth):\n",
" if error<40 or error/truth < 0.2:\n",
" return \"green\"\n",
" elif error<80 or error/truth < 0.4:\n",
" return \"orange\"\n",
" else:\n",
" return \"red\"\n",
"\n",
" def run_datapoint(self, i):\n",
" datapoint = self.data[i]\n",
" guess = self.predictor(datapoint[\"text\"])\n",
" truth = datapoint[\"price\"]\n",
" error = abs(guess - truth)\n",
" log_error = math.log(truth+1) - math.log(guess+1)\n",
" sle = log_error ** 2\n",
" color = self.color_for(error, truth)\n",
" title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n",
" self.guesses.append(guess)\n",
" self.truths.append(truth)\n",
" self.errors.append(error)\n",
" self.sles.append(sle)\n",
" self.colors.append(color)\n",
" print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
"\n",
" def chart(self, title):\n",
" max_error = max(self.errors)\n",
" plt.figure(figsize=(12, 8))\n",
" max_val = max(max(self.truths), max(self.guesses))\n",
" plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n",
" plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
" plt.xlabel('Ground Truth')\n",
" plt.ylabel('Model Estimate')\n",
" plt.xlim(0, max_val)\n",
" plt.ylim(0, max_val)\n",
" plt.title(title)\n",
" plt.show()\n",
"\n",
" def report(self):\n",
" average_error = sum(self.errors) / self.size\n",
" rmsle = math.sqrt(sum(self.sles) / self.size)\n",
" hits = sum(1 for color in self.colors if color==\"green\")\n",
" title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n",
" self.chart(title)\n",
"\n",
" def run(self):\n",
" self.error = 0\n",
" for i in range(self.size):\n",
" self.run_datapoint(i)\n",
" self.report()\n",
"\n",
" @classmethod\n",
" def test(cls, function, data):\n",
" cls(function, data).run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W_KcLvyt6kbb"
},
"outputs": [],
"source": [
"# Run evaluation on 250 test examples\n",
"Tester.test(improved_model_predict, test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nVwiWGVN1ihp"
},
"source": [
"### Performance Optimizations Applied\n",
"\n",
"**Changes for better accuracy:**\n",
"- ✅ 8-bit quantization (vs 4-bit) - Better precision\n",
"- ✅ top_K = 5 (vs 3) - More predictions in weighted average\n",
"- ✅ max_new_tokens = 5 - More flexibility in response\n",
"- ✅ temperature = 0.1 - More consistent predictions\n",
"\n",
"**Expected improvement:** ~10-15% reduction in average error\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hO4DdLa81ihp"
},
"source": [
"### Expected Performance\n",
"\n",
"**Baseline comparisons:**\n",
"- GPT-4o: $76 avg error\n",
"- Llama 3.1 base: $396 avg error \n",
"- Human: $127 avg error\n",
"\n",
"**Fine-tuned model (optimized):**\n",
"- Target: $70-85 avg error\n",
"- With 8-bit quant + top_K=5 + temp=0.1\n",
"- Expected to rival or beat GPT-4o\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -0,0 +1,453 @@
from __future__ import annotations
import json
import os
from dataclasses import dataclass, field
from typing import Dict, Generator, List, Optional, Tuple
import gradio as gr
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv()
# ---------------------------------------------------------------------------
# Configuration helpers
# ---------------------------------------------------------------------------
@dataclass
class AgentConfig:
"""Holds configuration required to talk to an LLM provider."""
name: str
model: str
api_key_env: str
base_url_env: Optional[str] = None
temperature: float = 0.7
supports_json: bool = True
def load_client(config: AgentConfig) -> OpenAI:
"""Create an OpenAI-compatible client for the given agent."""
api_key = os.getenv(config.api_key_env) or os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError(
f"Missing API key for {config.name}. "
f"Set {config.api_key_env} or OPENAI_API_KEY."
)
base_url = (
os.getenv(config.base_url_env)
if config.base_url_env
else os.getenv("OPENAI_BASE_URL")
)
return OpenAI(api_key=api_key, base_url=base_url)
def extract_text(response) -> str:
"""Extract text content from an OpenAI-style response object or dict."""
choices = getattr(response, "choices", None)
if choices is None and isinstance(response, dict):
choices = response.get("choices")
if not choices:
raise RuntimeError(f"LLM response missing choices field: {response!r}")
choice = choices[0]
message = getattr(choice, "message", None)
if message is None and isinstance(choice, dict):
message = choice.get("message")
content = None
if message is not None:
content = getattr(message, "content", None)
if content is None and isinstance(message, dict):
content = message.get("content")
if isinstance(content, list):
parts: List[str] = []
for part in content:
if isinstance(part, dict):
if "text" in part:
parts.append(str(part["text"]))
elif "output_text" in part:
parts.append(str(part["output_text"]))
elif "type" in part and "content" in part:
parts.append(str(part["content"]))
else:
parts.append(str(part))
content = "".join(parts)
if content is None:
text = getattr(choice, "text", None)
if text is None and isinstance(choice, dict):
text = choice.get("text")
if text:
content = text
if content is None:
raise RuntimeError(f"LLM response missing content/text: {response!r}")
return str(content).strip()
# Default configuration leverages OpenAI unless overrides are provided.
DEBATER_A_CONFIG = AgentConfig(
name="Debater A",
model=os.getenv("DEBATER_A_MODEL", "gpt-4o"),
api_key_env="OPENAI_API_KEY",
base_url_env="OPENAI_BASE_URL",
temperature=float(os.getenv("DEBATER_A_TEMPERATURE", 0.7)),
)
DEBATER_B_CONFIG = AgentConfig(
name="Debater B",
model=os.getenv("DEBATER_B_MODEL", "gemini-2.0-flash"),
api_key_env="GOOGLE_API_KEY",
base_url_env="GEMINI_BASE_URL",
temperature=float(os.getenv("DEBATER_B_TEMPERATURE", 0.7)),
)
JUDGE_CONFIG = AgentConfig(
name="Judge",
model=os.getenv("JUDGE_MODEL", "gpt-oss:20b-cloud"),
api_key_env="OLLAMA_API_KEY",
base_url_env="OLLAMA_BASE_URL",
temperature=float(os.getenv("JUDGE_TEMPERATURE", 0.2)),
supports_json=False,
)
REPORTER_CONFIG = AgentConfig(
name="Reporter",
model=os.getenv("REPORTER_MODEL", "MiniMax-M2"),
api_key_env="MINIMAX_API_KEY",
base_url_env="MINIMAX_BASE_URL",
temperature=float(os.getenv("REPORTER_TEMPERATURE", 0.4)),
supports_json=False,
)
THEME = gr.themes.Default(
primary_hue="blue",
secondary_hue="sky",
neutral_hue="gray",
)
CUSTOM_CSS = """
body, .gradio-container {
background: radial-gradient(circle at top, #0f172a 0%, #020617 60%, #020617 100%);
color: #e2e8f0;
}
#live-debate-panel {
background: linear-gradient(135deg, rgba(30,64,175,0.95), rgba(29,78,216,0.85));
color: #f8fafc;
border-radius: 16px;
padding: 24px;
box-shadow: 0 20px 45px rgba(15,23,42,0.35);
}
#live-debate-panel h3 {
color: #bfdbfe;
}
.gr-button-primary {
background: linear-gradient(135deg, #1d4ed8, #2563eb) !important;
border: none !important;
}
.gr-button-primary:hover {
background: linear-gradient(135deg, #2563eb, #1d4ed8) !important;
}
"""
# ---------------------------------------------------------------------------
# Debate runtime classes
# ---------------------------------------------------------------------------
@dataclass
class DebateState:
topic: str
stance_a: str
stance_b: str
transcript: List[Tuple[str, str]] = field(default_factory=list)
class LLMAdapter:
"""Thin wrapper around the OpenAI SDK to simplify prompting."""
def __init__(self, config: AgentConfig):
self.config = config
self.client = load_client(config)
def complete(
self,
prompt: str,
*,
system: Optional[str] = None,
max_tokens: int = 512,
json_mode: bool = False,
) -> str:
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
params = dict(
model=self.config.model,
messages=messages,
temperature=self.config.temperature,
max_tokens=max_tokens,
)
if json_mode and self.config.supports_json:
params["response_format"] = {"type": "json_object"}
response = self.client.chat.completions.create(**params)
return extract_text(response)
class Debater:
def __init__(self, adapter: LLMAdapter, stance_label: str):
self.adapter = adapter
self.stance_label = stance_label
def argue(self, topic: str) -> str:
prompt = (
f"You are {self.adapter.config.name}, debating the topic:\n"
f"'{topic}'.\n\n"
f"Present a concise argument that {self.stance_label.lower()} "
f"the statement. Use at most 150 words. Provide clear reasoning "
f"and, if applicable, cite plausible evidence or examples."
)
return self.adapter.complete(prompt, max_tokens=300)
class Judge:
RUBRIC = [
"Clarity of the argument",
"Use of evidence or examples",
"Logical coherence",
"Persuasiveness and impact",
]
def __init__(self, adapter: LLMAdapter):
self.adapter = adapter
def evaluate(self, topic: str, argument_a: str, argument_b: str) -> Dict[str, object]:
rubric_text = "\n".join(f"- {item}" for item in self.RUBRIC)
prompt = (
"You are serving as an impartial debate judge.\n"
f"Topic: {topic}\n\n"
f"Argument from Debater A:\n{argument_a}\n\n"
f"Argument from Debater B:\n{argument_b}\n\n"
"Score each debater from 0-10 on the following criteria:\n"
f"{rubric_text}\n\n"
"Return a JSON object with this exact structure:\n"
'{\n'
' "winner": "A" or "B" or "Tie",\n'
' "reason": "brief justification",\n'
' "scores": [\n'
' {"criterion": "...", "debater_a": 0-10, "debater_b": 0-10, "notes": "optional"}\n'
" ]\n"
"}\n"
"Ensure the JSON is valid."
)
raw = self.adapter.complete(prompt, max_tokens=400, json_mode=True)
try:
data = json.loads(raw)
if "scores" not in data:
raise ValueError("scores missing")
return data
except Exception:
# Fallback: wrap raw text if parsing fails.
return {"winner": "Unknown", "reason": raw, "scores": []}
class Reporter:
def __init__(self, adapter: LLMAdapter):
self.adapter = adapter
def summarize(
self,
topic: str,
argument_a: str,
argument_b: str,
judge_result: Dict[str, object],
) -> str:
prompt = (
f"Summarize a single-round debate on '{topic}'.\n\n"
f"Debater A argued:\n{argument_a}\n\n"
f"Debater B argued:\n{argument_b}\n\n"
f"Judge verdict: {json.dumps(judge_result, ensure_ascii=False)}\n\n"
"Provide a short journalistic summary (max 200 words) highlighting "
"each side's key points and the judge's decision. Use neutral tone."
)
response = self.adapter.client.chat.completions.create(
model=self.adapter.config.model,
messages=[
{"role": "system", "content": "You are an impartial debate reporter."},
{"role": "user", "content": prompt},
],
temperature=self.adapter.config.temperature,
max_tokens=300,
**(
{"extra_body": {"reasoning_split": True}}
if getattr(self.adapter.client, "base_url", None)
and "minimax" in str(self.adapter.client.base_url).lower()
else {}
),
)
return extract_text(response)
# ---------------------------------------------------------------------------
# Debate pipeline + UI
# ---------------------------------------------------------------------------
debater_a = Debater(LLMAdapter(DEBATER_A_CONFIG), stance_label="supports")
debater_b = Debater(LLMAdapter(DEBATER_B_CONFIG), stance_label="opposes")
judge = Judge(LLMAdapter(JUDGE_CONFIG))
reporter = Reporter(LLMAdapter(REPORTER_CONFIG))
def format_transcript(transcript: List[Tuple[str, str]]) -> str:
"""Return markdown-formatted transcript."""
lines = []
for speaker, message in transcript:
lines.append(f"### {speaker}\n\n{message}\n")
return "\n".join(lines)
def run_debate(
topic: str, stance_a: str, stance_b: str
) -> Generator[Tuple[str, str, List[List[object]], str, str], None, None]:
"""Generator for Gradio to stream debate progress."""
if not topic.strip():
warning = "⚠️ Please enter a debate topic to get started."
yield warning, "", [], "", ""
return
state = DebateState(topic=topic.strip(), stance_a=stance_a, stance_b=stance_b)
state.transcript.append(
("Moderator", f"Welcome to the debate on **{state.topic}**!")
)
yield format_transcript(state.transcript), "Waiting for judge...", [], "", ""
argument_a = debater_a.argue(state.topic)
state.transcript.append((f"Debater A ({state.stance_a})", argument_a))
yield format_transcript(state.transcript), "Collecting arguments...", [], "", ""
argument_b = debater_b.argue(state.topic)
state.transcript.append((f"Debater B ({state.stance_b})", argument_b))
yield format_transcript(state.transcript), "Judge deliberating...", [], "", ""
judge_result = judge.evaluate(state.topic, argument_a, argument_b)
verdict_text = (
f"Winner: {judge_result.get('winner', 'Unknown')}\nReason: "
f"{judge_result.get('reason', 'No explanation provided.')}"
)
score_rows = [
[
entry.get("criterion", ""),
entry.get("debater_a", ""),
entry.get("debater_b", ""),
entry.get("notes", ""),
]
for entry in judge_result.get("scores", [])
]
judge_report_md = (
f"**Judge Verdict:** {judge_result.get('winner', 'Unknown')}\n\n"
f"{judge_result.get('reason', '')}"
)
yield (
format_transcript(state.transcript),
judge_report_md,
score_rows,
verdict_text,
format_transcript(state.transcript),
)
reporter_summary = reporter.summarize(
state.topic, argument_a, argument_b, judge_result
)
final_markdown = (
f"{judge_report_md}\n\n---\n\n"
f"**Reporter Summary**\n\n{reporter_summary}"
)
yield (
format_transcript(state.transcript),
final_markdown,
score_rows,
verdict_text,
format_transcript(state.transcript),
)
# ---------------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------------
with gr.Blocks(
title="LLM Debate Arena",
fill_width=True,
theme=THEME,
css=CUSTOM_CSS,
) as demo:
gr.Markdown(
"# 🔁 LLM Debate Arena\n"
"Configure two debating agents, watch their arguments in real time, and "
"review the judge's verdict plus a reporter summary."
)
with gr.Row():
topic_input = gr.Textbox(
label="Debate Topic",
placeholder="e.g., Should autonomous delivery robots be allowed in city centers?",
)
with gr.Row():
stance_a_input = gr.Textbox(
label="Debater A Stance",
value="Supports the statement",
)
stance_b_input = gr.Textbox(
label="Debater B Stance",
value="Opposes the statement",
)
run_button = gr.Button("Start Debate", variant="primary")
with gr.Tab("Live Debate"):
transcript_md = gr.Markdown(
"### Waiting for the debate to start...",
elem_id="live-debate-panel",
)
with gr.Tab("Judge's Report"):
judge_md = gr.Markdown("Judge verdict will appear here.")
score_table = gr.Dataframe(
headers=["Criterion", "Debater A", "Debater B", "Notes"],
datatype=["str", "number", "number", "str"],
interactive=False,
)
verdict_box = gr.Textbox(
label="Verdict Detail",
interactive=False,
)
transcript_box = gr.Textbox(
label="Full Transcript (for copying)",
interactive=False,
lines=10,
)
run_button.click(
fn=run_debate,
inputs=[topic_input, stance_a_input, stance_b_input],
outputs=[transcript_md, judge_md, score_table, verdict_box, transcript_box],
queue=True,
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=4).launch()