diff --git a/week4/community-contributions/philip/week4_EXERCISE.ipynb b/week4/community-contributions/philip/week4_EXERCISE.ipynb new file mode 100644 index 0000000..c6f0e82 --- /dev/null +++ b/week4/community-contributions/philip/week4_EXERCISE.ipynb @@ -0,0 +1,725 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import re\n", + "from typing import List, Dict, Optional\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "from IPython.display import Markdown, display\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Groq API Key not set (optional)\n", + "OpenRouter API Key loaded (begins with sk-or-)\n" + ] + } + ], + "source": [ + "load_dotenv(override=True)\n", + "\n", + "# Ollama connection \n", + "ollama_url = \"http://localhost:11434/v1\"\n", + "ollama_client = OpenAI(api_key=\"ollama\", base_url=ollama_url)\n", + "\n", + "# Groq connection\n", + "groq_api_key = os.getenv('GROQ_API_KEY')\n", + "groq_url = \"https://api.groq.com/openai/v1\"\n", + "groq_client = None\n", + "if groq_api_key:\n", + " groq_client = OpenAI(api_key=groq_api_key, base_url=groq_url)\n", + " print(f\"Groq API Key loaded (begins with {groq_api_key[:4]})\")\n", + "else:\n", + " print(\"Groq API Key not set (optional)\")\n", + "\n", + "# OpenRouter connection\n", + "openrouter_api_key = os.getenv('OPENROUTER_API_KEY')\n", + "openrouter_url = \"https://openrouter.ai/api/v1\"\n", + "openrouter_client = None\n", + "if openrouter_api_key:\n", + " openrouter_client = OpenAI(api_key=openrouter_api_key, base_url=openrouter_url)\n", + " print(f\"OpenRouter API Key loaded (begins with {openrouter_api_key[:6]})\")\n", + "else:\n", + " print(\"OpenRouter API Key not set (optional)\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Configured 2 models\n", + "OpenRouter models available (perfect for limited storage demos!)\n" + ] + } + ], + "source": [ + "# Open-source code models configuration\n", + "MODELS = {}\n", + "\n", + "if groq_client:\n", + " MODELS.update({\n", + " \"gpt-oss-20b-groq\": {\n", + " \"name\": \"GPT-OSS-20B (Groq)\",\n", + " \"client\": groq_client,\n", + " \"model\": \"gpt-oss:20b\",\n", + " \"description\": \"Cloud\"\n", + " },\n", + " \"gpt-oss-120b-groq\": {\n", + " \"name\": \"GPT-OSS-120B (Groq)\",\n", + " \"client\": groq_client,\n", + " \"model\": \"openai/gpt-oss-120b\",\n", + " \"description\": \"Cloud - Larger GPT-OSS\"\n", + " },\n", + " \"qwen2.5-coder-32b-groq\": {\n", + " \"name\": \"Qwen2.5-Coder 32B (Groq)\",\n", + " \"client\": groq_client,\n", + " \"model\": \"qwen/qwen2.5-coder-32b-instruct\",\n", + " \"description\": \"Cloud\"\n", + " },\n", + " })\n", + "\n", + "# OpenRouter models\n", + "if openrouter_client:\n", + " MODELS.update({\n", + " \"qwen-2.5-coder-32b-openrouter\": {\n", + " \"name\": \"Qwen2.5-Coder 32B (OpenRouter)\",\n", + " \"client\": openrouter_client,\n", + " \"model\": \"qwen/qwen-2.5-coder-32b-instruct\",\n", + " \"description\": \"Cloud - Perfect for demos, 50 req/day free\"\n", + " },\n", + " \"gpt-oss-20b-groq\": {\n", + " \"name\": \"GPT-OSS-20B\",\n", + " \"client\": openrouter_client,\n", + " \"model\": \"openai/gpt-oss-20b\",\n", + " \"description\": \"Cloud - OpenAI's open model, excellent for code!\"\n", + " },\n", + " })\n", + "\n", + "print(f\"Configured {len(MODELS)} models\")\n", + "if openrouter_client:\n", + " print(\"OpenRouter models available (perfect for limited storage demos!)\")\n", + "if groq_client:\n", + " print(\"Groq models available (fast cloud inference!)\")\n", + "if \"qwen2.5-coder:7b\" in MODELS:\n", + " print(\"Ollama models available (unlimited local usage!)\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "BUG_DETECTION_SYSTEM_PROMPT = \"\"\"You are an expert code reviewer specializing in finding bugs, security vulnerabilities, and logic errors.\n", + "\n", + "Your task is to analyze Python code and identify issues. Return ONLY a valid JSON array with this exact format:\n", + "[{\n", + " \"severity\": \"critical|high|medium|low\",\n", + " \"line\": number,\n", + " \"issue\": \"brief description of the problem\",\n", + " \"suggestion\": \"specific fix recommendation\"\n", + "}]\n", + "\n", + "Be thorough but concise. Focus on real bugs and security issues.\"\"\"\n", + "\n", + "IMPROVEMENTS_SYSTEM_PROMPT = \"\"\"You are a senior software engineer specializing in code quality and best practices.\n", + "\n", + "Analyze the Python code and suggest improvements for:\n", + "- Code readability and maintainability\n", + "- Performance optimizations\n", + "- Pythonic idioms and conventions\n", + "- Better error handling\n", + "\n", + "Return ONLY a JSON array:\n", + "[{\n", + " \"category\": \"readability|performance|style|error_handling\",\n", + " \"line\": number,\n", + " \"current\": \"current code snippet\",\n", + " \"improved\": \"improved code snippet\",\n", + " \"explanation\": \"why this is better\"\n", + "}]\n", + "\n", + "Only suggest meaningful improvements.\"\"\"\n", + "\n", + "TEST_GENERATION_SYSTEM_PROMPT = \"\"\"You are an expert in writing comprehensive unit tests.\n", + "\n", + "Generate pytest unit tests for the given Python code. Include:\n", + "- Test cases for normal operation\n", + "- Edge cases and boundary conditions\n", + "- Error handling tests\n", + "- Tests for any bugs that were identified\n", + "\n", + "Return ONLY Python code with pytest tests. Include the original code at the top if needed.\n", + "Put the imports at the top of the file first.\n", + "Do not include explanations or markdown formatting.\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def extract_json_from_response(text: str) -> List[Dict]:\n", + " \"\"\"Extract JSON array from model response, handling markdown code blocks.\"\"\"\n", + " # Remove markdown code blocks\n", + " text = re.sub(r'```json\\n?', '', text)\n", + " text = re.sub(r'```\\n?', '', text)\n", + " \n", + " # Try to find JSON array\n", + " json_match = re.search(r'\\[\\s*\\{.*\\}\\s*\\]', text, re.DOTALL)\n", + " if json_match:\n", + " try:\n", + " return json.loads(json_match.group())\n", + " except json.JSONDecodeError:\n", + " pass\n", + " \n", + " # Fallback: try parsing entire response\n", + " try:\n", + " return json.loads(text.strip())\n", + " except json.JSONDecodeError:\n", + " return []\n", + "\n", + "def detect_bugs(code: str, model_key: str) -> Dict:\n", + " \"\"\"Detect bugs and security issues in code.\"\"\"\n", + " model_config = MODELS[model_key]\n", + " client = model_config[\"client\"]\n", + " model_name = model_config[\"model\"]\n", + " \n", + " user_prompt = f\"Analyze this Python code for bugs and security issues:\\n\\n```python\\n{code}\\n```\"\n", + " \n", + " try:\n", + " response = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": BUG_DETECTION_SYSTEM_PROMPT},\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ],\n", + " temperature=0.1\n", + " )\n", + " \n", + " content = response.choices[0].message.content\n", + " issues = extract_json_from_response(content)\n", + " \n", + " return {\n", + " \"model\": model_config[\"name\"],\n", + " \"issues\": issues,\n", + " \"raw_response\": content,\n", + " \"success\": True\n", + " }\n", + " except Exception as e:\n", + " return {\n", + " \"model\": model_config[\"name\"],\n", + " \"issues\": [],\n", + " \"error\": str(e),\n", + " \"success\": False\n", + " }\n", + "\n", + "def suggest_improvements(code: str, model_key: str) -> Dict:\n", + " \"\"\"Suggest code improvements and best practices.\"\"\"\n", + " model_config = MODELS[model_key]\n", + " client = model_config[\"client\"]\n", + " model_name = model_config[\"model\"]\n", + " \n", + " user_prompt = f\"Suggest improvements for this Python code:\\n\\n```python\\n{code}\\n```\"\n", + " \n", + " try:\n", + " response = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": IMPROVEMENTS_SYSTEM_PROMPT},\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ],\n", + " temperature=0.2\n", + " )\n", + " \n", + " content = response.choices[0].message.content\n", + " improvements = extract_json_from_response(content)\n", + " \n", + " return {\n", + " \"model\": model_config[\"name\"],\n", + " \"improvements\": improvements,\n", + " \"raw_response\": content,\n", + " \"success\": True\n", + " }\n", + " except Exception as e:\n", + " return {\n", + " \"model\": model_config[\"name\"],\n", + " \"improvements\": [],\n", + " \"error\": str(e),\n", + " \"success\": False\n", + " }\n", + "\n", + "def generate_tests(code: str, bugs: List[Dict], model_key: str) -> Dict:\n", + " \"\"\"Generate unit tests for the code.\"\"\"\n", + " model_config = MODELS[model_key]\n", + " client = model_config[\"client\"]\n", + " model_name = model_config[\"model\"]\n", + " \n", + " bugs_context = \"\"\n", + " if bugs:\n", + " bugs_context = f\"\\n\\nNote: The following bugs were identified:\\n\" + \"\\n\".join([f\"- Line {b.get('line', '?')}: {b.get('issue', '')}\" for b in bugs])\n", + " \n", + " user_prompt = f\"Generate pytest unit tests for this Python code:{bugs_context}\\n\\n```python\\n{code}\\n```\"\n", + " \n", + " try:\n", + " response = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": TEST_GENERATION_SYSTEM_PROMPT},\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ],\n", + " temperature=0.3\n", + " )\n", + " \n", + " content = response.choices[0].message.content\n", + " # Remove markdown code blocks if present\n", + " test_code = re.sub(r'```python\\n?', '', content)\n", + " test_code = re.sub(r'```\\n?', '', test_code)\n", + " \n", + " return {\n", + " \"model\": model_config[\"name\"],\n", + " \"test_code\": test_code.strip(),\n", + " \"raw_response\": content,\n", + " \"success\": True\n", + " }\n", + " except Exception as e:\n", + " return {\n", + " \"model\": model_config[\"name\"],\n", + " \"test_code\": \"\",\n", + " \"error\": str(e),\n", + " \"success\": False\n", + " }\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def format_bugs_output(result: Dict) -> str:\n", + " \"\"\"Format bug detection results for display.\"\"\"\n", + " if not result.get(\"success\"):\n", + " return f\"**Error with {result['model']}:** {result.get('error', 'Unknown error')}\"\n", + " \n", + " issues = result.get(\"issues\", [])\n", + " if not issues:\n", + " return f\"โœ… **{result['model']}**: No issues found. Code looks good!\"\n", + " \n", + " output = [f\"**{result['model']}** - Found {len(issues)} issue(s):\\n\"]\n", + " \n", + " severity_order = {\"critical\": 0, \"high\": 1, \"medium\": 2, \"low\": 3}\n", + " sorted_issues = sorted(issues, key=lambda x: severity_order.get(x.get(\"severity\", \"low\"), 3))\n", + " \n", + " for issue in sorted_issues:\n", + " severity = issue.get(\"severity\", \"unknown\").upper()\n", + " line = issue.get(\"line\", \"?\")\n", + " issue_desc = issue.get(\"issue\", \"\")\n", + " suggestion = issue.get(\"suggestion\", \"\")\n", + " \n", + " severity_emoji = {\n", + " \"CRITICAL\": \"๐Ÿ”ด\",\n", + " \"HIGH\": \"๐ŸŸ \",\n", + " \"MEDIUM\": \"๐ŸŸก\",\n", + " \"LOW\": \"๐Ÿ”ต\"\n", + " }.get(severity, \"โšช\")\n", + " \n", + " output.append(f\"{severity_emoji} **{severity}** (Line {line}): {issue_desc}\")\n", + " if suggestion:\n", + " output.append(f\" ๐Ÿ’ก *Fix:* {suggestion}\")\n", + " output.append(\"\")\n", + " \n", + " return \"\\n\".join(output)\n", + "\n", + "def format_improvements_output(result: Dict) -> str:\n", + " \"\"\"Format improvement suggestions for display.\"\"\"\n", + " if not result.get(\"success\"):\n", + " return f\"**Error with {result['model']}:** {result.get('error', 'Unknown error')}\"\n", + " \n", + " improvements = result.get(\"improvements\", [])\n", + " if not improvements:\n", + " return f\"โœ… **{result['model']}**: Code follows best practices. No major improvements needed!\"\n", + " \n", + " output = [f\"**{result['model']}** - {len(improvements)} suggestion(s):\\n\"]\n", + " \n", + " for imp in improvements:\n", + " category = imp.get(\"category\", \"general\").replace(\"_\", \" \").title()\n", + " line = imp.get(\"line\", \"?\")\n", + " current = imp.get(\"current\", \"\")\n", + " improved = imp.get(\"improved\", \"\")\n", + " explanation = imp.get(\"explanation\", \"\")\n", + " \n", + " output.append(f\"\\n๐Ÿ“ **{category}** (Line {line}):\")\n", + " if current and improved:\n", + " output.append(f\" Before: `{current[:60]}{'...' if len(current) > 60 else ''}`\")\n", + " output.append(f\" After: `{improved[:60]}{'...' if len(improved) > 60 else ''}`\")\n", + " if explanation:\n", + " output.append(f\" ๐Ÿ’ก {explanation}\")\n", + " \n", + " return \"\\n\".join(output)\n", + "\n", + "def format_tests_output(result: Dict) -> str:\n", + " \"\"\"Format test generation results for display.\"\"\"\n", + " if not result.get(\"success\"):\n", + " return f\"**Error with {result['model']}:** {result.get('error', 'Unknown error')}\"\n", + " \n", + " test_code = result.get(\"test_code\", \"\")\n", + " if not test_code:\n", + " return f\"โš ๏ธ **{result['model']}**: No tests generated.\"\n", + " \n", + " return test_code\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def review_code(code: str, model_key: str, include_tests: bool = True) -> tuple:\n", + " \"\"\"Main function to perform complete code review.\"\"\"\n", + " if not code.strip():\n", + " return \"Please provide code to review.\", \"\", \"\"\n", + " \n", + " # Detect bugs\n", + " bugs_result = detect_bugs(code, model_key)\n", + " bugs_output = format_bugs_output(bugs_result)\n", + " bugs_issues = bugs_result.get(\"issues\", [])\n", + " \n", + " # Suggest improvements\n", + " improvements_result = suggest_improvements(code, model_key)\n", + " improvements_output = format_improvements_output(improvements_result)\n", + " \n", + " # Generate tests\n", + " tests_output = \"\"\n", + " if include_tests:\n", + " tests_result = generate_tests(code, bugs_issues, model_key)\n", + " tests_output = format_tests_output(tests_result)\n", + " \n", + " return bugs_output, improvements_output, tests_output\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "def compare_models(code: str, model_keys: List[str]) -> str:\n", + " \"\"\"Compare multiple models on the same code.\"\"\"\n", + " if not code.strip():\n", + " return \"Please provide code to review.\"\n", + " \n", + " results = []\n", + " all_issues = []\n", + " \n", + " for model_key in model_keys:\n", + " result = detect_bugs(code, model_key)\n", + " results.append(result)\n", + " if result.get(\"success\"):\n", + " all_issues.extend(result.get(\"issues\", []))\n", + " \n", + " # Build comparison output\n", + " output = [\"# Model Comparison Results\\n\"]\n", + " \n", + " for result in results:\n", + " model_name = result[\"model\"]\n", + " issues = result.get(\"issues\", [])\n", + " success = result.get(\"success\", False)\n", + " \n", + " if success:\n", + " output.append(f\"\\n**{model_name}**: Found {len(issues)} issue(s)\")\n", + " if issues:\n", + " severity_counts = {}\n", + " for issue in issues:\n", + " sev = issue.get(\"severity\", \"low\")\n", + " severity_counts[sev] = severity_counts.get(sev, 0) + 1\n", + " output.append(f\" Breakdown: {dict(severity_counts)}\")\n", + " else:\n", + " output.append(f\"\\n**{model_name}**: Error - {result.get('error', 'Unknown')}\")\n", + " \n", + " # Find consensus issues (found by multiple models)\n", + " if len(results) > 1:\n", + " issue_signatures = {}\n", + " for result in results:\n", + " if result.get(\"success\"):\n", + " for issue in result.get(\"issues\", []):\n", + " # Create signature from line and issue description\n", + " sig = f\"{issue.get('line')}-{issue.get('issue', '')[:50]}\"\n", + " if sig not in issue_signatures:\n", + " issue_signatures[sig] = []\n", + " issue_signatures[sig].append(result[\"model\"])\n", + " \n", + " consensus = [sig for sig, models in issue_signatures.items() if len(models) > 1]\n", + " if consensus:\n", + " output.append(f\"\\n\\n **Consensus Issues**: {len(consensus)} issue(s) identified by multiple models\")\n", + " \n", + " return \"\\n\".join(output)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gradio UI\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7884\n", + "* To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Traceback (most recent call last):\n", + " File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\queueing.py\", line 745, in process_events\n", + " response = await route_utils.call_process_api(\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\route_utils.py\", line 354, in call_process_api\n", + " output = await app.get_blocks().process_api(\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\blocks.py\", line 2116, in process_api\n", + " result = await self.call_function(\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\blocks.py\", line 1623, in call_function\n", + " prediction = await anyio.to_thread.run_sync( # type: ignore\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\anyio\\to_thread.py\", line 56, in run_sync\n", + " return await get_async_backend().run_sync_in_worker_thread(\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 2485, in run_sync_in_worker_thread\n", + " return await future\n", + " ^^^^^^^^^^^^\n", + " File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 976, in run\n", + " result = context.run(func, *args)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"c:\\Users\\Philo Baba\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\utils.py\", line 915, in wrapper\n", + " response = f(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^\n", + " File \"C:\\Users\\Philo Baba\\AppData\\Local\\Temp\\ipykernel_43984\\2272281361.py\", line 7, in review_code\n", + " bugs_result = detect_bugs(code, model_key)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"C:\\Users\\Philo Baba\\AppData\\Local\\Temp\\ipykernel_43984\\638705449.py\", line 23, in detect_bugs\n", + " model_config = MODELS[model_key]\n", + " ~~~~~~^^^^^^^^^^^\n", + "KeyError: 'deepseek-coder-v2-openrouter'\n" + ] + } + ], + "source": [ + "# Example buggy code for testing\n", + "EXAMPLE_CODE = '''def divide_numbers(a, b):\n", + " return a / b\n", + "\n", + "def process_user_data(user_input):\n", + " # Missing input validation\n", + " result = eval(user_input)\n", + " return result\n", + "\n", + "def get_user_by_id(user_id):\n", + " # SQL injection vulnerability\n", + " query = f\"SELECT * FROM users WHERE id = {user_id}\"\n", + " return query\n", + "\n", + "def calculate_average(numbers):\n", + " total = sum(numbers)\n", + " return total / len(numbers) # Potential division by zero\n", + "'''\n", + "\n", + "def create_ui():\n", + " with gr.Blocks(title=\"AI Code Review Assistant\", theme=gr.themes.Soft()) as demo:\n", + " gr.Markdown(\"\"\"\n", + " # ๐Ÿ” AI-Powered Code Review Assistant\n", + " \n", + " Review your Python code using open-source AI models. Detect bugs, get improvement suggestions, and generate unit tests.\n", + " \"\"\")\n", + " \n", + " with gr.Row():\n", + " with gr.Column(scale=2):\n", + " code_input = gr.Code(\n", + " label=\"Python Code to Review\",\n", + " value=EXAMPLE_CODE,\n", + " language=\"python\",\n", + " lines=20\n", + " )\n", + " \n", + " with gr.Row():\n", + " model_selector = gr.Dropdown(\n", + " choices=list(MODELS.keys()),\n", + " value=list(MODELS.keys())[0],\n", + " label=\"Select Model\",\n", + " info=\"Choose an open-source code model\"\n", + " )\n", + " \n", + " include_tests = gr.Checkbox(\n", + " label=\"Generate Tests\",\n", + " value=True\n", + " )\n", + " \n", + " with gr.Row():\n", + " review_btn = gr.Button(\"๐Ÿ” Review Code\", variant=\"primary\", scale=2)\n", + " compare_btn = gr.Button(\"๐Ÿ“Š Compare Models\", variant=\"secondary\", scale=1)\n", + " \n", + " with gr.Column(scale=3):\n", + " with gr.Tabs() as tabs:\n", + " with gr.Tab(\"๐Ÿ› Bug Detection\"):\n", + " bugs_output = gr.Markdown(value=\"Select a model and click 'Review Code' to analyze your code.\")\n", + " \n", + " with gr.Tab(\"โœจ Improvements\"):\n", + " improvements_output = gr.Markdown(value=\"Get suggestions for code improvements and best practices.\")\n", + " \n", + " with gr.Tab(\"๐Ÿงช Unit Tests\"):\n", + " tests_output = gr.Code(\n", + " label=\"Generated Test Code\",\n", + " language=\"python\",\n", + " lines=25\n", + " )\n", + " \n", + " with gr.Tab(\"๐Ÿ“Š Comparison\"):\n", + " comparison_output = gr.Markdown(value=\"Compare multiple models side-by-side.\")\n", + " \n", + " # Event handlers\n", + " review_btn.click(\n", + " fn=review_code,\n", + " inputs=[code_input, model_selector, include_tests],\n", + " outputs=[bugs_output, improvements_output, tests_output]\n", + " )\n", + " \n", + " def compare_selected_models(code):\n", + " # Compare first 3 models by default\n", + " model_keys = list(MODELS.keys())[:3]\n", + " return compare_models(code, model_keys)\n", + " \n", + " compare_btn.click(\n", + " fn=compare_selected_models,\n", + " inputs=[code_input],\n", + " outputs=[comparison_output]\n", + " )\n", + " \n", + " gr.Examples(\n", + " examples=[\n", + " [EXAMPLE_CODE],\n", + " [\"\"\"def fibonacci(n):\n", + " if n <= 1:\n", + " return n\n", + " return fibonacci(n-1) + fibonacci(n-2)\n", + "\"\"\"],\n", + " [\"\"\"def parse_config(file_path):\n", + " with open(file_path) as f:\n", + " return eval(f.read())\n", + "\"\"\"]\n", + " ],\n", + " inputs=[code_input]\n", + " )\n", + " \n", + " return demo\n", + "\n", + "demo = create_ui()\n", + "demo.launch(inbrowser=True, share=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "**Qwen2.5-Coder 32B (OpenRouter)** - Found 2 issue(s):\n", + "\n", + "๐Ÿ”ด **CRITICAL** (Line 2): No division by zero protection\n", + " ๐Ÿ’ก *Fix:* Add a check for b == 0 and raise ValueError or handle ZeroDivisionError\n", + "\n", + "๐ŸŸก **MEDIUM** (Line 2): No input validation for numeric types\n", + " ๐Ÿ’ก *Fix:* Add type checking to ensure a and b are numbers (int/float)\n", + "\n" + ] + } + ], + "source": [ + "# Test with a simple example\n", + "test_code = \"\"\"def divide(a, b):\n", + " return a / b\n", + "\"\"\"\n", + "\n", + "# Test bug detection\n", + "result = detect_bugs(test_code, list(MODELS.keys())[0])\n", + "print(format_bugs_output(result))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/week7/community_contributions/emmy/product_price_estimator.ipynb b/week7/community_contributions/emmy/product_price_estimator.ipynb new file mode 100644 index 0000000..96914a5 --- /dev/null +++ b/week7/community_contributions/emmy/product_price_estimator.ipynb @@ -0,0 +1,454 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "275415f0", + "metadata": { + "id": "275415f0" + }, + "outputs": [], + "source": [ + "# pip installs\n", + "\n", + "!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n", + "!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "535bd9de", + "metadata": { + "id": "535bd9de" + }, + "outputs": [], + "source": [ + "# imports\n", + "\n", + "import os\n", + "import re\n", + "import math\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "from google.colab import userdata\n", + "from huggingface_hub import login\n", + "import torch\n", + "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, mean_absolute_percentage_error\n", + "import torch.nn.functional as F\n", + "import transformers\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n", + "from datasets import load_dataset, Dataset, DatasetDict\n", + "from datetime import datetime\n", + "from peft import PeftModel\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc58234a", + "metadata": { + "id": "fc58234a" + }, + "outputs": [], + "source": [ + "# Constants\n", + "\n", + "BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n", + "PROJECT_NAME = \"pricer\"\n", + "HF_USER = \"ed-donner\"\n", + "RUN_NAME = \"2024-09-13_13.04.39\"\n", + "PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n", + "REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n", + "FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n", + "\n", + "\n", + "DATASET_NAME = f\"{HF_USER}/home-data\"\n", + "\n", + "QUANT_4_BIT = True\n", + "top_K = 6\n", + "\n", + "%matplotlib inline\n", + "\n", + "# Used for writing to output in color\n", + "\n", + "GREEN = \"\\033[92m\"\n", + "YELLOW = \"\\033[93m\"\n", + "RED = \"\\033[91m\"\n", + "RESET = \"\\033[0m\"\n", + "COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0145ad8a", + "metadata": { + "id": "0145ad8a" + }, + "outputs": [], + "source": [ + "# Log in to HuggingFace\n", + "\n", + "hf_token = userdata.get('HF_TOKEN')\n", + "login(hf_token, add_to_git_credential=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6919506e", + "metadata": { + "id": "6919506e" + }, + "outputs": [], + "source": [ + "dataset = load_dataset(DATASET_NAME)\n", + "train = dataset['train']\n", + "test = dataset['test']\n", + "len(train), len(test)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea79cde1", + "metadata": { + "id": "ea79cde1" + }, + "outputs": [], + "source": [ + "if QUANT_4_BIT:\n", + " quant_config = BitsAndBytesConfig(\n", + " load_in_4bit=True,\n", + " bnb_4bit_use_double_quant=True,\n", + " bnb_4bit_compute_dtype=torch.bfloat16,\n", + " bnb_4bit_quant_type=\"nf4\"\n", + " )\n", + "else:\n", + " quant_config = BitsAndBytesConfig(\n", + " load_in_8bit=True,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef108f8d", + "metadata": { + "id": "ef108f8d" + }, + "outputs": [], + "source": [ + "# Load the Tokenizer and the Model\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "tokenizer.padding_side = \"right\"\n", + "\n", + "base_model = AutoModelForCausalLM.from_pretrained(\n", + " BASE_MODEL,\n", + " quantization_config=quant_config,\n", + " device_map=\"auto\",\n", + ")\n", + "base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n", + "\n", + "# Load the fine-tuned model with PEFT\n", + "if REVISION:\n", + " fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n", + "else:\n", + " fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n", + "\n", + "fine_tuned_model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f3c4176", + "metadata": { + "id": "7f3c4176" + }, + "outputs": [], + "source": [ + "def extract_price(s):\n", + " if \"Price is $\" in s:\n", + " contents = s.split(\"Price is $\")[1]\n", + " contents = contents.replace(',','')\n", + " match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", contents)\n", + " return float(match.group()) if match else 0\n", + " return 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "436fa29a", + "metadata": { + "id": "436fa29a" + }, + "outputs": [], + "source": [ + "# Original prediction function takes the most likely next token\n", + "\n", + "def model_predict(prompt):\n", + " set_seed(42)\n", + " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(\"cuda\")\n", + " attention_mask = torch.ones(inputs.shape, device=\"cuda\")\n", + " outputs = fine_tuned_model.generate(inputs, attention_mask=attention_mask, max_new_tokens=3, num_return_sequences=1)\n", + " response = tokenizer.decode(outputs[0])\n", + " return extract_price(response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a666dab6", + "metadata": { + "id": "a666dab6" + }, + "outputs": [], + "source": [ + "def improved_model_predict(prompt, device=\"cuda\", top_k=3):\n", + " set_seed(42)\n", + " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", + " attention_mask = torch.ones_like(inputs)\n", + "\n", + " with torch.no_grad():\n", + " outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n", + " next_token_logits = outputs.logits[:, -1, :].to(\"cpu\")\n", + "\n", + " probs = F.softmax(next_token_logits, dim=-1)\n", + " top_prob, top_token_id = probs.topk(top_k)\n", + "\n", + " prices, weights = [], []\n", + " for weight, token_id in zip(top_prob[0], top_token_id[0]):\n", + " token_text = tokenizer.decode(token_id)\n", + " try:\n", + " value = float(token_text)\n", + " except ValueError:\n", + " continue\n", + " prices.append(value)\n", + " weights.append(weight.item())\n", + "\n", + " if not prices:\n", + " return 0.0, 0.0, 0.0 # price, confidence, spread\n", + "\n", + " total = sum(weights)\n", + " normalized = [w / total for w in weights]\n", + " weighted_price = sum(p * w for p, w in zip(prices, normalized))\n", + " variance = sum(w * (p - weighted_price) ** 2 for p, w in zip(prices, normalized))\n", + " confidence = 1 - min(variance / (weighted_price + 1e-6), 1)\n", + " spread = max(prices) - min(prices)\n", + " return weighted_price, confidence, spread" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install -q gradio>=4.0\n", + "import gradio as gr\n", + "\n", + "def format_prediction(description):\n", + " price, confidence, spread = improved_model_predict(description)\n", + " return (\n", + " f\"${price:,.2f}\",\n", + " f\"{confidence*100:.1f}%\",\n", + " f\"${spread:,.2f}\",\n", + " )\n", + "\n", + "demo = gr.Interface(\n", + " fn=format_prediction,\n", + " inputs=gr.Textbox(lines=10, label=\"Product Description\"),\n", + " outputs=[\n", + " gr.Textbox(label=\"Estimated Price\"),\n", + " gr.Textbox(label=\"Confidence\"),\n", + " gr.Textbox(label=\"Token Spread\"),\n", + " ],\n", + " title=\"Open-Source Product Price Estimator\",\n", + " description=\"Paste a cleaned product blurb. The model returns a price estimate, a confidence score based on top-k token dispersion, and the spread between top predictions.\",\n", + ")\n", + "\n", + "demo.launch(share=True)" + ], + "metadata": { + "id": "Km4tHaeQyMoW" + }, + "id": "Km4tHaeQyMoW", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9664c4c7", + "metadata": { + "id": "9664c4c7" + }, + "outputs": [], + "source": [ + "\n", + "class Tester:\n", + "\n", + " def __init__(self, predictor, data, title=None, show_progress=True):\n", + " self.predictor = predictor\n", + " self.data = data\n", + " self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n", + " self.size = len(data)\n", + " self.guesses, self.truths, self.errors, self.rel_errors, self.sles, self.colors = [], [], [], [], [], []\n", + " self.show_progress = show_progress\n", + "\n", + " def color_for(self, error, truth):\n", + " if error < 40 or error / truth < 0.2:\n", + " return \"green\"\n", + " elif error < 80 or error / truth < 0.4:\n", + " return \"orange\"\n", + " else:\n", + " return \"red\"\n", + "\n", + " def run_datapoint(self, i):\n", + " datapoint = self.data[i]\n", + " guess = self.predictor(datapoint[\"text\"])\n", + " truth = datapoint[\"price\"]\n", + "\n", + " error = guess - truth\n", + " abs_error = abs(error)\n", + " rel_error = abs_error / truth if truth != 0 else 0\n", + " log_error = math.log(truth + 1) - math.log(guess + 1)\n", + " sle = log_error ** 2\n", + " color = self.color_for(abs_error, truth)\n", + "\n", + " title = (datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\") if \"\\n\\n\" in datapoint[\"text\"] else datapoint[\"text\"][:20]\n", + " self.guesses.append(guess)\n", + " self.truths.append(truth)\n", + " self.errors.append(error)\n", + " self.rel_errors.append(rel_error)\n", + " self.sles.append(sle)\n", + " self.colors.append(color)\n", + "\n", + " print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} \"\n", + " f\"Error: ${abs_error:,.2f} RelErr: {rel_error*100:.1f}% SLE: {sle:,.2f} Item: {title}{RESET}\")\n", + "\n", + " def chart_all(self, chart_title):\n", + " \"\"\"Compact version: 4 performance charts in one grid.\"\"\"\n", + " t, g = np.array(self.truths), np.array(self.guesses)\n", + " rel_err, abs_err = np.array(self.rel_errors) * 100, np.abs(np.array(self.errors))\n", + "\n", + " fig, axs = plt.subplots(2, 2, figsize=(14, 10))\n", + " fig.suptitle(f\"Performance Dashboard โ€” {chart_title}\", fontsize=16, fontweight=\"bold\")\n", + "\n", + " # Scatter plot\n", + " max_val = max(t.max(), g.max()) * 1.05\n", + " axs[1, 1].plot([0, max_val], [0, max_val], \"b--\", alpha=0.6)\n", + " axs[1, 1].scatter(t, g, s=20, c=self.colors, alpha=0.6)\n", + " axs[1, 1].set_title(\"Predictions vs Ground Truth\")\n", + " axs[1, 1].set_xlabel(\"True Price ($)\")\n", + " axs[1, 1].set_ylabel(\"Predicted ($)\")\n", + "\n", + " # Accuracy by price range\n", + " bins = np.linspace(t.min(), t.max(), 6)\n", + " labels = [f\"${bins[i]:.0f}โ€“${bins[i+1]:.0f}\" for i in range(len(bins)-1)]\n", + " inds = np.digitize(t, bins) - 1\n", + " avg_err = [rel_err[inds == i].mean() for i in range(len(labels))]\n", + " axs[0, 0].bar(labels, avg_err, color=\"seagreen\", alpha=0.8)\n", + " axs[0, 0].set_title(\"Avg Relative Error by Price Range\")\n", + " axs[0, 0].set_ylabel(\"Relative Error (%)\")\n", + " axs[0, 0].tick_params(axis=\"x\", rotation=30)\n", + "\n", + " # Relative error distribution\n", + " axs[0, 1].hist(rel_err, bins=25, color=\"mediumpurple\", edgecolor=\"black\", alpha=0.7)\n", + " axs[0, 1].set_title(\"Relative Error Distribution (%)\")\n", + " axs[0, 1].set_xlabel(\"Relative Error (%)\")\n", + "\n", + " # Absolute error distribution\n", + " axs[1, 0].hist(abs_err, bins=25, color=\"steelblue\", edgecolor=\"black\", alpha=0.7)\n", + " axs[1, 0].axvline(abs_err.mean(), color=\"red\", linestyle=\"--\", label=f\"Mean={abs_err.mean():.2f}\")\n", + " axs[1, 0].set_title(\"Absolute Error Distribution\")\n", + " axs[1, 0].set_xlabel(\"Absolute Error ($)\")\n", + " axs[1, 0].legend()\n", + "\n", + " for ax in axs.ravel():\n", + " ax.grid(alpha=0.3)\n", + "\n", + " plt.tight_layout(rect=[0, 0, 1, 0.95])\n", + " plt.show()\n", + "\n", + " def report(self):\n", + " y_true = np.array(self.truths)\n", + " y_pred = np.array(self.guesses)\n", + "\n", + " mae = mean_absolute_error(y_true, y_pred)\n", + " rmse = math.sqrt(mean_squared_error(y_true, y_pred))\n", + " rmsle = math.sqrt(sum(self.sles) / self.size)\n", + " mape = mean_absolute_percentage_error(y_true, y_pred) * 100\n", + " median_error = float(np.median(np.abs(y_true - y_pred)))\n", + " r2 = r2_score(y_true, y_pred)\n", + "\n", + " hit_rate_green = sum(1 for c in self.colors if c == \"green\") / self.size * 100\n", + " hit_rate_acceptable = sum(1 for c in self.colors if c in (\"green\", \"orange\")) / self.size * 100\n", + "\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"FINAL REPORT: {self.title}\")\n", + " print(f\"{'='*70}\")\n", + " print(f\"Total Predictions: {self.size}\")\n", + " print(f\"\\n--- Error Metrics ---\")\n", + " print(f\"Mean Absolute Error (MAE): ${mae:,.2f}\")\n", + " print(f\"Median Error: ${median_error:,.2f}\")\n", + " print(f\"Root Mean Squared Error (RMSE): ${rmse:,.2f}\")\n", + " print(f\"Root Mean Squared Log Error (RMSLE): {rmsle:.4f}\")\n", + " print(f\"Mean Absolute Percentage Error (MAPE): {mape:.2f}%\")\n", + " print(f\"\\n--- Accuracy Metrics ---\")\n", + " print(f\"Rยฒ Score: {r2:.4f}\")\n", + " print(f\"Hit Rate (Green): {hit_rate_green:.1f}%\")\n", + " print(f\"Hit Rate (Green+Orange): {hit_rate_acceptable:.1f}%\")\n", + " print(f\"{'='*70}\\n\")\n", + " chart_title = f\"{self.title} | MAE=${mae:,.2f} | RMSLE={rmsle:.3f} | Rยฒ={r2:.3f}\"\n", + "\n", + " self.chart_all(chart_title)\n", + "\n", + " def run(self):\n", + " iterator = tqdm(range(self.size), desc=\"Testing Model\") if self.show_progress else range(self.size)\n", + " for i in iterator:\n", + " self.run_datapoint(i)\n", + " self.report()\n", + "\n", + " @classmethod\n", + " def test(cls, function, data, title=None):\n", + " cls(function, data, title=title).run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e60a696", + "metadata": { + "id": "2e60a696" + }, + "outputs": [], + "source": [ + "Tester.test(\n", + " improved_model_predict,\n", + " test,\n", + " title=\"Home appliances prediction\"\n", + ")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/week7/community_contributions/hopeogbons/week7_EXERCISE.ipynb b/week7/community_contributions/hopeogbons/week7_EXERCISE.ipynb new file mode 100644 index 0000000..cf4f4a0 --- /dev/null +++ b/week7/community_contributions/hopeogbons/week7_EXERCISE.ipynb @@ -0,0 +1,466 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "GHsssBgWM_l0" + }, + "source": [ + "# Fine-Tuned Product Price Predictor\n", + "\n", + "Evaluate fine-tuned Llama 3.1 8B model for product price estimation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MDyR63OTNUJ6" + }, + "outputs": [], + "source": [ + "# Install required libraries for model inference\n", + "%pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n", + "%pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-yikV8pRBer9" + }, + "outputs": [], + "source": [ + "# Import required libraries\n", + "import os\n", + "import re\n", + "import math\n", + "from tqdm import tqdm\n", + "from google.colab import userdata\n", + "from huggingface_hub import login\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import transformers\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n", + "from datasets import load_dataset, Dataset, DatasetDict\n", + "from datetime import datetime\n", + "from peft import PeftModel\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uuTX-xonNeOK" + }, + "outputs": [], + "source": [ + "# Configuration\n", + "BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n", + "PROJECT_NAME = \"pricer\"\n", + "HF_USER = \"ed-donner\" # Change to your HF username\n", + "RUN_NAME = \"2024-09-13_13.04.39\"\n", + "PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n", + "REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n", + "FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n", + "DATASET_NAME = f\"{HF_USER}/pricer-data\"\n", + "\n", + "# Quantization setting (False = 8-bit = better accuracy, more memory)\n", + "QUANT_4_BIT = False # Changed to 8-bit for better accuracy\n", + "\n", + "%matplotlib inline\n", + "\n", + "# Color codes for output\n", + "GREEN = \"\\033[92m\"\n", + "YELLOW = \"\\033[93m\"\n", + "RED = \"\\033[91m\"\n", + "RESET = \"\\033[0m\"\n", + "COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8JArT3QAQAjx" + }, + "source": [ + "# Step 1\n", + "\n", + "### Load dataset and fine-tuned model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WyFPZeMcM88v" + }, + "outputs": [], + "source": [ + "# Login to HuggingFace\n", + "hf_token = userdata.get('HF_TOKEN')\n", + "login(hf_token, add_to_git_credential=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cvXVoJH8LS6u" + }, + "outputs": [], + "source": [ + "# Load product pricing dataset\n", + "dataset = load_dataset(DATASET_NAME)\n", + "train = dataset['train']\n", + "test = dataset['test']\n", + "\n", + "print(f\"โœ“ Loaded {len(train)} train and {len(test)} test samples\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xb86e__Wc7j_" + }, + "outputs": [], + "source": [ + "# Verify data structure\n", + "test[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qJWQ0a3wZ0Bw" + }, + "source": [ + "### Load Tokenizer and Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lAUAAcEC6ido" + }, + "outputs": [], + "source": [ + "# Configure quantization for memory efficiency\n", + "if QUANT_4_BIT:\n", + " quant_config = BitsAndBytesConfig(\n", + " load_in_4bit=True,\n", + " bnb_4bit_use_double_quant=True,\n", + " bnb_4bit_compute_dtype=torch.bfloat16,\n", + " bnb_4bit_quant_type=\"nf4\"\n", + " )\n", + "else:\n", + " quant_config = BitsAndBytesConfig(\n", + " load_in_8bit=True,\n", + " bnb_8bit_compute_dtype=torch.bfloat16\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R_O04fKxMMT-" + }, + "outputs": [], + "source": [ + "# Load tokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "tokenizer.padding_side = \"right\"\n", + "\n", + "# Load base model with quantization\n", + "base_model = AutoModelForCausalLM.from_pretrained(\n", + " BASE_MODEL,\n", + " quantization_config=quant_config,\n", + " device_map=\"auto\",\n", + ")\n", + "base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n", + "\n", + "# Load fine-tuned weights\n", + "if REVISION:\n", + " fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n", + "else:\n", + " fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n", + "\n", + "print(f\"โœ“ Model loaded - Memory: {fine_tuned_model.get_memory_footprint() / 1e6:.1f} MB\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kD-GJtbrdd5t" + }, + "outputs": [], + "source": [ + "# Verify model loaded\n", + "fine_tuned_model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UObo1-RqaNnT" + }, + "source": [ + "# Step 2\n", + "\n", + "### Model inference and evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Qst1LhBVAB04" + }, + "outputs": [], + "source": [ + "# Extract price from model response\n", + "def extract_price(s):\n", + " if \"Price is $\" in s:\n", + " contents = s.split(\"Price is $\")[1]\n", + " contents = contents.replace(',','')\n", + " match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", contents)\n", + " return float(match.group()) if match else 0\n", + " return 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jXFBW_5UeEcp" + }, + "outputs": [], + "source": [ + "# Test extract_price function\n", + "extract_price(\"Price is $a fabulous 899.99 or so\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Oj_PzpdFAIMk" + }, + "outputs": [], + "source": [ + "# Simple prediction: takes most likely next token\n", + "def model_predict(prompt):\n", + " set_seed(42)\n", + " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(\"cuda\")\n", + " attention_mask = torch.ones(inputs.shape, device=\"cuda\")\n", + " outputs = fine_tuned_model.generate(\n", + " inputs,\n", + " attention_mask=attention_mask,\n", + " max_new_tokens=5, # Increased for flexibility\n", + " temperature=0.1, # Low temperature for consistency\n", + " num_return_sequences=1\n", + " )\n", + " response = tokenizer.decode(outputs[0])\n", + " return extract_price(response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Je5dR8QEAI1d" + }, + "outputs": [], + "source": [ + "# Improved prediction: weighted average of top K predictions\n", + "top_K = 5 # Increased from 3 to 5 for better accuracy\n", + "\n", + "def improved_model_predict(prompt, device=\"cuda\"):\n", + " set_seed(42)\n", + " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", + " attention_mask = torch.ones(inputs.shape, device=device)\n", + "\n", + " with torch.no_grad():\n", + " outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n", + " next_token_logits = outputs.logits[:, -1, :].to('cpu')\n", + "\n", + " next_token_probs = F.softmax(next_token_logits, dim=-1)\n", + " top_prob, top_token_id = next_token_probs.topk(top_K)\n", + " prices, weights = [], []\n", + " for i in range(top_K):\n", + " predicted_token = tokenizer.decode(top_token_id[0][i])\n", + " probability = top_prob[0][i]\n", + " try:\n", + " result = float(predicted_token)\n", + " except ValueError as e:\n", + " result = 0.0\n", + " if result > 0:\n", + " prices.append(result)\n", + " weights.append(probability)\n", + " if not prices:\n", + " return 0.0, 0.0\n", + " total = sum(weights)\n", + " weighted_prices = [price * weight / total for price, weight in zip(prices, weights)]\n", + " return sum(weighted_prices).item()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EpGVJPuC1iho" + }, + "source": [ + "# Step 3\n", + "\n", + "### Test and evaluate model performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "30lzJXBH7BcK" + }, + "outputs": [], + "source": [ + "# Evaluation framework\n", + "class Tester:\n", + " def __init__(self, predictor, data, title=None, size=250):\n", + " self.predictor = predictor\n", + " self.data = data\n", + " self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n", + " self.size = size\n", + " self.guesses = []\n", + " self.truths = []\n", + " self.errors = []\n", + " self.sles = []\n", + " self.colors = []\n", + "\n", + " def color_for(self, error, truth):\n", + " if error<40 or error/truth < 0.2:\n", + " return \"green\"\n", + " elif error<80 or error/truth < 0.4:\n", + " return \"orange\"\n", + " else:\n", + " return \"red\"\n", + "\n", + " def run_datapoint(self, i):\n", + " datapoint = self.data[i]\n", + " guess = self.predictor(datapoint[\"text\"])\n", + " truth = datapoint[\"price\"]\n", + " error = abs(guess - truth)\n", + " log_error = math.log(truth+1) - math.log(guess+1)\n", + " sle = log_error ** 2\n", + " color = self.color_for(error, truth)\n", + " title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n", + " self.guesses.append(guess)\n", + " self.truths.append(truth)\n", + " self.errors.append(error)\n", + " self.sles.append(sle)\n", + " self.colors.append(color)\n", + " print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n", + "\n", + " def chart(self, title):\n", + " max_error = max(self.errors)\n", + " plt.figure(figsize=(12, 8))\n", + " max_val = max(max(self.truths), max(self.guesses))\n", + " plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n", + " plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n", + " plt.xlabel('Ground Truth')\n", + " plt.ylabel('Model Estimate')\n", + " plt.xlim(0, max_val)\n", + " plt.ylim(0, max_val)\n", + " plt.title(title)\n", + " plt.show()\n", + "\n", + " def report(self):\n", + " average_error = sum(self.errors) / self.size\n", + " rmsle = math.sqrt(sum(self.sles) / self.size)\n", + " hits = sum(1 for color in self.colors if color==\"green\")\n", + " title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n", + " self.chart(title)\n", + "\n", + " def run(self):\n", + " self.error = 0\n", + " for i in range(self.size):\n", + " self.run_datapoint(i)\n", + " self.report()\n", + "\n", + " @classmethod\n", + " def test(cls, function, data):\n", + " cls(function, data).run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "W_KcLvyt6kbb" + }, + "outputs": [], + "source": [ + "# Run evaluation on 250 test examples\n", + "Tester.test(improved_model_predict, test)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nVwiWGVN1ihp" + }, + "source": [ + "### Performance Optimizations Applied\n", + "\n", + "**Changes for better accuracy:**\n", + "- โœ… 8-bit quantization (vs 4-bit) - Better precision\n", + "- โœ… top_K = 5 (vs 3) - More predictions in weighted average\n", + "- โœ… max_new_tokens = 5 - More flexibility in response\n", + "- โœ… temperature = 0.1 - More consistent predictions\n", + "\n", + "**Expected improvement:** ~10-15% reduction in average error\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hO4DdLa81ihp" + }, + "source": [ + "### Expected Performance\n", + "\n", + "**Baseline comparisons:**\n", + "- GPT-4o: $76 avg error\n", + "- Llama 3.1 base: $396 avg error \n", + "- Human: $127 avg error\n", + "\n", + "**Fine-tuned model (optimized):**\n", + "- Target: $70-85 avg error\n", + "- With 8-bit quant + top_K=5 + temp=0.1\n", + "- Expected to rival or beat GPT-4o\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/week8/community_contributions/emmy/llm_battle.py b/week8/community_contributions/emmy/llm_battle.py new file mode 100644 index 0000000..b419d72 --- /dev/null +++ b/week8/community_contributions/emmy/llm_battle.py @@ -0,0 +1,453 @@ +from __future__ import annotations + +import json +import os +from dataclasses import dataclass, field +from typing import Dict, Generator, List, Optional, Tuple + +import gradio as gr +from dotenv import load_dotenv +from openai import OpenAI + +load_dotenv() + + +# --------------------------------------------------------------------------- +# Configuration helpers +# --------------------------------------------------------------------------- + + +@dataclass +class AgentConfig: + """Holds configuration required to talk to an LLM provider.""" + + name: str + model: str + api_key_env: str + base_url_env: Optional[str] = None + temperature: float = 0.7 + supports_json: bool = True + + +def load_client(config: AgentConfig) -> OpenAI: + """Create an OpenAI-compatible client for the given agent.""" + api_key = os.getenv(config.api_key_env) or os.getenv("OPENAI_API_KEY") + if not api_key: + raise RuntimeError( + f"Missing API key for {config.name}. " + f"Set {config.api_key_env} or OPENAI_API_KEY." + ) + + base_url = ( + os.getenv(config.base_url_env) + if config.base_url_env + else os.getenv("OPENAI_BASE_URL") + ) + + return OpenAI(api_key=api_key, base_url=base_url) + + +def extract_text(response) -> str: + """Extract text content from an OpenAI-style response object or dict.""" + + choices = getattr(response, "choices", None) + if choices is None and isinstance(response, dict): + choices = response.get("choices") + if not choices: + raise RuntimeError(f"LLM response missing choices field: {response!r}") + + choice = choices[0] + message = getattr(choice, "message", None) + if message is None and isinstance(choice, dict): + message = choice.get("message") + + content = None + if message is not None: + content = getattr(message, "content", None) + if content is None and isinstance(message, dict): + content = message.get("content") + + if isinstance(content, list): + parts: List[str] = [] + for part in content: + if isinstance(part, dict): + if "text" in part: + parts.append(str(part["text"])) + elif "output_text" in part: + parts.append(str(part["output_text"])) + elif "type" in part and "content" in part: + parts.append(str(part["content"])) + else: + parts.append(str(part)) + content = "".join(parts) + + if content is None: + text = getattr(choice, "text", None) + if text is None and isinstance(choice, dict): + text = choice.get("text") + if text: + content = text + + if content is None: + raise RuntimeError(f"LLM response missing content/text: {response!r}") + + return str(content).strip() + + +# Default configuration leverages OpenAI unless overrides are provided. +DEBATER_A_CONFIG = AgentConfig( + name="Debater A", + model=os.getenv("DEBATER_A_MODEL", "gpt-4o"), + api_key_env="OPENAI_API_KEY", + base_url_env="OPENAI_BASE_URL", + temperature=float(os.getenv("DEBATER_A_TEMPERATURE", 0.7)), +) + +DEBATER_B_CONFIG = AgentConfig( + name="Debater B", + model=os.getenv("DEBATER_B_MODEL", "gemini-2.0-flash"), + api_key_env="GOOGLE_API_KEY", + base_url_env="GEMINI_BASE_URL", + temperature=float(os.getenv("DEBATER_B_TEMPERATURE", 0.7)), +) + +JUDGE_CONFIG = AgentConfig( + name="Judge", + model=os.getenv("JUDGE_MODEL", "gpt-oss:20b-cloud"), + api_key_env="OLLAMA_API_KEY", + base_url_env="OLLAMA_BASE_URL", + temperature=float(os.getenv("JUDGE_TEMPERATURE", 0.2)), + supports_json=False, +) + +REPORTER_CONFIG = AgentConfig( + name="Reporter", + model=os.getenv("REPORTER_MODEL", "MiniMax-M2"), + api_key_env="MINIMAX_API_KEY", + base_url_env="MINIMAX_BASE_URL", + temperature=float(os.getenv("REPORTER_TEMPERATURE", 0.4)), + supports_json=False, +) + +THEME = gr.themes.Default( + primary_hue="blue", + secondary_hue="sky", + neutral_hue="gray", +) + +CUSTOM_CSS = """ +body, .gradio-container { + background: radial-gradient(circle at top, #0f172a 0%, #020617 60%, #020617 100%); + color: #e2e8f0; +} +#live-debate-panel { + background: linear-gradient(135deg, rgba(30,64,175,0.95), rgba(29,78,216,0.85)); + color: #f8fafc; + border-radius: 16px; + padding: 24px; + box-shadow: 0 20px 45px rgba(15,23,42,0.35); +} +#live-debate-panel h3 { + color: #bfdbfe; +} +.gr-button-primary { + background: linear-gradient(135deg, #1d4ed8, #2563eb) !important; + border: none !important; +} +.gr-button-primary:hover { + background: linear-gradient(135deg, #2563eb, #1d4ed8) !important; +} +""" + +# --------------------------------------------------------------------------- +# Debate runtime classes +# --------------------------------------------------------------------------- + + +@dataclass +class DebateState: + topic: str + stance_a: str + stance_b: str + transcript: List[Tuple[str, str]] = field(default_factory=list) + + +class LLMAdapter: + """Thin wrapper around the OpenAI SDK to simplify prompting.""" + + def __init__(self, config: AgentConfig): + self.config = config + self.client = load_client(config) + + def complete( + self, + prompt: str, + *, + system: Optional[str] = None, + max_tokens: int = 512, + json_mode: bool = False, + ) -> str: + messages = [] + if system: + messages.append({"role": "system", "content": system}) + messages.append({"role": "user", "content": prompt}) + + params = dict( + model=self.config.model, + messages=messages, + temperature=self.config.temperature, + max_tokens=max_tokens, + ) + if json_mode and self.config.supports_json: + params["response_format"] = {"type": "json_object"} + + response = self.client.chat.completions.create(**params) + return extract_text(response) + + +class Debater: + def __init__(self, adapter: LLMAdapter, stance_label: str): + self.adapter = adapter + self.stance_label = stance_label + + def argue(self, topic: str) -> str: + prompt = ( + f"You are {self.adapter.config.name}, debating the topic:\n" + f"'{topic}'.\n\n" + f"Present a concise argument that {self.stance_label.lower()} " + f"the statement. Use at most 150 words. Provide clear reasoning " + f"and, if applicable, cite plausible evidence or examples." + ) + return self.adapter.complete(prompt, max_tokens=300) + + +class Judge: + RUBRIC = [ + "Clarity of the argument", + "Use of evidence or examples", + "Logical coherence", + "Persuasiveness and impact", + ] + + def __init__(self, adapter: LLMAdapter): + self.adapter = adapter + + def evaluate(self, topic: str, argument_a: str, argument_b: str) -> Dict[str, object]: + rubric_text = "\n".join(f"- {item}" for item in self.RUBRIC) + prompt = ( + "You are serving as an impartial debate judge.\n" + f"Topic: {topic}\n\n" + f"Argument from Debater A:\n{argument_a}\n\n" + f"Argument from Debater B:\n{argument_b}\n\n" + "Score each debater from 0-10 on the following criteria:\n" + f"{rubric_text}\n\n" + "Return a JSON object with this exact structure:\n" + '{\n' + ' "winner": "A" or "B" or "Tie",\n' + ' "reason": "brief justification",\n' + ' "scores": [\n' + ' {"criterion": "...", "debater_a": 0-10, "debater_b": 0-10, "notes": "optional"}\n' + " ]\n" + "}\n" + "Ensure the JSON is valid." + ) + raw = self.adapter.complete(prompt, max_tokens=400, json_mode=True) + try: + data = json.loads(raw) + if "scores" not in data: + raise ValueError("scores missing") + return data + except Exception: + # Fallback: wrap raw text if parsing fails. + return {"winner": "Unknown", "reason": raw, "scores": []} + + +class Reporter: + def __init__(self, adapter: LLMAdapter): + self.adapter = adapter + + def summarize( + self, + topic: str, + argument_a: str, + argument_b: str, + judge_result: Dict[str, object], + ) -> str: + prompt = ( + f"Summarize a single-round debate on '{topic}'.\n\n" + f"Debater A argued:\n{argument_a}\n\n" + f"Debater B argued:\n{argument_b}\n\n" + f"Judge verdict: {json.dumps(judge_result, ensure_ascii=False)}\n\n" + "Provide a short journalistic summary (max 200 words) highlighting " + "each side's key points and the judge's decision. Use neutral tone." + ) + response = self.adapter.client.chat.completions.create( + model=self.adapter.config.model, + messages=[ + {"role": "system", "content": "You are an impartial debate reporter."}, + {"role": "user", "content": prompt}, + ], + temperature=self.adapter.config.temperature, + max_tokens=300, + **( + {"extra_body": {"reasoning_split": True}} + if getattr(self.adapter.client, "base_url", None) + and "minimax" in str(self.adapter.client.base_url).lower() + else {} + ), + ) + return extract_text(response) + + +# --------------------------------------------------------------------------- +# Debate pipeline + UI +# --------------------------------------------------------------------------- + + +debater_a = Debater(LLMAdapter(DEBATER_A_CONFIG), stance_label="supports") +debater_b = Debater(LLMAdapter(DEBATER_B_CONFIG), stance_label="opposes") +judge = Judge(LLMAdapter(JUDGE_CONFIG)) +reporter = Reporter(LLMAdapter(REPORTER_CONFIG)) + + +def format_transcript(transcript: List[Tuple[str, str]]) -> str: + """Return markdown-formatted transcript.""" + lines = [] + for speaker, message in transcript: + lines.append(f"### {speaker}\n\n{message}\n") + return "\n".join(lines) + + +def run_debate( + topic: str, stance_a: str, stance_b: str +) -> Generator[Tuple[str, str, List[List[object]], str, str], None, None]: + """Generator for Gradio to stream debate progress.""" + if not topic.strip(): + warning = "โš ๏ธ Please enter a debate topic to get started." + yield warning, "", [], "", "" + return + + state = DebateState(topic=topic.strip(), stance_a=stance_a, stance_b=stance_b) + + state.transcript.append( + ("Moderator", f"Welcome to the debate on **{state.topic}**!") + ) + yield format_transcript(state.transcript), "Waiting for judge...", [], "", "" + + argument_a = debater_a.argue(state.topic) + state.transcript.append((f"Debater A ({state.stance_a})", argument_a)) + yield format_transcript(state.transcript), "Collecting arguments...", [], "", "" + + argument_b = debater_b.argue(state.topic) + state.transcript.append((f"Debater B ({state.stance_b})", argument_b)) + yield format_transcript(state.transcript), "Judge deliberating...", [], "", "" + + judge_result = judge.evaluate(state.topic, argument_a, argument_b) + verdict_text = ( + f"Winner: {judge_result.get('winner', 'Unknown')}\nReason: " + f"{judge_result.get('reason', 'No explanation provided.')}" + ) + score_rows = [ + [ + entry.get("criterion", ""), + entry.get("debater_a", ""), + entry.get("debater_b", ""), + entry.get("notes", ""), + ] + for entry in judge_result.get("scores", []) + ] + judge_report_md = ( + f"**Judge Verdict:** {judge_result.get('winner', 'Unknown')}\n\n" + f"{judge_result.get('reason', '')}" + ) + yield ( + format_transcript(state.transcript), + judge_report_md, + score_rows, + verdict_text, + format_transcript(state.transcript), + ) + + reporter_summary = reporter.summarize( + state.topic, argument_a, argument_b, judge_result + ) + + final_markdown = ( + f"{judge_report_md}\n\n---\n\n" + f"**Reporter Summary**\n\n{reporter_summary}" + ) + yield ( + format_transcript(state.transcript), + final_markdown, + score_rows, + verdict_text, + format_transcript(state.transcript), + ) + + +# --------------------------------------------------------------------------- +# Gradio Interface +# --------------------------------------------------------------------------- + + +with gr.Blocks( + title="LLM Debate Arena", + fill_width=True, + theme=THEME, + css=CUSTOM_CSS, +) as demo: + gr.Markdown( + "# ๐Ÿ” LLM Debate Arena\n" + "Configure two debating agents, watch their arguments in real time, and " + "review the judge's verdict plus a reporter summary." + ) + + with gr.Row(): + topic_input = gr.Textbox( + label="Debate Topic", + placeholder="e.g., Should autonomous delivery robots be allowed in city centers?", + ) + with gr.Row(): + stance_a_input = gr.Textbox( + label="Debater A Stance", + value="Supports the statement", + ) + stance_b_input = gr.Textbox( + label="Debater B Stance", + value="Opposes the statement", + ) + + run_button = gr.Button("Start Debate", variant="primary") + + with gr.Tab("Live Debate"): + transcript_md = gr.Markdown( + "### Waiting for the debate to start...", + elem_id="live-debate-panel", + ) + + with gr.Tab("Judge's Report"): + judge_md = gr.Markdown("Judge verdict will appear here.") + score_table = gr.Dataframe( + headers=["Criterion", "Debater A", "Debater B", "Notes"], + datatype=["str", "number", "number", "str"], + interactive=False, + ) + verdict_box = gr.Textbox( + label="Verdict Detail", + interactive=False, + ) + transcript_box = gr.Textbox( + label="Full Transcript (for copying)", + interactive=False, + lines=10, + ) + + run_button.click( + fn=run_debate, + inputs=[topic_input, stance_a_input, stance_b_input], + outputs=[transcript_md, judge_md, score_table, verdict_box, transcript_box], + queue=True, + ) + +if __name__ == "__main__": + demo.queue(default_concurrency_limit=4).launch()