Merge branch 'main' of github.com:ed-donner/llm_engineering
This commit is contained in:
725
week4/community-contributions/philip/week4_EXERCISE.ipynb
Normal file
725
week4/community-contributions/philip/week4_EXERCISE.ipynb
Normal file
@@ -0,0 +1,725 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import json\n",
|
||||||
|
"import re\n",
|
||||||
|
"from typing import List, Dict, Optional\n",
|
||||||
|
"from dotenv import load_dotenv\n",
|
||||||
|
"from openai import OpenAI\n",
|
||||||
|
"import gradio as gr\n",
|
||||||
|
"from IPython.display import Markdown, display\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Groq API Key not set (optional)\n",
|
||||||
|
"OpenRouter API Key loaded (begins with sk-or-)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"load_dotenv(override=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# Ollama connection \n",
|
||||||
|
"ollama_url = \"http://localhost:11434/v1\"\n",
|
||||||
|
"ollama_client = OpenAI(api_key=\"ollama\", base_url=ollama_url)\n",
|
||||||
|
"\n",
|
||||||
|
"# Groq connection\n",
|
||||||
|
"groq_api_key = os.getenv('GROQ_API_KEY')\n",
|
||||||
|
"groq_url = \"https://api.groq.com/openai/v1\"\n",
|
||||||
|
"groq_client = None\n",
|
||||||
|
"if groq_api_key:\n",
|
||||||
|
" groq_client = OpenAI(api_key=groq_api_key, base_url=groq_url)\n",
|
||||||
|
" print(f\"Groq API Key loaded (begins with {groq_api_key[:4]})\")\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"Groq API Key not set (optional)\")\n",
|
||||||
|
"\n",
|
||||||
|
"# OpenRouter connection\n",
|
||||||
|
"openrouter_api_key = os.getenv('OPENROUTER_API_KEY')\n",
|
||||||
|
"openrouter_url = \"https://openrouter.ai/api/v1\"\n",
|
||||||
|
"openrouter_client = None\n",
|
||||||
|
"if openrouter_api_key:\n",
|
||||||
|
" openrouter_client = OpenAI(api_key=openrouter_api_key, base_url=openrouter_url)\n",
|
||||||
|
" print(f\"OpenRouter API Key loaded (begins with {openrouter_api_key[:6]})\")\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"OpenRouter API Key not set (optional)\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Configured 2 models\n",
|
||||||
|
"OpenRouter models available (perfect for limited storage demos!)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Open-source code models configuration\n",
|
||||||
|
"MODELS = {}\n",
|
||||||
|
"\n",
|
||||||
|
"if groq_client:\n",
|
||||||
|
" MODELS.update({\n",
|
||||||
|
" \"gpt-oss-20b-groq\": {\n",
|
||||||
|
" \"name\": \"GPT-OSS-20B (Groq)\",\n",
|
||||||
|
" \"client\": groq_client,\n",
|
||||||
|
" \"model\": \"gpt-oss:20b\",\n",
|
||||||
|
" \"description\": \"Cloud\"\n",
|
||||||
|
" },\n",
|
||||||
|
" \"gpt-oss-120b-groq\": {\n",
|
||||||
|
" \"name\": \"GPT-OSS-120B (Groq)\",\n",
|
||||||
|
" \"client\": groq_client,\n",
|
||||||
|
" \"model\": \"openai/gpt-oss-120b\",\n",
|
||||||
|
" \"description\": \"Cloud - Larger GPT-OSS\"\n",
|
||||||
|
" },\n",
|
||||||
|
" \"qwen2.5-coder-32b-groq\": {\n",
|
||||||
|
" \"name\": \"Qwen2.5-Coder 32B (Groq)\",\n",
|
||||||
|
" \"client\": groq_client,\n",
|
||||||
|
" \"model\": \"qwen/qwen2.5-coder-32b-instruct\",\n",
|
||||||
|
" \"description\": \"Cloud\"\n",
|
||||||
|
" },\n",
|
||||||
|
" })\n",
|
||||||
|
"\n",
|
||||||
|
"# OpenRouter models\n",
|
||||||
|
"if openrouter_client:\n",
|
||||||
|
" MODELS.update({\n",
|
||||||
|
" \"qwen-2.5-coder-32b-openrouter\": {\n",
|
||||||
|
" \"name\": \"Qwen2.5-Coder 32B (OpenRouter)\",\n",
|
||||||
|
" \"client\": openrouter_client,\n",
|
||||||
|
" \"model\": \"qwen/qwen-2.5-coder-32b-instruct\",\n",
|
||||||
|
" \"description\": \"Cloud - Perfect for demos, 50 req/day free\"\n",
|
||||||
|
" },\n",
|
||||||
|
" \"gpt-oss-20b-groq\": {\n",
|
||||||
|
" \"name\": \"GPT-OSS-20B\",\n",
|
||||||
|
" \"client\": openrouter_client,\n",
|
||||||
|
" \"model\": \"openai/gpt-oss-20b\",\n",
|
||||||
|
" \"description\": \"Cloud - OpenAI's open model, excellent for code!\"\n",
|
||||||
|
" },\n",
|
||||||
|
" })\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Configured {len(MODELS)} models\")\n",
|
||||||
|
"if openrouter_client:\n",
|
||||||
|
" print(\"OpenRouter models available (perfect for limited storage demos!)\")\n",
|
||||||
|
"if groq_client:\n",
|
||||||
|
" print(\"Groq models available (fast cloud inference!)\")\n",
|
||||||
|
"if \"qwen2.5-coder:7b\" in MODELS:\n",
|
||||||
|
" print(\"Ollama models available (unlimited local usage!)\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 22,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"BUG_DETECTION_SYSTEM_PROMPT = \"\"\"You are an expert code reviewer specializing in finding bugs, security vulnerabilities, and logic errors.\n",
|
||||||
|
"\n",
|
||||||
|
"Your task is to analyze Python code and identify issues. Return ONLY a valid JSON array with this exact format:\n",
|
||||||
|
"[{\n",
|
||||||
|
" \"severity\": \"critical|high|medium|low\",\n",
|
||||||
|
" \"line\": number,\n",
|
||||||
|
" \"issue\": \"brief description of the problem\",\n",
|
||||||
|
" \"suggestion\": \"specific fix recommendation\"\n",
|
||||||
|
"}]\n",
|
||||||
|
"\n",
|
||||||
|
"Be thorough but concise. Focus on real bugs and security issues.\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"IMPROVEMENTS_SYSTEM_PROMPT = \"\"\"You are a senior software engineer specializing in code quality and best practices.\n",
|
||||||
|
"\n",
|
||||||
|
"Analyze the Python code and suggest improvements for:\n",
|
||||||
|
"- Code readability and maintainability\n",
|
||||||
|
"- Performance optimizations\n",
|
||||||
|
"- Pythonic idioms and conventions\n",
|
||||||
|
"- Better error handling\n",
|
||||||
|
"\n",
|
||||||
|
"Return ONLY a JSON array:\n",
|
||||||
|
"[{\n",
|
||||||
|
" \"category\": \"readability|performance|style|error_handling\",\n",
|
||||||
|
" \"line\": number,\n",
|
||||||
|
" \"current\": \"current code snippet\",\n",
|
||||||
|
" \"improved\": \"improved code snippet\",\n",
|
||||||
|
" \"explanation\": \"why this is better\"\n",
|
||||||
|
"}]\n",
|
||||||
|
"\n",
|
||||||
|
"Only suggest meaningful improvements.\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"TEST_GENERATION_SYSTEM_PROMPT = \"\"\"You are an expert in writing comprehensive unit tests.\n",
|
||||||
|
"\n",
|
||||||
|
"Generate pytest unit tests for the given Python code. Include:\n",
|
||||||
|
"- Test cases for normal operation\n",
|
||||||
|
"- Edge cases and boundary conditions\n",
|
||||||
|
"- Error handling tests\n",
|
||||||
|
"- Tests for any bugs that were identified\n",
|
||||||
|
"\n",
|
||||||
|
"Return ONLY Python code with pytest tests. Include the original code at the top if needed.\n",
|
||||||
|
"Put the imports at the top of the file first.\n",
|
||||||
|
"Do not include explanations or markdown formatting.\"\"\"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def extract_json_from_response(text: str) -> List[Dict]:\n",
|
||||||
|
" \"\"\"Extract JSON array from model response, handling markdown code blocks.\"\"\"\n",
|
||||||
|
" # Remove markdown code blocks\n",
|
||||||
|
" text = re.sub(r'```json\\n?', '', text)\n",
|
||||||
|
" text = re.sub(r'```\\n?', '', text)\n",
|
||||||
|
" \n",
|
||||||
|
" # Try to find JSON array\n",
|
||||||
|
" json_match = re.search(r'\\[\\s*\\{.*\\}\\s*\\]', text, re.DOTALL)\n",
|
||||||
|
" if json_match:\n",
|
||||||
|
" try:\n",
|
||||||
|
" return json.loads(json_match.group())\n",
|
||||||
|
" except json.JSONDecodeError:\n",
|
||||||
|
" pass\n",
|
||||||
|
" \n",
|
||||||
|
" # Fallback: try parsing entire response\n",
|
||||||
|
" try:\n",
|
||||||
|
" return json.loads(text.strip())\n",
|
||||||
|
" except json.JSONDecodeError:\n",
|
||||||
|
" return []\n",
|
||||||
|
"\n",
|
||||||
|
"def detect_bugs(code: str, model_key: str) -> Dict:\n",
|
||||||
|
" \"\"\"Detect bugs and security issues in code.\"\"\"\n",
|
||||||
|
" model_config = MODELS[model_key]\n",
|
||||||
|
" client = model_config[\"client\"]\n",
|
||||||
|
" model_name = model_config[\"model\"]\n",
|
||||||
|
" \n",
|
||||||
|
" user_prompt = f\"Analyze this Python code for bugs and security issues:\\n\\n```python\\n{code}\\n```\"\n",
|
||||||
|
" \n",
|
||||||
|
" try:\n",
|
||||||
|
" response = client.chat.completions.create(\n",
|
||||||
|
" model=model_name,\n",
|
||||||
|
" messages=[\n",
|
||||||
|
" {\"role\": \"system\", \"content\": BUG_DETECTION_SYSTEM_PROMPT},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": user_prompt}\n",
|
||||||
|
" ],\n",
|
||||||
|
" temperature=0.1\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" content = response.choices[0].message.content\n",
|
||||||
|
" issues = extract_json_from_response(content)\n",
|
||||||
|
" \n",
|
||||||
|
" return {\n",
|
||||||
|
" \"model\": model_config[\"name\"],\n",
|
||||||
|
" \"issues\": issues,\n",
|
||||||
|
" \"raw_response\": content,\n",
|
||||||
|
" \"success\": True\n",
|
||||||
|
" }\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" return {\n",
|
||||||
|
" \"model\": model_config[\"name\"],\n",
|
||||||
|
" \"issues\": [],\n",
|
||||||
|
" \"error\": str(e),\n",
|
||||||
|
" \"success\": False\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
"def suggest_improvements(code: str, model_key: str) -> Dict:\n",
|
||||||
|
" \"\"\"Suggest code improvements and best practices.\"\"\"\n",
|
||||||
|
" model_config = MODELS[model_key]\n",
|
||||||
|
" client = model_config[\"client\"]\n",
|
||||||
|
" model_name = model_config[\"model\"]\n",
|
||||||
|
" \n",
|
||||||
|
" user_prompt = f\"Suggest improvements for this Python code:\\n\\n```python\\n{code}\\n```\"\n",
|
||||||
|
" \n",
|
||||||
|
" try:\n",
|
||||||
|
" response = client.chat.completions.create(\n",
|
||||||
|
" model=model_name,\n",
|
||||||
|
" messages=[\n",
|
||||||
|
" {\"role\": \"system\", \"content\": IMPROVEMENTS_SYSTEM_PROMPT},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": user_prompt}\n",
|
||||||
|
" ],\n",
|
||||||
|
" temperature=0.2\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" content = response.choices[0].message.content\n",
|
||||||
|
" improvements = extract_json_from_response(content)\n",
|
||||||
|
" \n",
|
||||||
|
" return {\n",
|
||||||
|
" \"model\": model_config[\"name\"],\n",
|
||||||
|
" \"improvements\": improvements,\n",
|
||||||
|
" \"raw_response\": content,\n",
|
||||||
|
" \"success\": True\n",
|
||||||
|
" }\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" return {\n",
|
||||||
|
" \"model\": model_config[\"name\"],\n",
|
||||||
|
" \"improvements\": [],\n",
|
||||||
|
" \"error\": str(e),\n",
|
||||||
|
" \"success\": False\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
"def generate_tests(code: str, bugs: List[Dict], model_key: str) -> Dict:\n",
|
||||||
|
" \"\"\"Generate unit tests for the code.\"\"\"\n",
|
||||||
|
" model_config = MODELS[model_key]\n",
|
||||||
|
" client = model_config[\"client\"]\n",
|
||||||
|
" model_name = model_config[\"model\"]\n",
|
||||||
|
" \n",
|
||||||
|
" bugs_context = \"\"\n",
|
||||||
|
" if bugs:\n",
|
||||||
|
" bugs_context = f\"\\n\\nNote: The following bugs were identified:\\n\" + \"\\n\".join([f\"- Line {b.get('line', '?')}: {b.get('issue', '')}\" for b in bugs])\n",
|
||||||
|
" \n",
|
||||||
|
" user_prompt = f\"Generate pytest unit tests for this Python code:{bugs_context}\\n\\n```python\\n{code}\\n```\"\n",
|
||||||
|
" \n",
|
||||||
|
" try:\n",
|
||||||
|
" response = client.chat.completions.create(\n",
|
||||||
|
" model=model_name,\n",
|
||||||
|
" messages=[\n",
|
||||||
|
" {\"role\": \"system\", \"content\": TEST_GENERATION_SYSTEM_PROMPT},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": user_prompt}\n",
|
||||||
|
" ],\n",
|
||||||
|
" temperature=0.3\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" content = response.choices[0].message.content\n",
|
||||||
|
" # Remove markdown code blocks if present\n",
|
||||||
|
" test_code = re.sub(r'```python\\n?', '', content)\n",
|
||||||
|
" test_code = re.sub(r'```\\n?', '', test_code)\n",
|
||||||
|
" \n",
|
||||||
|
" return {\n",
|
||||||
|
" \"model\": model_config[\"name\"],\n",
|
||||||
|
" \"test_code\": test_code.strip(),\n",
|
||||||
|
" \"raw_response\": content,\n",
|
||||||
|
" \"success\": True\n",
|
||||||
|
" }\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" return {\n",
|
||||||
|
" \"model\": model_config[\"name\"],\n",
|
||||||
|
" \"test_code\": \"\",\n",
|
||||||
|
" \"error\": str(e),\n",
|
||||||
|
" \"success\": False\n",
|
||||||
|
" }\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def format_bugs_output(result: Dict) -> str:\n",
|
||||||
|
" \"\"\"Format bug detection results for display.\"\"\"\n",
|
||||||
|
" if not result.get(\"success\"):\n",
|
||||||
|
" return f\"**Error with {result['model']}:** {result.get('error', 'Unknown error')}\"\n",
|
||||||
|
" \n",
|
||||||
|
" issues = result.get(\"issues\", [])\n",
|
||||||
|
" if not issues:\n",
|
||||||
|
" return f\"✅ **{result['model']}**: No issues found. Code looks good!\"\n",
|
||||||
|
" \n",
|
||||||
|
" output = [f\"**{result['model']}** - Found {len(issues)} issue(s):\\n\"]\n",
|
||||||
|
" \n",
|
||||||
|
" severity_order = {\"critical\": 0, \"high\": 1, \"medium\": 2, \"low\": 3}\n",
|
||||||
|
" sorted_issues = sorted(issues, key=lambda x: severity_order.get(x.get(\"severity\", \"low\"), 3))\n",
|
||||||
|
" \n",
|
||||||
|
" for issue in sorted_issues:\n",
|
||||||
|
" severity = issue.get(\"severity\", \"unknown\").upper()\n",
|
||||||
|
" line = issue.get(\"line\", \"?\")\n",
|
||||||
|
" issue_desc = issue.get(\"issue\", \"\")\n",
|
||||||
|
" suggestion = issue.get(\"suggestion\", \"\")\n",
|
||||||
|
" \n",
|
||||||
|
" severity_emoji = {\n",
|
||||||
|
" \"CRITICAL\": \"🔴\",\n",
|
||||||
|
" \"HIGH\": \"🟠\",\n",
|
||||||
|
" \"MEDIUM\": \"🟡\",\n",
|
||||||
|
" \"LOW\": \"🔵\"\n",
|
||||||
|
" }.get(severity, \"⚪\")\n",
|
||||||
|
" \n",
|
||||||
|
" output.append(f\"{severity_emoji} **{severity}** (Line {line}): {issue_desc}\")\n",
|
||||||
|
" if suggestion:\n",
|
||||||
|
" output.append(f\" 💡 *Fix:* {suggestion}\")\n",
|
||||||
|
" output.append(\"\")\n",
|
||||||
|
" \n",
|
||||||
|
" return \"\\n\".join(output)\n",
|
||||||
|
"\n",
|
||||||
|
"def format_improvements_output(result: Dict) -> str:\n",
|
||||||
|
" \"\"\"Format improvement suggestions for display.\"\"\"\n",
|
||||||
|
" if not result.get(\"success\"):\n",
|
||||||
|
" return f\"**Error with {result['model']}:** {result.get('error', 'Unknown error')}\"\n",
|
||||||
|
" \n",
|
||||||
|
" improvements = result.get(\"improvements\", [])\n",
|
||||||
|
" if not improvements:\n",
|
||||||
|
" return f\"✅ **{result['model']}**: Code follows best practices. No major improvements needed!\"\n",
|
||||||
|
" \n",
|
||||||
|
" output = [f\"**{result['model']}** - {len(improvements)} suggestion(s):\\n\"]\n",
|
||||||
|
" \n",
|
||||||
|
" for imp in improvements:\n",
|
||||||
|
" category = imp.get(\"category\", \"general\").replace(\"_\", \" \").title()\n",
|
||||||
|
" line = imp.get(\"line\", \"?\")\n",
|
||||||
|
" current = imp.get(\"current\", \"\")\n",
|
||||||
|
" improved = imp.get(\"improved\", \"\")\n",
|
||||||
|
" explanation = imp.get(\"explanation\", \"\")\n",
|
||||||
|
" \n",
|
||||||
|
" output.append(f\"\\n📝 **{category}** (Line {line}):\")\n",
|
||||||
|
" if current and improved:\n",
|
||||||
|
" output.append(f\" Before: `{current[:60]}{'...' if len(current) > 60 else ''}`\")\n",
|
||||||
|
" output.append(f\" After: `{improved[:60]}{'...' if len(improved) > 60 else ''}`\")\n",
|
||||||
|
" if explanation:\n",
|
||||||
|
" output.append(f\" 💡 {explanation}\")\n",
|
||||||
|
" \n",
|
||||||
|
" return \"\\n\".join(output)\n",
|
||||||
|
"\n",
|
||||||
|
"def format_tests_output(result: Dict) -> str:\n",
|
||||||
|
" \"\"\"Format test generation results for display.\"\"\"\n",
|
||||||
|
" if not result.get(\"success\"):\n",
|
||||||
|
" return f\"**Error with {result['model']}:** {result.get('error', 'Unknown error')}\"\n",
|
||||||
|
" \n",
|
||||||
|
" test_code = result.get(\"test_code\", \"\")\n",
|
||||||
|
" if not test_code:\n",
|
||||||
|
" return f\"⚠️ **{result['model']}**: No tests generated.\"\n",
|
||||||
|
" \n",
|
||||||
|
" return test_code\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def review_code(code: str, model_key: str, include_tests: bool = True) -> tuple:\n",
|
||||||
|
" \"\"\"Main function to perform complete code review.\"\"\"\n",
|
||||||
|
" if not code.strip():\n",
|
||||||
|
" return \"Please provide code to review.\", \"\", \"\"\n",
|
||||||
|
" \n",
|
||||||
|
" # Detect bugs\n",
|
||||||
|
" bugs_result = detect_bugs(code, model_key)\n",
|
||||||
|
" bugs_output = format_bugs_output(bugs_result)\n",
|
||||||
|
" bugs_issues = bugs_result.get(\"issues\", [])\n",
|
||||||
|
" \n",
|
||||||
|
" # Suggest improvements\n",
|
||||||
|
" improvements_result = suggest_improvements(code, model_key)\n",
|
||||||
|
" improvements_output = format_improvements_output(improvements_result)\n",
|
||||||
|
" \n",
|
||||||
|
" # Generate tests\n",
|
||||||
|
" tests_output = \"\"\n",
|
||||||
|
" if include_tests:\n",
|
||||||
|
" tests_result = generate_tests(code, bugs_issues, model_key)\n",
|
||||||
|
" tests_output = format_tests_output(tests_result)\n",
|
||||||
|
" \n",
|
||||||
|
" return bugs_output, improvements_output, tests_output\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def compare_models(code: str, model_keys: List[str]) -> str:\n",
|
||||||
|
" \"\"\"Compare multiple models on the same code.\"\"\"\n",
|
||||||
|
" if not code.strip():\n",
|
||||||
|
" return \"Please provide code to review.\"\n",
|
||||||
|
" \n",
|
||||||
|
" results = []\n",
|
||||||
|
" all_issues = []\n",
|
||||||
|
" \n",
|
||||||
|
" for model_key in model_keys:\n",
|
||||||
|
" result = detect_bugs(code, model_key)\n",
|
||||||
|
" results.append(result)\n",
|
||||||
|
" if result.get(\"success\"):\n",
|
||||||
|
" all_issues.extend(result.get(\"issues\", []))\n",
|
||||||
|
" \n",
|
||||||
|
" # Build comparison output\n",
|
||||||
|
" output = [\"# Model Comparison Results\\n\"]\n",
|
||||||
|
" \n",
|
||||||
|
" for result in results:\n",
|
||||||
|
" model_name = result[\"model\"]\n",
|
||||||
|
" issues = result.get(\"issues\", [])\n",
|
||||||
|
" success = result.get(\"success\", False)\n",
|
||||||
|
" \n",
|
||||||
|
" if success:\n",
|
||||||
|
" output.append(f\"\\n**{model_name}**: Found {len(issues)} issue(s)\")\n",
|
||||||
|
" if issues:\n",
|
||||||
|
" severity_counts = {}\n",
|
||||||
|
" for issue in issues:\n",
|
||||||
|
" sev = issue.get(\"severity\", \"low\")\n",
|
||||||
|
" severity_counts[sev] = severity_counts.get(sev, 0) + 1\n",
|
||||||
|
" output.append(f\" Breakdown: {dict(severity_counts)}\")\n",
|
||||||
|
" else:\n",
|
||||||
|
" output.append(f\"\\n**{model_name}**: Error - {result.get('error', 'Unknown')}\")\n",
|
||||||
|
" \n",
|
||||||
|
" # Find consensus issues (found by multiple models)\n",
|
||||||
|
" if len(results) > 1:\n",
|
||||||
|
" issue_signatures = {}\n",
|
||||||
|
" for result in results:\n",
|
||||||
|
" if result.get(\"success\"):\n",
|
||||||
|
" for issue in result.get(\"issues\", []):\n",
|
||||||
|
" # Create signature from line and issue description\n",
|
||||||
|
" sig = f\"{issue.get('line')}-{issue.get('issue', '')[:50]}\"\n",
|
||||||
|
" if sig not in issue_signatures:\n",
|
||||||
|
" issue_signatures[sig] = []\n",
|
||||||
|
" issue_signatures[sig].append(result[\"model\"])\n",
|
||||||
|
" \n",
|
||||||
|
" consensus = [sig for sig, models in issue_signatures.items() if len(models) > 1]\n",
|
||||||
|
" if consensus:\n",
|
||||||
|
" output.append(f\"\\n\\n **Consensus Issues**: {len(consensus)} issue(s) identified by multiple models\")\n",
|
||||||
|
" \n",
|
||||||
|
" return \"\\n\".join(output)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Gradio UI\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"* Running on local URL: http://127.0.0.1:7884\n",
|
||||||
|
"* To create a public link, set `share=True` in `launch()`.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div><iframe src=\"http://127.0.0.1:7884/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<IPython.core.display.HTML object>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": []
|
||||||
|
},
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Traceback (most recent call last):\n",
|
||||||
|
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\queueing.py\", line 745, in process_events\n",
|
||||||
|
" response = await route_utils.call_process_api(\n",
|
||||||
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||||
|
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\route_utils.py\", line 354, in call_process_api\n",
|
||||||
|
" output = await app.get_blocks().process_api(\n",
|
||||||
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||||
|
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\blocks.py\", line 2116, in process_api\n",
|
||||||
|
" result = await self.call_function(\n",
|
||||||
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||||
|
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\blocks.py\", line 1623, in call_function\n",
|
||||||
|
" prediction = await anyio.to_thread.run_sync( # type: ignore\n",
|
||||||
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||||
|
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\anyio\\to_thread.py\", line 56, in run_sync\n",
|
||||||
|
" return await get_async_backend().run_sync_in_worker_thread(\n",
|
||||||
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||||
|
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 2485, in run_sync_in_worker_thread\n",
|
||||||
|
" return await future\n",
|
||||||
|
" ^^^^^^^^^^^^\n",
|
||||||
|
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 976, in run\n",
|
||||||
|
" result = context.run(func, *args)\n",
|
||||||
|
" ^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||||
|
" File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\utils.py\", line 915, in wrapper\n",
|
||||||
|
" response = f(*args, **kwargs)\n",
|
||||||
|
" ^^^^^^^^^^^^^^^^^^\n",
|
||||||
|
" File \"C:\\Users\\Philo Baba\\AppData\\Local\\Temp\\ipykernel_43984\\2272281361.py\", line 7, in review_code\n",
|
||||||
|
" bugs_result = detect_bugs(code, model_key)\n",
|
||||||
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||||
|
" File \"C:\\Users\\Philo Baba\\AppData\\Local\\Temp\\ipykernel_43984\\638705449.py\", line 23, in detect_bugs\n",
|
||||||
|
" model_config = MODELS[model_key]\n",
|
||||||
|
" ~~~~~~^^^^^^^^^^^\n",
|
||||||
|
"KeyError: 'deepseek-coder-v2-openrouter'\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Example buggy code for testing\n",
|
||||||
|
"EXAMPLE_CODE = '''def divide_numbers(a, b):\n",
|
||||||
|
" return a / b\n",
|
||||||
|
"\n",
|
||||||
|
"def process_user_data(user_input):\n",
|
||||||
|
" # Missing input validation\n",
|
||||||
|
" result = eval(user_input)\n",
|
||||||
|
" return result\n",
|
||||||
|
"\n",
|
||||||
|
"def get_user_by_id(user_id):\n",
|
||||||
|
" # SQL injection vulnerability\n",
|
||||||
|
" query = f\"SELECT * FROM users WHERE id = {user_id}\"\n",
|
||||||
|
" return query\n",
|
||||||
|
"\n",
|
||||||
|
"def calculate_average(numbers):\n",
|
||||||
|
" total = sum(numbers)\n",
|
||||||
|
" return total / len(numbers) # Potential division by zero\n",
|
||||||
|
"'''\n",
|
||||||
|
"\n",
|
||||||
|
"def create_ui():\n",
|
||||||
|
" with gr.Blocks(title=\"AI Code Review Assistant\", theme=gr.themes.Soft()) as demo:\n",
|
||||||
|
" gr.Markdown(\"\"\"\n",
|
||||||
|
" # 🔍 AI-Powered Code Review Assistant\n",
|
||||||
|
" \n",
|
||||||
|
" Review your Python code using open-source AI models. Detect bugs, get improvement suggestions, and generate unit tests.\n",
|
||||||
|
" \"\"\")\n",
|
||||||
|
" \n",
|
||||||
|
" with gr.Row():\n",
|
||||||
|
" with gr.Column(scale=2):\n",
|
||||||
|
" code_input = gr.Code(\n",
|
||||||
|
" label=\"Python Code to Review\",\n",
|
||||||
|
" value=EXAMPLE_CODE,\n",
|
||||||
|
" language=\"python\",\n",
|
||||||
|
" lines=20\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" with gr.Row():\n",
|
||||||
|
" model_selector = gr.Dropdown(\n",
|
||||||
|
" choices=list(MODELS.keys()),\n",
|
||||||
|
" value=list(MODELS.keys())[0],\n",
|
||||||
|
" label=\"Select Model\",\n",
|
||||||
|
" info=\"Choose an open-source code model\"\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" include_tests = gr.Checkbox(\n",
|
||||||
|
" label=\"Generate Tests\",\n",
|
||||||
|
" value=True\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" with gr.Row():\n",
|
||||||
|
" review_btn = gr.Button(\"🔍 Review Code\", variant=\"primary\", scale=2)\n",
|
||||||
|
" compare_btn = gr.Button(\"📊 Compare Models\", variant=\"secondary\", scale=1)\n",
|
||||||
|
" \n",
|
||||||
|
" with gr.Column(scale=3):\n",
|
||||||
|
" with gr.Tabs() as tabs:\n",
|
||||||
|
" with gr.Tab(\"🐛 Bug Detection\"):\n",
|
||||||
|
" bugs_output = gr.Markdown(value=\"Select a model and click 'Review Code' to analyze your code.\")\n",
|
||||||
|
" \n",
|
||||||
|
" with gr.Tab(\"✨ Improvements\"):\n",
|
||||||
|
" improvements_output = gr.Markdown(value=\"Get suggestions for code improvements and best practices.\")\n",
|
||||||
|
" \n",
|
||||||
|
" with gr.Tab(\"🧪 Unit Tests\"):\n",
|
||||||
|
" tests_output = gr.Code(\n",
|
||||||
|
" label=\"Generated Test Code\",\n",
|
||||||
|
" language=\"python\",\n",
|
||||||
|
" lines=25\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" with gr.Tab(\"📊 Comparison\"):\n",
|
||||||
|
" comparison_output = gr.Markdown(value=\"Compare multiple models side-by-side.\")\n",
|
||||||
|
" \n",
|
||||||
|
" # Event handlers\n",
|
||||||
|
" review_btn.click(\n",
|
||||||
|
" fn=review_code,\n",
|
||||||
|
" inputs=[code_input, model_selector, include_tests],\n",
|
||||||
|
" outputs=[bugs_output, improvements_output, tests_output]\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" def compare_selected_models(code):\n",
|
||||||
|
" # Compare first 3 models by default\n",
|
||||||
|
" model_keys = list(MODELS.keys())[:3]\n",
|
||||||
|
" return compare_models(code, model_keys)\n",
|
||||||
|
" \n",
|
||||||
|
" compare_btn.click(\n",
|
||||||
|
" fn=compare_selected_models,\n",
|
||||||
|
" inputs=[code_input],\n",
|
||||||
|
" outputs=[comparison_output]\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" gr.Examples(\n",
|
||||||
|
" examples=[\n",
|
||||||
|
" [EXAMPLE_CODE],\n",
|
||||||
|
" [\"\"\"def fibonacci(n):\n",
|
||||||
|
" if n <= 1:\n",
|
||||||
|
" return n\n",
|
||||||
|
" return fibonacci(n-1) + fibonacci(n-2)\n",
|
||||||
|
"\"\"\"],\n",
|
||||||
|
" [\"\"\"def parse_config(file_path):\n",
|
||||||
|
" with open(file_path) as f:\n",
|
||||||
|
" return eval(f.read())\n",
|
||||||
|
"\"\"\"]\n",
|
||||||
|
" ],\n",
|
||||||
|
" inputs=[code_input]\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" return demo\n",
|
||||||
|
"\n",
|
||||||
|
"demo = create_ui()\n",
|
||||||
|
"demo.launch(inbrowser=True, share=False)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"**Qwen2.5-Coder 32B (OpenRouter)** - Found 2 issue(s):\n",
|
||||||
|
"\n",
|
||||||
|
"🔴 **CRITICAL** (Line 2): No division by zero protection\n",
|
||||||
|
" 💡 *Fix:* Add a check for b == 0 and raise ValueError or handle ZeroDivisionError\n",
|
||||||
|
"\n",
|
||||||
|
"🟡 **MEDIUM** (Line 2): No input validation for numeric types\n",
|
||||||
|
" 💡 *Fix:* Add type checking to ensure a and b are numbers (int/float)\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Test with a simple example\n",
|
||||||
|
"test_code = \"\"\"def divide(a, b):\n",
|
||||||
|
" return a / b\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Test bug detection\n",
|
||||||
|
"result = detect_bugs(test_code, list(MODELS.keys())[0])\n",
|
||||||
|
"print(format_bugs_output(result))\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
454
week7/community_contributions/emmy/product_price_estimator.ipynb
Normal file
454
week7/community_contributions/emmy/product_price_estimator.ipynb
Normal file
@@ -0,0 +1,454 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "275415f0",
|
||||||
|
"metadata": {
|
||||||
|
"id": "275415f0"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# pip installs\n",
|
||||||
|
"\n",
|
||||||
|
"!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
|
||||||
|
"!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "535bd9de",
|
||||||
|
"metadata": {
|
||||||
|
"id": "535bd9de"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# imports\n",
|
||||||
|
"\n",
|
||||||
|
"import os\n",
|
||||||
|
"import re\n",
|
||||||
|
"import math\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from google.colab import userdata\n",
|
||||||
|
"from huggingface_hub import login\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, mean_absolute_percentage_error\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"import transformers\n",
|
||||||
|
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n",
|
||||||
|
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
||||||
|
"from datetime import datetime\n",
|
||||||
|
"from peft import PeftModel\n",
|
||||||
|
"import matplotlib.pyplot as plt"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "fc58234a",
|
||||||
|
"metadata": {
|
||||||
|
"id": "fc58234a"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Constants\n",
|
||||||
|
"\n",
|
||||||
|
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
|
||||||
|
"PROJECT_NAME = \"pricer\"\n",
|
||||||
|
"HF_USER = \"ed-donner\"\n",
|
||||||
|
"RUN_NAME = \"2024-09-13_13.04.39\"\n",
|
||||||
|
"PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
|
||||||
|
"REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n",
|
||||||
|
"FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"DATASET_NAME = f\"{HF_USER}/home-data\"\n",
|
||||||
|
"\n",
|
||||||
|
"QUANT_4_BIT = True\n",
|
||||||
|
"top_K = 6\n",
|
||||||
|
"\n",
|
||||||
|
"%matplotlib inline\n",
|
||||||
|
"\n",
|
||||||
|
"# Used for writing to output in color\n",
|
||||||
|
"\n",
|
||||||
|
"GREEN = \"\\033[92m\"\n",
|
||||||
|
"YELLOW = \"\\033[93m\"\n",
|
||||||
|
"RED = \"\\033[91m\"\n",
|
||||||
|
"RESET = \"\\033[0m\"\n",
|
||||||
|
"COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0145ad8a",
|
||||||
|
"metadata": {
|
||||||
|
"id": "0145ad8a"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Log in to HuggingFace\n",
|
||||||
|
"\n",
|
||||||
|
"hf_token = userdata.get('HF_TOKEN')\n",
|
||||||
|
"login(hf_token, add_to_git_credential=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6919506e",
|
||||||
|
"metadata": {
|
||||||
|
"id": "6919506e"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset = load_dataset(DATASET_NAME)\n",
|
||||||
|
"train = dataset['train']\n",
|
||||||
|
"test = dataset['test']\n",
|
||||||
|
"len(train), len(test)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ea79cde1",
|
||||||
|
"metadata": {
|
||||||
|
"id": "ea79cde1"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"if QUANT_4_BIT:\n",
|
||||||
|
" quant_config = BitsAndBytesConfig(\n",
|
||||||
|
" load_in_4bit=True,\n",
|
||||||
|
" bnb_4bit_use_double_quant=True,\n",
|
||||||
|
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
||||||
|
" bnb_4bit_quant_type=\"nf4\"\n",
|
||||||
|
" )\n",
|
||||||
|
"else:\n",
|
||||||
|
" quant_config = BitsAndBytesConfig(\n",
|
||||||
|
" load_in_8bit=True,\n",
|
||||||
|
" )"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ef108f8d",
|
||||||
|
"metadata": {
|
||||||
|
"id": "ef108f8d"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load the Tokenizer and the Model\n",
|
||||||
|
"\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
|
||||||
|
"tokenizer.pad_token = tokenizer.eos_token\n",
|
||||||
|
"tokenizer.padding_side = \"right\"\n",
|
||||||
|
"\n",
|
||||||
|
"base_model = AutoModelForCausalLM.from_pretrained(\n",
|
||||||
|
" BASE_MODEL,\n",
|
||||||
|
" quantization_config=quant_config,\n",
|
||||||
|
" device_map=\"auto\",\n",
|
||||||
|
")\n",
|
||||||
|
"base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
|
||||||
|
"\n",
|
||||||
|
"# Load the fine-tuned model with PEFT\n",
|
||||||
|
"if REVISION:\n",
|
||||||
|
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n",
|
||||||
|
"else:\n",
|
||||||
|
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n",
|
||||||
|
"\n",
|
||||||
|
"fine_tuned_model.eval()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7f3c4176",
|
||||||
|
"metadata": {
|
||||||
|
"id": "7f3c4176"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def extract_price(s):\n",
|
||||||
|
" if \"Price is $\" in s:\n",
|
||||||
|
" contents = s.split(\"Price is $\")[1]\n",
|
||||||
|
" contents = contents.replace(',','')\n",
|
||||||
|
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", contents)\n",
|
||||||
|
" return float(match.group()) if match else 0\n",
|
||||||
|
" return 0"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "436fa29a",
|
||||||
|
"metadata": {
|
||||||
|
"id": "436fa29a"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Original prediction function takes the most likely next token\n",
|
||||||
|
"\n",
|
||||||
|
"def model_predict(prompt):\n",
|
||||||
|
" set_seed(42)\n",
|
||||||
|
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
|
||||||
|
" attention_mask = torch.ones(inputs.shape, device=\"cuda\")\n",
|
||||||
|
" outputs = fine_tuned_model.generate(inputs, attention_mask=attention_mask, max_new_tokens=3, num_return_sequences=1)\n",
|
||||||
|
" response = tokenizer.decode(outputs[0])\n",
|
||||||
|
" return extract_price(response)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a666dab6",
|
||||||
|
"metadata": {
|
||||||
|
"id": "a666dab6"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def improved_model_predict(prompt, device=\"cuda\", top_k=3):\n",
|
||||||
|
" set_seed(42)\n",
|
||||||
|
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
|
||||||
|
" attention_mask = torch.ones_like(inputs)\n",
|
||||||
|
"\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n",
|
||||||
|
" next_token_logits = outputs.logits[:, -1, :].to(\"cpu\")\n",
|
||||||
|
"\n",
|
||||||
|
" probs = F.softmax(next_token_logits, dim=-1)\n",
|
||||||
|
" top_prob, top_token_id = probs.topk(top_k)\n",
|
||||||
|
"\n",
|
||||||
|
" prices, weights = [], []\n",
|
||||||
|
" for weight, token_id in zip(top_prob[0], top_token_id[0]):\n",
|
||||||
|
" token_text = tokenizer.decode(token_id)\n",
|
||||||
|
" try:\n",
|
||||||
|
" value = float(token_text)\n",
|
||||||
|
" except ValueError:\n",
|
||||||
|
" continue\n",
|
||||||
|
" prices.append(value)\n",
|
||||||
|
" weights.append(weight.item())\n",
|
||||||
|
"\n",
|
||||||
|
" if not prices:\n",
|
||||||
|
" return 0.0, 0.0, 0.0 # price, confidence, spread\n",
|
||||||
|
"\n",
|
||||||
|
" total = sum(weights)\n",
|
||||||
|
" normalized = [w / total for w in weights]\n",
|
||||||
|
" weighted_price = sum(p * w for p, w in zip(prices, normalized))\n",
|
||||||
|
" variance = sum(w * (p - weighted_price) ** 2 for p, w in zip(prices, normalized))\n",
|
||||||
|
" confidence = 1 - min(variance / (weighted_price + 1e-6), 1)\n",
|
||||||
|
" spread = max(prices) - min(prices)\n",
|
||||||
|
" return weighted_price, confidence, spread"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"!pip install -q gradio>=4.0\n",
|
||||||
|
"import gradio as gr\n",
|
||||||
|
"\n",
|
||||||
|
"def format_prediction(description):\n",
|
||||||
|
" price, confidence, spread = improved_model_predict(description)\n",
|
||||||
|
" return (\n",
|
||||||
|
" f\"${price:,.2f}\",\n",
|
||||||
|
" f\"{confidence*100:.1f}%\",\n",
|
||||||
|
" f\"${spread:,.2f}\",\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"demo = gr.Interface(\n",
|
||||||
|
" fn=format_prediction,\n",
|
||||||
|
" inputs=gr.Textbox(lines=10, label=\"Product Description\"),\n",
|
||||||
|
" outputs=[\n",
|
||||||
|
" gr.Textbox(label=\"Estimated Price\"),\n",
|
||||||
|
" gr.Textbox(label=\"Confidence\"),\n",
|
||||||
|
" gr.Textbox(label=\"Token Spread\"),\n",
|
||||||
|
" ],\n",
|
||||||
|
" title=\"Open-Source Product Price Estimator\",\n",
|
||||||
|
" description=\"Paste a cleaned product blurb. The model returns a price estimate, a confidence score based on top-k token dispersion, and the spread between top predictions.\",\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"demo.launch(share=True)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "Km4tHaeQyMoW"
|
||||||
|
},
|
||||||
|
"id": "Km4tHaeQyMoW",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9664c4c7",
|
||||||
|
"metadata": {
|
||||||
|
"id": "9664c4c7"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"class Tester:\n",
|
||||||
|
"\n",
|
||||||
|
" def __init__(self, predictor, data, title=None, show_progress=True):\n",
|
||||||
|
" self.predictor = predictor\n",
|
||||||
|
" self.data = data\n",
|
||||||
|
" self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
|
||||||
|
" self.size = len(data)\n",
|
||||||
|
" self.guesses, self.truths, self.errors, self.rel_errors, self.sles, self.colors = [], [], [], [], [], []\n",
|
||||||
|
" self.show_progress = show_progress\n",
|
||||||
|
"\n",
|
||||||
|
" def color_for(self, error, truth):\n",
|
||||||
|
" if error < 40 or error / truth < 0.2:\n",
|
||||||
|
" return \"green\"\n",
|
||||||
|
" elif error < 80 or error / truth < 0.4:\n",
|
||||||
|
" return \"orange\"\n",
|
||||||
|
" else:\n",
|
||||||
|
" return \"red\"\n",
|
||||||
|
"\n",
|
||||||
|
" def run_datapoint(self, i):\n",
|
||||||
|
" datapoint = self.data[i]\n",
|
||||||
|
" guess = self.predictor(datapoint[\"text\"])\n",
|
||||||
|
" truth = datapoint[\"price\"]\n",
|
||||||
|
"\n",
|
||||||
|
" error = guess - truth\n",
|
||||||
|
" abs_error = abs(error)\n",
|
||||||
|
" rel_error = abs_error / truth if truth != 0 else 0\n",
|
||||||
|
" log_error = math.log(truth + 1) - math.log(guess + 1)\n",
|
||||||
|
" sle = log_error ** 2\n",
|
||||||
|
" color = self.color_for(abs_error, truth)\n",
|
||||||
|
"\n",
|
||||||
|
" title = (datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\") if \"\\n\\n\" in datapoint[\"text\"] else datapoint[\"text\"][:20]\n",
|
||||||
|
" self.guesses.append(guess)\n",
|
||||||
|
" self.truths.append(truth)\n",
|
||||||
|
" self.errors.append(error)\n",
|
||||||
|
" self.rel_errors.append(rel_error)\n",
|
||||||
|
" self.sles.append(sle)\n",
|
||||||
|
" self.colors.append(color)\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} \"\n",
|
||||||
|
" f\"Error: ${abs_error:,.2f} RelErr: {rel_error*100:.1f}% SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
|
||||||
|
"\n",
|
||||||
|
" def chart_all(self, chart_title):\n",
|
||||||
|
" \"\"\"Compact version: 4 performance charts in one grid.\"\"\"\n",
|
||||||
|
" t, g = np.array(self.truths), np.array(self.guesses)\n",
|
||||||
|
" rel_err, abs_err = np.array(self.rel_errors) * 100, np.abs(np.array(self.errors))\n",
|
||||||
|
"\n",
|
||||||
|
" fig, axs = plt.subplots(2, 2, figsize=(14, 10))\n",
|
||||||
|
" fig.suptitle(f\"Performance Dashboard — {chart_title}\", fontsize=16, fontweight=\"bold\")\n",
|
||||||
|
"\n",
|
||||||
|
" # Scatter plot\n",
|
||||||
|
" max_val = max(t.max(), g.max()) * 1.05\n",
|
||||||
|
" axs[1, 1].plot([0, max_val], [0, max_val], \"b--\", alpha=0.6)\n",
|
||||||
|
" axs[1, 1].scatter(t, g, s=20, c=self.colors, alpha=0.6)\n",
|
||||||
|
" axs[1, 1].set_title(\"Predictions vs Ground Truth\")\n",
|
||||||
|
" axs[1, 1].set_xlabel(\"True Price ($)\")\n",
|
||||||
|
" axs[1, 1].set_ylabel(\"Predicted ($)\")\n",
|
||||||
|
"\n",
|
||||||
|
" # Accuracy by price range\n",
|
||||||
|
" bins = np.linspace(t.min(), t.max(), 6)\n",
|
||||||
|
" labels = [f\"${bins[i]:.0f}–${bins[i+1]:.0f}\" for i in range(len(bins)-1)]\n",
|
||||||
|
" inds = np.digitize(t, bins) - 1\n",
|
||||||
|
" avg_err = [rel_err[inds == i].mean() for i in range(len(labels))]\n",
|
||||||
|
" axs[0, 0].bar(labels, avg_err, color=\"seagreen\", alpha=0.8)\n",
|
||||||
|
" axs[0, 0].set_title(\"Avg Relative Error by Price Range\")\n",
|
||||||
|
" axs[0, 0].set_ylabel(\"Relative Error (%)\")\n",
|
||||||
|
" axs[0, 0].tick_params(axis=\"x\", rotation=30)\n",
|
||||||
|
"\n",
|
||||||
|
" # Relative error distribution\n",
|
||||||
|
" axs[0, 1].hist(rel_err, bins=25, color=\"mediumpurple\", edgecolor=\"black\", alpha=0.7)\n",
|
||||||
|
" axs[0, 1].set_title(\"Relative Error Distribution (%)\")\n",
|
||||||
|
" axs[0, 1].set_xlabel(\"Relative Error (%)\")\n",
|
||||||
|
"\n",
|
||||||
|
" # Absolute error distribution\n",
|
||||||
|
" axs[1, 0].hist(abs_err, bins=25, color=\"steelblue\", edgecolor=\"black\", alpha=0.7)\n",
|
||||||
|
" axs[1, 0].axvline(abs_err.mean(), color=\"red\", linestyle=\"--\", label=f\"Mean={abs_err.mean():.2f}\")\n",
|
||||||
|
" axs[1, 0].set_title(\"Absolute Error Distribution\")\n",
|
||||||
|
" axs[1, 0].set_xlabel(\"Absolute Error ($)\")\n",
|
||||||
|
" axs[1, 0].legend()\n",
|
||||||
|
"\n",
|
||||||
|
" for ax in axs.ravel():\n",
|
||||||
|
" ax.grid(alpha=0.3)\n",
|
||||||
|
"\n",
|
||||||
|
" plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
|
||||||
|
" plt.show()\n",
|
||||||
|
"\n",
|
||||||
|
" def report(self):\n",
|
||||||
|
" y_true = np.array(self.truths)\n",
|
||||||
|
" y_pred = np.array(self.guesses)\n",
|
||||||
|
"\n",
|
||||||
|
" mae = mean_absolute_error(y_true, y_pred)\n",
|
||||||
|
" rmse = math.sqrt(mean_squared_error(y_true, y_pred))\n",
|
||||||
|
" rmsle = math.sqrt(sum(self.sles) / self.size)\n",
|
||||||
|
" mape = mean_absolute_percentage_error(y_true, y_pred) * 100\n",
|
||||||
|
" median_error = float(np.median(np.abs(y_true - y_pred)))\n",
|
||||||
|
" r2 = r2_score(y_true, y_pred)\n",
|
||||||
|
"\n",
|
||||||
|
" hit_rate_green = sum(1 for c in self.colors if c == \"green\") / self.size * 100\n",
|
||||||
|
" hit_rate_acceptable = sum(1 for c in self.colors if c in (\"green\", \"orange\")) / self.size * 100\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"\\n{'='*70}\")\n",
|
||||||
|
" print(f\"FINAL REPORT: {self.title}\")\n",
|
||||||
|
" print(f\"{'='*70}\")\n",
|
||||||
|
" print(f\"Total Predictions: {self.size}\")\n",
|
||||||
|
" print(f\"\\n--- Error Metrics ---\")\n",
|
||||||
|
" print(f\"Mean Absolute Error (MAE): ${mae:,.2f}\")\n",
|
||||||
|
" print(f\"Median Error: ${median_error:,.2f}\")\n",
|
||||||
|
" print(f\"Root Mean Squared Error (RMSE): ${rmse:,.2f}\")\n",
|
||||||
|
" print(f\"Root Mean Squared Log Error (RMSLE): {rmsle:.4f}\")\n",
|
||||||
|
" print(f\"Mean Absolute Percentage Error (MAPE): {mape:.2f}%\")\n",
|
||||||
|
" print(f\"\\n--- Accuracy Metrics ---\")\n",
|
||||||
|
" print(f\"R² Score: {r2:.4f}\")\n",
|
||||||
|
" print(f\"Hit Rate (Green): {hit_rate_green:.1f}%\")\n",
|
||||||
|
" print(f\"Hit Rate (Green+Orange): {hit_rate_acceptable:.1f}%\")\n",
|
||||||
|
" print(f\"{'='*70}\\n\")\n",
|
||||||
|
" chart_title = f\"{self.title} | MAE=${mae:,.2f} | RMSLE={rmsle:.3f} | R²={r2:.3f}\"\n",
|
||||||
|
"\n",
|
||||||
|
" self.chart_all(chart_title)\n",
|
||||||
|
"\n",
|
||||||
|
" def run(self):\n",
|
||||||
|
" iterator = tqdm(range(self.size), desc=\"Testing Model\") if self.show_progress else range(self.size)\n",
|
||||||
|
" for i in iterator:\n",
|
||||||
|
" self.run_datapoint(i)\n",
|
||||||
|
" self.report()\n",
|
||||||
|
"\n",
|
||||||
|
" @classmethod\n",
|
||||||
|
" def test(cls, function, data, title=None):\n",
|
||||||
|
" cls(function, data, title=title).run()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2e60a696",
|
||||||
|
"metadata": {
|
||||||
|
"id": "2e60a696"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"Tester.test(\n",
|
||||||
|
" improved_model_predict,\n",
|
||||||
|
" test,\n",
|
||||||
|
" title=\"Home appliances prediction\"\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
},
|
||||||
|
"colab": {
|
||||||
|
"provenance": [],
|
||||||
|
"gpuType": "T4"
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"name": "python3",
|
||||||
|
"display_name": "Python 3"
|
||||||
|
},
|
||||||
|
"accelerator": "GPU"
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
466
week7/community_contributions/hopeogbons/week7_EXERCISE.ipynb
Normal file
466
week7/community_contributions/hopeogbons/week7_EXERCISE.ipynb
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "GHsssBgWM_l0"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Fine-Tuned Product Price Predictor\n",
|
||||||
|
"\n",
|
||||||
|
"Evaluate fine-tuned Llama 3.1 8B model for product price estimation"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "MDyR63OTNUJ6"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Install required libraries for model inference\n",
|
||||||
|
"%pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
|
||||||
|
"%pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "-yikV8pRBer9"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Import required libraries\n",
|
||||||
|
"import os\n",
|
||||||
|
"import re\n",
|
||||||
|
"import math\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"from google.colab import userdata\n",
|
||||||
|
"from huggingface_hub import login\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"import transformers\n",
|
||||||
|
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n",
|
||||||
|
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
||||||
|
"from datetime import datetime\n",
|
||||||
|
"from peft import PeftModel\n",
|
||||||
|
"import matplotlib.pyplot as plt"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "uuTX-xonNeOK"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Configuration\n",
|
||||||
|
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
|
||||||
|
"PROJECT_NAME = \"pricer\"\n",
|
||||||
|
"HF_USER = \"ed-donner\" # Change to your HF username\n",
|
||||||
|
"RUN_NAME = \"2024-09-13_13.04.39\"\n",
|
||||||
|
"PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
|
||||||
|
"REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n",
|
||||||
|
"FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
|
||||||
|
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Quantization setting (False = 8-bit = better accuracy, more memory)\n",
|
||||||
|
"QUANT_4_BIT = False # Changed to 8-bit for better accuracy\n",
|
||||||
|
"\n",
|
||||||
|
"%matplotlib inline\n",
|
||||||
|
"\n",
|
||||||
|
"# Color codes for output\n",
|
||||||
|
"GREEN = \"\\033[92m\"\n",
|
||||||
|
"YELLOW = \"\\033[93m\"\n",
|
||||||
|
"RED = \"\\033[91m\"\n",
|
||||||
|
"RESET = \"\\033[0m\"\n",
|
||||||
|
"COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "8JArT3QAQAjx"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Step 1\n",
|
||||||
|
"\n",
|
||||||
|
"### Load dataset and fine-tuned model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "WyFPZeMcM88v"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Login to HuggingFace\n",
|
||||||
|
"hf_token = userdata.get('HF_TOKEN')\n",
|
||||||
|
"login(hf_token, add_to_git_credential=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "cvXVoJH8LS6u"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load product pricing dataset\n",
|
||||||
|
"dataset = load_dataset(DATASET_NAME)\n",
|
||||||
|
"train = dataset['train']\n",
|
||||||
|
"test = dataset['test']\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"✓ Loaded {len(train)} train and {len(test)} test samples\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "xb86e__Wc7j_"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Verify data structure\n",
|
||||||
|
"test[0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "qJWQ0a3wZ0Bw"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Load Tokenizer and Model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "lAUAAcEC6ido"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Configure quantization for memory efficiency\n",
|
||||||
|
"if QUANT_4_BIT:\n",
|
||||||
|
" quant_config = BitsAndBytesConfig(\n",
|
||||||
|
" load_in_4bit=True,\n",
|
||||||
|
" bnb_4bit_use_double_quant=True,\n",
|
||||||
|
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
||||||
|
" bnb_4bit_quant_type=\"nf4\"\n",
|
||||||
|
" )\n",
|
||||||
|
"else:\n",
|
||||||
|
" quant_config = BitsAndBytesConfig(\n",
|
||||||
|
" load_in_8bit=True,\n",
|
||||||
|
" bnb_8bit_compute_dtype=torch.bfloat16\n",
|
||||||
|
" )"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "R_O04fKxMMT-"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load tokenizer\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
|
||||||
|
"tokenizer.pad_token = tokenizer.eos_token\n",
|
||||||
|
"tokenizer.padding_side = \"right\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Load base model with quantization\n",
|
||||||
|
"base_model = AutoModelForCausalLM.from_pretrained(\n",
|
||||||
|
" BASE_MODEL,\n",
|
||||||
|
" quantization_config=quant_config,\n",
|
||||||
|
" device_map=\"auto\",\n",
|
||||||
|
")\n",
|
||||||
|
"base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
|
||||||
|
"\n",
|
||||||
|
"# Load fine-tuned weights\n",
|
||||||
|
"if REVISION:\n",
|
||||||
|
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n",
|
||||||
|
"else:\n",
|
||||||
|
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"✓ Model loaded - Memory: {fine_tuned_model.get_memory_footprint() / 1e6:.1f} MB\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "kD-GJtbrdd5t"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Verify model loaded\n",
|
||||||
|
"fine_tuned_model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "UObo1-RqaNnT"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Step 2\n",
|
||||||
|
"\n",
|
||||||
|
"### Model inference and evaluation"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "Qst1LhBVAB04"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Extract price from model response\n",
|
||||||
|
"def extract_price(s):\n",
|
||||||
|
" if \"Price is $\" in s:\n",
|
||||||
|
" contents = s.split(\"Price is $\")[1]\n",
|
||||||
|
" contents = contents.replace(',','')\n",
|
||||||
|
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", contents)\n",
|
||||||
|
" return float(match.group()) if match else 0\n",
|
||||||
|
" return 0"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "jXFBW_5UeEcp"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Test extract_price function\n",
|
||||||
|
"extract_price(\"Price is $a fabulous 899.99 or so\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "Oj_PzpdFAIMk"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Simple prediction: takes most likely next token\n",
|
||||||
|
"def model_predict(prompt):\n",
|
||||||
|
" set_seed(42)\n",
|
||||||
|
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
|
||||||
|
" attention_mask = torch.ones(inputs.shape, device=\"cuda\")\n",
|
||||||
|
" outputs = fine_tuned_model.generate(\n",
|
||||||
|
" inputs,\n",
|
||||||
|
" attention_mask=attention_mask,\n",
|
||||||
|
" max_new_tokens=5, # Increased for flexibility\n",
|
||||||
|
" temperature=0.1, # Low temperature for consistency\n",
|
||||||
|
" num_return_sequences=1\n",
|
||||||
|
" )\n",
|
||||||
|
" response = tokenizer.decode(outputs[0])\n",
|
||||||
|
" return extract_price(response)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "Je5dR8QEAI1d"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Improved prediction: weighted average of top K predictions\n",
|
||||||
|
"top_K = 5 # Increased from 3 to 5 for better accuracy\n",
|
||||||
|
"\n",
|
||||||
|
"def improved_model_predict(prompt, device=\"cuda\"):\n",
|
||||||
|
" set_seed(42)\n",
|
||||||
|
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
|
||||||
|
" attention_mask = torch.ones(inputs.shape, device=device)\n",
|
||||||
|
"\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n",
|
||||||
|
" next_token_logits = outputs.logits[:, -1, :].to('cpu')\n",
|
||||||
|
"\n",
|
||||||
|
" next_token_probs = F.softmax(next_token_logits, dim=-1)\n",
|
||||||
|
" top_prob, top_token_id = next_token_probs.topk(top_K)\n",
|
||||||
|
" prices, weights = [], []\n",
|
||||||
|
" for i in range(top_K):\n",
|
||||||
|
" predicted_token = tokenizer.decode(top_token_id[0][i])\n",
|
||||||
|
" probability = top_prob[0][i]\n",
|
||||||
|
" try:\n",
|
||||||
|
" result = float(predicted_token)\n",
|
||||||
|
" except ValueError as e:\n",
|
||||||
|
" result = 0.0\n",
|
||||||
|
" if result > 0:\n",
|
||||||
|
" prices.append(result)\n",
|
||||||
|
" weights.append(probability)\n",
|
||||||
|
" if not prices:\n",
|
||||||
|
" return 0.0, 0.0\n",
|
||||||
|
" total = sum(weights)\n",
|
||||||
|
" weighted_prices = [price * weight / total for price, weight in zip(prices, weights)]\n",
|
||||||
|
" return sum(weighted_prices).item()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "EpGVJPuC1iho"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Step 3\n",
|
||||||
|
"\n",
|
||||||
|
"### Test and evaluate model performance"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "30lzJXBH7BcK"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Evaluation framework\n",
|
||||||
|
"class Tester:\n",
|
||||||
|
" def __init__(self, predictor, data, title=None, size=250):\n",
|
||||||
|
" self.predictor = predictor\n",
|
||||||
|
" self.data = data\n",
|
||||||
|
" self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
|
||||||
|
" self.size = size\n",
|
||||||
|
" self.guesses = []\n",
|
||||||
|
" self.truths = []\n",
|
||||||
|
" self.errors = []\n",
|
||||||
|
" self.sles = []\n",
|
||||||
|
" self.colors = []\n",
|
||||||
|
"\n",
|
||||||
|
" def color_for(self, error, truth):\n",
|
||||||
|
" if error<40 or error/truth < 0.2:\n",
|
||||||
|
" return \"green\"\n",
|
||||||
|
" elif error<80 or error/truth < 0.4:\n",
|
||||||
|
" return \"orange\"\n",
|
||||||
|
" else:\n",
|
||||||
|
" return \"red\"\n",
|
||||||
|
"\n",
|
||||||
|
" def run_datapoint(self, i):\n",
|
||||||
|
" datapoint = self.data[i]\n",
|
||||||
|
" guess = self.predictor(datapoint[\"text\"])\n",
|
||||||
|
" truth = datapoint[\"price\"]\n",
|
||||||
|
" error = abs(guess - truth)\n",
|
||||||
|
" log_error = math.log(truth+1) - math.log(guess+1)\n",
|
||||||
|
" sle = log_error ** 2\n",
|
||||||
|
" color = self.color_for(error, truth)\n",
|
||||||
|
" title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n",
|
||||||
|
" self.guesses.append(guess)\n",
|
||||||
|
" self.truths.append(truth)\n",
|
||||||
|
" self.errors.append(error)\n",
|
||||||
|
" self.sles.append(sle)\n",
|
||||||
|
" self.colors.append(color)\n",
|
||||||
|
" print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
|
||||||
|
"\n",
|
||||||
|
" def chart(self, title):\n",
|
||||||
|
" max_error = max(self.errors)\n",
|
||||||
|
" plt.figure(figsize=(12, 8))\n",
|
||||||
|
" max_val = max(max(self.truths), max(self.guesses))\n",
|
||||||
|
" plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n",
|
||||||
|
" plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
|
||||||
|
" plt.xlabel('Ground Truth')\n",
|
||||||
|
" plt.ylabel('Model Estimate')\n",
|
||||||
|
" plt.xlim(0, max_val)\n",
|
||||||
|
" plt.ylim(0, max_val)\n",
|
||||||
|
" plt.title(title)\n",
|
||||||
|
" plt.show()\n",
|
||||||
|
"\n",
|
||||||
|
" def report(self):\n",
|
||||||
|
" average_error = sum(self.errors) / self.size\n",
|
||||||
|
" rmsle = math.sqrt(sum(self.sles) / self.size)\n",
|
||||||
|
" hits = sum(1 for color in self.colors if color==\"green\")\n",
|
||||||
|
" title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n",
|
||||||
|
" self.chart(title)\n",
|
||||||
|
"\n",
|
||||||
|
" def run(self):\n",
|
||||||
|
" self.error = 0\n",
|
||||||
|
" for i in range(self.size):\n",
|
||||||
|
" self.run_datapoint(i)\n",
|
||||||
|
" self.report()\n",
|
||||||
|
"\n",
|
||||||
|
" @classmethod\n",
|
||||||
|
" def test(cls, function, data):\n",
|
||||||
|
" cls(function, data).run()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "W_KcLvyt6kbb"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Run evaluation on 250 test examples\n",
|
||||||
|
"Tester.test(improved_model_predict, test)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "nVwiWGVN1ihp"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Performance Optimizations Applied\n",
|
||||||
|
"\n",
|
||||||
|
"**Changes for better accuracy:**\n",
|
||||||
|
"- ✅ 8-bit quantization (vs 4-bit) - Better precision\n",
|
||||||
|
"- ✅ top_K = 5 (vs 3) - More predictions in weighted average\n",
|
||||||
|
"- ✅ max_new_tokens = 5 - More flexibility in response\n",
|
||||||
|
"- ✅ temperature = 0.1 - More consistent predictions\n",
|
||||||
|
"\n",
|
||||||
|
"**Expected improvement:** ~10-15% reduction in average error\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "hO4DdLa81ihp"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Expected Performance\n",
|
||||||
|
"\n",
|
||||||
|
"**Baseline comparisons:**\n",
|
||||||
|
"- GPT-4o: $76 avg error\n",
|
||||||
|
"- Llama 3.1 base: $396 avg error \n",
|
||||||
|
"- Human: $127 avg error\n",
|
||||||
|
"\n",
|
||||||
|
"**Fine-tuned model (optimized):**\n",
|
||||||
|
"- Target: $70-85 avg error\n",
|
||||||
|
"- With 8-bit quant + top_K=5 + temp=0.1\n",
|
||||||
|
"- Expected to rival or beat GPT-4o\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"gpuType": "T4",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
||||||
453
week8/community_contributions/emmy/llm_battle.py
Normal file
453
week8/community_contributions/emmy/llm_battle.py
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Configuration helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentConfig:
|
||||||
|
"""Holds configuration required to talk to an LLM provider."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
model: str
|
||||||
|
api_key_env: str
|
||||||
|
base_url_env: Optional[str] = None
|
||||||
|
temperature: float = 0.7
|
||||||
|
supports_json: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
def load_client(config: AgentConfig) -> OpenAI:
|
||||||
|
"""Create an OpenAI-compatible client for the given agent."""
|
||||||
|
api_key = os.getenv(config.api_key_env) or os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Missing API key for {config.name}. "
|
||||||
|
f"Set {config.api_key_env} or OPENAI_API_KEY."
|
||||||
|
)
|
||||||
|
|
||||||
|
base_url = (
|
||||||
|
os.getenv(config.base_url_env)
|
||||||
|
if config.base_url_env
|
||||||
|
else os.getenv("OPENAI_BASE_URL")
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text(response) -> str:
|
||||||
|
"""Extract text content from an OpenAI-style response object or dict."""
|
||||||
|
|
||||||
|
choices = getattr(response, "choices", None)
|
||||||
|
if choices is None and isinstance(response, dict):
|
||||||
|
choices = response.get("choices")
|
||||||
|
if not choices:
|
||||||
|
raise RuntimeError(f"LLM response missing choices field: {response!r}")
|
||||||
|
|
||||||
|
choice = choices[0]
|
||||||
|
message = getattr(choice, "message", None)
|
||||||
|
if message is None and isinstance(choice, dict):
|
||||||
|
message = choice.get("message")
|
||||||
|
|
||||||
|
content = None
|
||||||
|
if message is not None:
|
||||||
|
content = getattr(message, "content", None)
|
||||||
|
if content is None and isinstance(message, dict):
|
||||||
|
content = message.get("content")
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: List[str] = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict):
|
||||||
|
if "text" in part:
|
||||||
|
parts.append(str(part["text"]))
|
||||||
|
elif "output_text" in part:
|
||||||
|
parts.append(str(part["output_text"]))
|
||||||
|
elif "type" in part and "content" in part:
|
||||||
|
parts.append(str(part["content"]))
|
||||||
|
else:
|
||||||
|
parts.append(str(part))
|
||||||
|
content = "".join(parts)
|
||||||
|
|
||||||
|
if content is None:
|
||||||
|
text = getattr(choice, "text", None)
|
||||||
|
if text is None and isinstance(choice, dict):
|
||||||
|
text = choice.get("text")
|
||||||
|
if text:
|
||||||
|
content = text
|
||||||
|
|
||||||
|
if content is None:
|
||||||
|
raise RuntimeError(f"LLM response missing content/text: {response!r}")
|
||||||
|
|
||||||
|
return str(content).strip()
|
||||||
|
|
||||||
|
|
||||||
|
# Default configuration leverages OpenAI unless overrides are provided.
|
||||||
|
DEBATER_A_CONFIG = AgentConfig(
|
||||||
|
name="Debater A",
|
||||||
|
model=os.getenv("DEBATER_A_MODEL", "gpt-4o"),
|
||||||
|
api_key_env="OPENAI_API_KEY",
|
||||||
|
base_url_env="OPENAI_BASE_URL",
|
||||||
|
temperature=float(os.getenv("DEBATER_A_TEMPERATURE", 0.7)),
|
||||||
|
)
|
||||||
|
|
||||||
|
DEBATER_B_CONFIG = AgentConfig(
|
||||||
|
name="Debater B",
|
||||||
|
model=os.getenv("DEBATER_B_MODEL", "gemini-2.0-flash"),
|
||||||
|
api_key_env="GOOGLE_API_KEY",
|
||||||
|
base_url_env="GEMINI_BASE_URL",
|
||||||
|
temperature=float(os.getenv("DEBATER_B_TEMPERATURE", 0.7)),
|
||||||
|
)
|
||||||
|
|
||||||
|
JUDGE_CONFIG = AgentConfig(
|
||||||
|
name="Judge",
|
||||||
|
model=os.getenv("JUDGE_MODEL", "gpt-oss:20b-cloud"),
|
||||||
|
api_key_env="OLLAMA_API_KEY",
|
||||||
|
base_url_env="OLLAMA_BASE_URL",
|
||||||
|
temperature=float(os.getenv("JUDGE_TEMPERATURE", 0.2)),
|
||||||
|
supports_json=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
REPORTER_CONFIG = AgentConfig(
|
||||||
|
name="Reporter",
|
||||||
|
model=os.getenv("REPORTER_MODEL", "MiniMax-M2"),
|
||||||
|
api_key_env="MINIMAX_API_KEY",
|
||||||
|
base_url_env="MINIMAX_BASE_URL",
|
||||||
|
temperature=float(os.getenv("REPORTER_TEMPERATURE", 0.4)),
|
||||||
|
supports_json=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
THEME = gr.themes.Default(
|
||||||
|
primary_hue="blue",
|
||||||
|
secondary_hue="sky",
|
||||||
|
neutral_hue="gray",
|
||||||
|
)
|
||||||
|
|
||||||
|
CUSTOM_CSS = """
|
||||||
|
body, .gradio-container {
|
||||||
|
background: radial-gradient(circle at top, #0f172a 0%, #020617 60%, #020617 100%);
|
||||||
|
color: #e2e8f0;
|
||||||
|
}
|
||||||
|
#live-debate-panel {
|
||||||
|
background: linear-gradient(135deg, rgba(30,64,175,0.95), rgba(29,78,216,0.85));
|
||||||
|
color: #f8fafc;
|
||||||
|
border-radius: 16px;
|
||||||
|
padding: 24px;
|
||||||
|
box-shadow: 0 20px 45px rgba(15,23,42,0.35);
|
||||||
|
}
|
||||||
|
#live-debate-panel h3 {
|
||||||
|
color: #bfdbfe;
|
||||||
|
}
|
||||||
|
.gr-button-primary {
|
||||||
|
background: linear-gradient(135deg, #1d4ed8, #2563eb) !important;
|
||||||
|
border: none !important;
|
||||||
|
}
|
||||||
|
.gr-button-primary:hover {
|
||||||
|
background: linear-gradient(135deg, #2563eb, #1d4ed8) !important;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Debate runtime classes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DebateState:
|
||||||
|
topic: str
|
||||||
|
stance_a: str
|
||||||
|
stance_b: str
|
||||||
|
transcript: List[Tuple[str, str]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMAdapter:
|
||||||
|
"""Thin wrapper around the OpenAI SDK to simplify prompting."""
|
||||||
|
|
||||||
|
def __init__(self, config: AgentConfig):
|
||||||
|
self.config = config
|
||||||
|
self.client = load_client(config)
|
||||||
|
|
||||||
|
def complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
system: Optional[str] = None,
|
||||||
|
max_tokens: int = 512,
|
||||||
|
json_mode: bool = False,
|
||||||
|
) -> str:
|
||||||
|
messages = []
|
||||||
|
if system:
|
||||||
|
messages.append({"role": "system", "content": system})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
params = dict(
|
||||||
|
model=self.config.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
if json_mode and self.config.supports_json:
|
||||||
|
params["response_format"] = {"type": "json_object"}
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(**params)
|
||||||
|
return extract_text(response)
|
||||||
|
|
||||||
|
|
||||||
|
class Debater:
|
||||||
|
def __init__(self, adapter: LLMAdapter, stance_label: str):
|
||||||
|
self.adapter = adapter
|
||||||
|
self.stance_label = stance_label
|
||||||
|
|
||||||
|
def argue(self, topic: str) -> str:
|
||||||
|
prompt = (
|
||||||
|
f"You are {self.adapter.config.name}, debating the topic:\n"
|
||||||
|
f"'{topic}'.\n\n"
|
||||||
|
f"Present a concise argument that {self.stance_label.lower()} "
|
||||||
|
f"the statement. Use at most 150 words. Provide clear reasoning "
|
||||||
|
f"and, if applicable, cite plausible evidence or examples."
|
||||||
|
)
|
||||||
|
return self.adapter.complete(prompt, max_tokens=300)
|
||||||
|
|
||||||
|
|
||||||
|
class Judge:
|
||||||
|
RUBRIC = [
|
||||||
|
"Clarity of the argument",
|
||||||
|
"Use of evidence or examples",
|
||||||
|
"Logical coherence",
|
||||||
|
"Persuasiveness and impact",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, adapter: LLMAdapter):
|
||||||
|
self.adapter = adapter
|
||||||
|
|
||||||
|
def evaluate(self, topic: str, argument_a: str, argument_b: str) -> Dict[str, object]:
|
||||||
|
rubric_text = "\n".join(f"- {item}" for item in self.RUBRIC)
|
||||||
|
prompt = (
|
||||||
|
"You are serving as an impartial debate judge.\n"
|
||||||
|
f"Topic: {topic}\n\n"
|
||||||
|
f"Argument from Debater A:\n{argument_a}\n\n"
|
||||||
|
f"Argument from Debater B:\n{argument_b}\n\n"
|
||||||
|
"Score each debater from 0-10 on the following criteria:\n"
|
||||||
|
f"{rubric_text}\n\n"
|
||||||
|
"Return a JSON object with this exact structure:\n"
|
||||||
|
'{\n'
|
||||||
|
' "winner": "A" or "B" or "Tie",\n'
|
||||||
|
' "reason": "brief justification",\n'
|
||||||
|
' "scores": [\n'
|
||||||
|
' {"criterion": "...", "debater_a": 0-10, "debater_b": 0-10, "notes": "optional"}\n'
|
||||||
|
" ]\n"
|
||||||
|
"}\n"
|
||||||
|
"Ensure the JSON is valid."
|
||||||
|
)
|
||||||
|
raw = self.adapter.complete(prompt, max_tokens=400, json_mode=True)
|
||||||
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
if "scores" not in data:
|
||||||
|
raise ValueError("scores missing")
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
# Fallback: wrap raw text if parsing fails.
|
||||||
|
return {"winner": "Unknown", "reason": raw, "scores": []}
|
||||||
|
|
||||||
|
|
||||||
|
class Reporter:
|
||||||
|
def __init__(self, adapter: LLMAdapter):
|
||||||
|
self.adapter = adapter
|
||||||
|
|
||||||
|
def summarize(
|
||||||
|
self,
|
||||||
|
topic: str,
|
||||||
|
argument_a: str,
|
||||||
|
argument_b: str,
|
||||||
|
judge_result: Dict[str, object],
|
||||||
|
) -> str:
|
||||||
|
prompt = (
|
||||||
|
f"Summarize a single-round debate on '{topic}'.\n\n"
|
||||||
|
f"Debater A argued:\n{argument_a}\n\n"
|
||||||
|
f"Debater B argued:\n{argument_b}\n\n"
|
||||||
|
f"Judge verdict: {json.dumps(judge_result, ensure_ascii=False)}\n\n"
|
||||||
|
"Provide a short journalistic summary (max 200 words) highlighting "
|
||||||
|
"each side's key points and the judge's decision. Use neutral tone."
|
||||||
|
)
|
||||||
|
response = self.adapter.client.chat.completions.create(
|
||||||
|
model=self.adapter.config.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are an impartial debate reporter."},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
temperature=self.adapter.config.temperature,
|
||||||
|
max_tokens=300,
|
||||||
|
**(
|
||||||
|
{"extra_body": {"reasoning_split": True}}
|
||||||
|
if getattr(self.adapter.client, "base_url", None)
|
||||||
|
and "minimax" in str(self.adapter.client.base_url).lower()
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return extract_text(response)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Debate pipeline + UI
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
debater_a = Debater(LLMAdapter(DEBATER_A_CONFIG), stance_label="supports")
|
||||||
|
debater_b = Debater(LLMAdapter(DEBATER_B_CONFIG), stance_label="opposes")
|
||||||
|
judge = Judge(LLMAdapter(JUDGE_CONFIG))
|
||||||
|
reporter = Reporter(LLMAdapter(REPORTER_CONFIG))
|
||||||
|
|
||||||
|
|
||||||
|
def format_transcript(transcript: List[Tuple[str, str]]) -> str:
|
||||||
|
"""Return markdown-formatted transcript."""
|
||||||
|
lines = []
|
||||||
|
for speaker, message in transcript:
|
||||||
|
lines.append(f"### {speaker}\n\n{message}\n")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def run_debate(
|
||||||
|
topic: str, stance_a: str, stance_b: str
|
||||||
|
) -> Generator[Tuple[str, str, List[List[object]], str, str], None, None]:
|
||||||
|
"""Generator for Gradio to stream debate progress."""
|
||||||
|
if not topic.strip():
|
||||||
|
warning = "⚠️ Please enter a debate topic to get started."
|
||||||
|
yield warning, "", [], "", ""
|
||||||
|
return
|
||||||
|
|
||||||
|
state = DebateState(topic=topic.strip(), stance_a=stance_a, stance_b=stance_b)
|
||||||
|
|
||||||
|
state.transcript.append(
|
||||||
|
("Moderator", f"Welcome to the debate on **{state.topic}**!")
|
||||||
|
)
|
||||||
|
yield format_transcript(state.transcript), "Waiting for judge...", [], "", ""
|
||||||
|
|
||||||
|
argument_a = debater_a.argue(state.topic)
|
||||||
|
state.transcript.append((f"Debater A ({state.stance_a})", argument_a))
|
||||||
|
yield format_transcript(state.transcript), "Collecting arguments...", [], "", ""
|
||||||
|
|
||||||
|
argument_b = debater_b.argue(state.topic)
|
||||||
|
state.transcript.append((f"Debater B ({state.stance_b})", argument_b))
|
||||||
|
yield format_transcript(state.transcript), "Judge deliberating...", [], "", ""
|
||||||
|
|
||||||
|
judge_result = judge.evaluate(state.topic, argument_a, argument_b)
|
||||||
|
verdict_text = (
|
||||||
|
f"Winner: {judge_result.get('winner', 'Unknown')}\nReason: "
|
||||||
|
f"{judge_result.get('reason', 'No explanation provided.')}"
|
||||||
|
)
|
||||||
|
score_rows = [
|
||||||
|
[
|
||||||
|
entry.get("criterion", ""),
|
||||||
|
entry.get("debater_a", ""),
|
||||||
|
entry.get("debater_b", ""),
|
||||||
|
entry.get("notes", ""),
|
||||||
|
]
|
||||||
|
for entry in judge_result.get("scores", [])
|
||||||
|
]
|
||||||
|
judge_report_md = (
|
||||||
|
f"**Judge Verdict:** {judge_result.get('winner', 'Unknown')}\n\n"
|
||||||
|
f"{judge_result.get('reason', '')}"
|
||||||
|
)
|
||||||
|
yield (
|
||||||
|
format_transcript(state.transcript),
|
||||||
|
judge_report_md,
|
||||||
|
score_rows,
|
||||||
|
verdict_text,
|
||||||
|
format_transcript(state.transcript),
|
||||||
|
)
|
||||||
|
|
||||||
|
reporter_summary = reporter.summarize(
|
||||||
|
state.topic, argument_a, argument_b, judge_result
|
||||||
|
)
|
||||||
|
|
||||||
|
final_markdown = (
|
||||||
|
f"{judge_report_md}\n\n---\n\n"
|
||||||
|
f"**Reporter Summary**\n\n{reporter_summary}"
|
||||||
|
)
|
||||||
|
yield (
|
||||||
|
format_transcript(state.transcript),
|
||||||
|
final_markdown,
|
||||||
|
score_rows,
|
||||||
|
verdict_text,
|
||||||
|
format_transcript(state.transcript),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Gradio Interface
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks(
|
||||||
|
title="LLM Debate Arena",
|
||||||
|
fill_width=True,
|
||||||
|
theme=THEME,
|
||||||
|
css=CUSTOM_CSS,
|
||||||
|
) as demo:
|
||||||
|
gr.Markdown(
|
||||||
|
"# 🔁 LLM Debate Arena\n"
|
||||||
|
"Configure two debating agents, watch their arguments in real time, and "
|
||||||
|
"review the judge's verdict plus a reporter summary."
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
topic_input = gr.Textbox(
|
||||||
|
label="Debate Topic",
|
||||||
|
placeholder="e.g., Should autonomous delivery robots be allowed in city centers?",
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
stance_a_input = gr.Textbox(
|
||||||
|
label="Debater A Stance",
|
||||||
|
value="Supports the statement",
|
||||||
|
)
|
||||||
|
stance_b_input = gr.Textbox(
|
||||||
|
label="Debater B Stance",
|
||||||
|
value="Opposes the statement",
|
||||||
|
)
|
||||||
|
|
||||||
|
run_button = gr.Button("Start Debate", variant="primary")
|
||||||
|
|
||||||
|
with gr.Tab("Live Debate"):
|
||||||
|
transcript_md = gr.Markdown(
|
||||||
|
"### Waiting for the debate to start...",
|
||||||
|
elem_id="live-debate-panel",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Tab("Judge's Report"):
|
||||||
|
judge_md = gr.Markdown("Judge verdict will appear here.")
|
||||||
|
score_table = gr.Dataframe(
|
||||||
|
headers=["Criterion", "Debater A", "Debater B", "Notes"],
|
||||||
|
datatype=["str", "number", "number", "str"],
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
verdict_box = gr.Textbox(
|
||||||
|
label="Verdict Detail",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
transcript_box = gr.Textbox(
|
||||||
|
label="Full Transcript (for copying)",
|
||||||
|
interactive=False,
|
||||||
|
lines=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
run_button.click(
|
||||||
|
fn=run_debate,
|
||||||
|
inputs=[topic_input, stance_a_input, stance_b_input],
|
||||||
|
outputs=[transcript_md, judge_md, score_table, verdict_box, transcript_box],
|
||||||
|
queue=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo.queue(default_concurrency_limit=4).launch()
|
||||||
Reference in New Issue
Block a user