diff --git a/week4/community-contributions/w4d5-Trade.ipynb b/week4/community-contributions/w4d5-Trade.ipynb new file mode 100644 index 0000000..3a57afa --- /dev/null +++ b/week4/community-contributions/w4d5-Trade.ipynb @@ -0,0 +1,1833 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Trading Code Generator\n", + "\n", + "This notebook creates a code generator that produces trading code to buy and sell equities in a simulated environment based on free APIs. It uses Gradio for the UI, similar to the approach in day5.ipynb.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import io\n", + "import sys\n", + "import time\n", + "import random\n", + "import numpy as np\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "from IPython.display import display\n", + "from huggingface_hub import InferenceClient\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OpenAI API Key exists and begins sk-proj-\n", + "Hugging Face Token exists and begins hf_fNncb\n" + ] + } + ], + "source": [ + "load_dotenv(override=True)\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "hf_token = os.getenv('HF_TOKEN')\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + " \n", + "if hf_token:\n", + " print(f\"Hugging Face Token exists and begins {hf_token[:8]}\")\n", + "else:\n", + " print(\"Hugging Face Token not set\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "openai_client = OpenAI()\n", + "hf_client = InferenceClient(token=hf_token)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "models = [\"gpt-4o\", \"gpt-3.5-turbo\", \"meta-llama/Llama-2-70b-chat-hf\"]\n", + "\n", + "def generate_with_openai(model, messages):\n", + " response = openai_client.chat.completions.create(\n", + " model=model, \n", + " messages=messages\n", + " )\n", + " return response.choices[0].message.content\n", + "\n", + "def generate_with_hf(model, messages):\n", + " prompt = \"\"\n", + " for msg in messages:\n", + " role = msg[\"role\"]\n", + " content = msg[\"content\"]\n", + " if role == \"system\":\n", + " prompt += f\"[INST] {content} [/INST]\\n\"\n", + " elif role == \"user\":\n", + " prompt += f\"[INST] {content} [/INST]\\n\"\n", + " else:\n", + " prompt += f\"{content}\\n\"\n", + " \n", + " response = hf_client.text_generation(\n", + " prompt,\n", + " model=model,\n", + " max_new_tokens=1024,\n", + " temperature=0.7,\n", + " repetition_penalty=1.2\n", + " )\n", + " return response\n" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "CSS = \"\"\"\n", + ":root {\n", + " --py-color: #209dd7;\n", + " --trading-color: #27ae60;\n", + " --accent: #753991;\n", + " --card: #161a22;\n", + " --text: #e9eef5;\n", + "}\n", + "\n", + "/* Full-width layout */\n", + ".gradio-container {\n", + " max-width: 100% !important;\n", + " padding: 0 40px !important;\n", + "}\n", + "\n", + "/* Code card styling */\n", + ".card {\n", + " background: var(--card);\n", + " border: 1px solid rgba(255,255,255,.08);\n", + " border-radius: 14px;\n", + " padding: 10px;\n", + "}\n", + "\n", + "/* Make code block scrollable but fixed height */\n", + "#code-block {\n", + " max-height: 400px !important;\n", + " overflow-y: auto !important;\n", + "}\n", + "\n", + "#code-block .cm-editor {\n", + " height: 400px !important;\n", + "}\n", + "\n", + "/* Buttons */\n", + ".generate-btn button {\n", + " background: var(--accent) !important;\n", + " border-color: rgba(255,255,255,.12) !important;\n", + " color: white !important;\n", + " font-weight: 700;\n", + "}\n", + ".run-btn button {\n", + " background: #202631 !important;\n", + " color: var(--text) !important;\n", + " border-color: rgba(255,255,255,.12) !important;\n", + "}\n", + ".run-btn.py button:hover { box-shadow: 0 0 0 2px var(--py-color) inset; }\n", + ".run-btn.trading button:hover { box-shadow: 0 0 0 2px var(--trading-color) inset; }\n", + ".generate-btn button:hover { box-shadow: 0 0 0 2px var(--accent) inset; }\n", + "\n", + "/* Outputs with color tint */\n", + ".py-out textarea {\n", + " background: linear-gradient(180deg, rgba(32,157,215,.18), rgba(32,157,215,.10));\n", + " border: 1px solid rgba(32,157,215,.35) !important;\n", + " color: rgba(32,157,215,1) !important;\n", + " font-weight: 600;\n", + "}\n", + ".trading-out textarea {\n", + " background: linear-gradient(180deg, rgba(39,174,96,.18), rgba(39,174,96,.10));\n", + " border: 1px solid rgba(39,174,96,.35) !important;\n", + " color: rgba(39,174,96,1) !important;\n", + " font-weight: 600;\n", + "}\n", + "\n", + "/* Align controls neatly */\n", + ".controls .wrap {\n", + " gap: 10px;\n", + " justify-content: center;\n", + " align-items: center;\n", + "}\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt = \"\"\"\n", + "You are an expert algorithmic trading code generator. Generate clean, bug-free Python code for trading strategies.\n", + "\n", + "Generate code that:\n", + "1. Uses synthetic data generation only - no API calls\n", + "2. Implements the specified trading strategy\n", + "3. Uses proper error handling\n", + "4. Visualizes strategy performance with buy/sell signals\n", + "5. Calculates performance metrics\n", + "6. Handles edge cases properly\n", + "\n", + "REQUIREMENTS:\n", + "1. Include if __name__ == \"__main__\": block that executes immediately\n", + "2. Define all variables before use\n", + "3. Pass parameters between functions, avoid global variables\n", + "4. NO explanatory text outside of code\n", + "5. NO markdown blocks or language indicators\n", + "6. Code must execute without user input\n", + "7. Use str() for pandas objects in f-strings\n", + "8. Use .copy() for DataFrame views that will be modified\n", + "9. Include min_periods in rolling calculations\n", + "10. Check array lengths before scatter plots\n", + "11. Configure logging properly\n", + "12. Include helper functions for formatting and plotting\n", + "\n", + "Respond ONLY with Python code. No explanations or markdown.\n", + "\"\"\"\n", + "\n", + "def user_prompt_for(description):\n", + " return f\"\"\"\n", + "Generate Python code for a trading strategy:\n", + "\n", + "{description}\n", + "\n", + "Requirements:\n", + "1. Use synthetic data generation only\n", + "2. Implement the strategy exactly as described\n", + "3. Include backtesting functionality\n", + "4. Visualize results with matplotlib\n", + "5. Calculate performance metrics\n", + "6. Handle all edge cases\n", + "7. No comments needed\n", + "\n", + "Make the code complete and runnable as-is with all necessary imports.\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [], + "source": [ + "def messages_for(description):\n", + " return [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt_for(description)}\n", + " ]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def validate_code(code):\n", + " issues = []\n", + " if \"import yfinance\" not in code and \"from yfinance\" not in code:\n", + " issues.append(\"Missing yfinance import\")\n", + " if \"import matplotlib\" not in code and \"from matplotlib\" not in code:\n", + " issues.append(\"Missing matplotlib import\")\n", + " if \"__name__ == \\\"__main__\\\"\" not in code and \"__name__ == '__main__'\" not in code:\n", + " issues.append(\"Missing if __name__ == '__main__' block\")\n", + " if \"f\\\"\" in code or \"f'\" in code:\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if ('f\"' in line or \"f'\" in line) and ('data[' in line or '.iloc' in line or '.loc' in line):\n", + " issues.append(f\"Potentially unsafe f-string formatting with pandas objects on line {i+1}\")\n", + " if \"try:\" in code and \"except\" not in code:\n", + " issues.append(\"Try block without except clause\")\n", + " if \"rolling\" in code and \"min_periods\" not in code:\n", + " issues.append(\"Rolling window without min_periods parameter (may produce NaN values)\")\n", + " if \".loc\" in code and \"iloc\" not in code and \"copy()\" not in code:\n", + " issues.append(\"Potential pandas SettingWithCopyWarning - consider using .copy() before modifications\")\n", + " lines = code.split('\\n')\n", + " defined_vars = set()\n", + " for line in lines:\n", + " if line.strip().startswith('#') or not line.strip():\n", + " continue\n", + " if '=' in line and not line.strip().startswith('if') and not line.strip().startswith('elif') and not line.strip().startswith('while'):\n", + " var_name = line.split('=')[0].strip()\n", + " if var_name:\n", + " defined_vars.add(var_name)\n", + " if issues:\n", + " return False, issues\n", + " return True, []\n", + "\n", + "def generate_trading_code(model, description, force_gpt4=False):\n", + " messages = messages_for(description)\n", + " if force_gpt4:\n", + " try:\n", + " reply = generate_with_openai(\"gpt-4o\", messages)\n", + " except Exception as e:\n", + " print(f\"Error using GPT-4o: {e}. Falling back to selected model.\")\n", + " if \"gpt\" in model.lower():\n", + " reply = generate_with_openai(model, messages)\n", + " else:\n", + " reply = generate_with_hf(model, messages)\n", + " else:\n", + " if \"gpt\" in model.lower():\n", + " reply = generate_with_openai(model, messages)\n", + " else:\n", + " reply = generate_with_hf(model, messages)\n", + " reply = reply.replace('```python','').replace('```','')\n", + " is_valid, issues = validate_code(reply)\n", + " max_attempts = 3\n", + " attempt = 0\n", + " fix_model = \"gpt-4o\" if force_gpt4 else model\n", + " while not is_valid and attempt < max_attempts and (\"gpt\" in model.lower() or force_gpt4):\n", + " attempt += 1\n", + " fix_messages = messages.copy()\n", + " fix_messages.append({\"role\": \"assistant\", \"content\": reply})\n", + " fix_request = f\"\"\"The code has the following issues that need to be fixed:\n", + "{chr(10).join([f\"- {issue}\" for issue in issues])}\n", + "\n", + "Please provide a completely corrected version that addresses these issues. Make sure to:\n", + "\n", + "1. Avoid using f-strings with pandas Series or DataFrame objects directly\n", + "2. Always handle NaN values in calculations with proper checks\n", + "3. Use proper error handling with try/except blocks around all API calls and calculations\n", + "4. Include min_periods parameter in rolling window calculations\n", + "5. Use .copy() when creating views of DataFrames that will be modified\n", + "6. Make sure all variables are properly defined before use\n", + "7. Add yfinance timeout settings: yf.set_timeout(30)\n", + "8. Add proper logging for all steps\n", + "9. Use synthetic data generation as a fallback if API calls fail\n", + "10. Include proper if __name__ == \"__main__\" block\n", + "\n", + "Return ONLY the corrected code with no explanation or markdown formatting.\n", + "\"\"\"\n", + " fix_messages.append({\"role\": \"user\", \"content\": fix_request})\n", + " try:\n", + " if force_gpt4:\n", + " fixed_reply = generate_with_openai(\"gpt-4o\", fix_messages)\n", + " else:\n", + " if \"gpt\" in model.lower():\n", + " fixed_reply = generate_with_openai(model, fix_messages)\n", + " else:\n", + " fixed_reply = generate_with_hf(model, fix_messages)\n", + " fixed_reply = fixed_reply.replace('```python','').replace('```','')\n", + " is_fixed_valid, fixed_issues = validate_code(fixed_reply)\n", + " if is_fixed_valid or len(fixed_issues) < len(issues):\n", + " reply = fixed_reply\n", + " is_valid = is_fixed_valid\n", + " issues = fixed_issues\n", + " except Exception as e:\n", + " print(f\"Error during fix attempt {attempt}: {e}\")\n", + " reply = add_safety_features(reply)\n", + " return reply\n", + "\n", + "def add_safety_features(code):\n", + " if \"pandas\" in code:\n", + " safety_imports = \"\"\"\n", + "import pandas as pd\n", + "pd.set_option('display.float_format', '{:.5f}'.format)\n", + "\n", + "def safe_format(obj):\n", + " if isinstance(obj, (pd.Series, pd.DataFrame)):\n", + " return str(obj)\n", + " return obj\n", + "\"\"\"\n", + " import_lines = [i for i, line in enumerate(code.split('\\n')) if 'import' in line]\n", + " if import_lines:\n", + " lines = code.split('\\n')\n", + " lines.insert(import_lines[-1] + 1, safety_imports)\n", + " code = '\\n'.join(lines)\n", + " code = code.replace(\"yf.set_timeout(30)\", \"\")\n", + " code = code.replace(\"yf.pdr_override()\", \"\")\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if 'f\"' in line or \"f'\" in line:\n", + " if any(term in line for term in ['data[', '.iloc', '.loc', 'Series', 'DataFrame']):\n", + " for term in ['.mean()', '.sum()', '.std()', '.min()', '.max()']:\n", + " if term in line:\n", + " lines[i] = line.replace(f\"{term}\", f\"{term})\")\n", + " lines[i] = lines[i].replace(\"f\\\"\", \"f\\\"{safe_format(\")\n", + " lines[i] = lines[i].replace(\"f'\", \"f'{safe_format(\")\n", + " code = '\\n'.join(lines)\n", + " if \"plt.scatter\" in code or \".scatter\" in code:\n", + " scatter_safety = \"\"\"\n", + "def safe_scatter(ax, x, y, *args, **kwargs):\n", + " if len(x) != len(y):\n", + " min_len = min(len(x), len(y))\n", + " x = x[:min_len]\n", + " y = y[:min_len]\n", + " if len(x) == 0 or len(y) == 0:\n", + " return None\n", + " return ax.scatter(x, y, *args, **kwargs)\n", + "\"\"\"\n", + " func_lines = [i for i, line in enumerate(code.split('\\n')) if line.startswith('def ')]\n", + " if func_lines:\n", + " lines = code.split('\\n')\n", + " lines.insert(func_lines[0], scatter_safety)\n", + " code = '\\n'.join(lines)\n", + " code = code.replace(\"plt.scatter(\", \"safe_scatter(plt.gca(), \")\n", + " code = code.replace(\".scatter(\", \"safe_scatter(\")\n", + " if \"yfinance\" in code and \"generate_synthetic_data\" not in code:\n", + " synthetic_data_func = \"\"\"\n", + "def generate_synthetic_data(ticker='AAPL', start_date=None, end_date=None, days=252, seed=42):\n", + " import numpy as np\n", + " import pandas as pd\n", + " from datetime import datetime, timedelta\n", + " if start_date is None:\n", + " end_date = datetime.now()\n", + " start_date = end_date - timedelta(days=days)\n", + " elif end_date is None:\n", + " if isinstance(start_date, str):\n", + " start_date = pd.to_datetime(start_date)\n", + " end_date = datetime.now()\n", + " np.random.seed(seed)\n", + " if isinstance(start_date, str):\n", + " start = pd.to_datetime(start_date)\n", + " else:\n", + " start = start_date\n", + " if isinstance(end_date, str):\n", + " end = pd.to_datetime(end_date)\n", + " else:\n", + " end = end_date\n", + " days = (end - start).days + 1\n", + " price = 100\n", + " prices = [price]\n", + " for _ in range(days):\n", + " change = np.random.normal(0, 0.01)\n", + " price *= (1 + change)\n", + " prices.append(price)\n", + " dates = pd.date_range(start=start, end=end, periods=len(prices))\n", + " df = pd.DataFrame({\n", + " 'Open': prices[:-1],\n", + " 'High': [p * 1.01 for p in prices[:-1]],\n", + " 'Low': [p * 0.99 for p in prices[:-1]],\n", + " 'Close': prices[1:],\n", + " 'Volume': [np.random.randint(1000000, 10000000) for _ in range(len(prices)-1)]\n", + " }, index=dates[:-1])\n", + " return df\n", + "\"\"\"\n", + " func_lines = [i for i, line in enumerate(code.split('\\n')) if line.startswith('def ')]\n", + " if func_lines:\n", + " lines = code.split('\\n')\n", + " lines.insert(func_lines[0], synthetic_data_func)\n", + " code = '\\n'.join(lines)\n", + " if \"logging\" in code and \"basicConfig\" not in code:\n", + " logging_config = \"\"\"\n", + "import logging\n", + "logging.basicConfig(\n", + " level=logging.INFO,\n", + " format='[%(asctime)s] %(levelname)s: %(message)s',\n", + " datefmt='%H:%M:%S'\n", + ")\n", + "\"\"\"\n", + " import_lines = [i for i, line in enumerate(code.split('\\n')) if 'import' in line]\n", + " if import_lines:\n", + " lines = code.split('\\n')\n", + " lines.insert(import_lines[-1] + 1, logging_config)\n", + " code = '\\n'.join(lines)\n", + " if \"yfinance\" in code and \"try:\" not in code:\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if \"yf.download\" in line and \"try:\" not in lines[max(0, i-5):i]:\n", + " indent = len(line) - len(line.lstrip())\n", + " indent_str = \" \" * indent\n", + " lines[i] = f\"{indent_str}try:\\n{indent_str} {line}\\n{indent_str}except Exception as e:\\n{indent_str} logging.error(f\\\"Error fetching data: {{e}}\\\")\\n{indent_str} # Use synthetic data as fallback\\n{indent_str} data = generate_synthetic_data(ticker, start_date, end_date)\"\n", + " code = '\\n'.join(lines)\n", + " break\n", + " if \"synthetic data\" in code.lower() and \"yf.download\" in code:\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if \"yf.download\" in line:\n", + " indent = len(line) - len(line.lstrip())\n", + " indent_str = \" \" * indent\n", + " comment = f\"{indent_str}# Using synthetic data instead of API call\\n\"\n", + " synthetic = f\"{indent_str}data = generate_synthetic_data(ticker, start_date, end_date)\\n\"\n", + " lines[i] = f\"{indent_str}# {line.strip()} # Commented out to avoid API issues\"\n", + " lines.insert(i+1, comment + synthetic)\n", + " code = '\\n'.join(lines)\n", + " break\n", + " if \"plt.figure\" in code:\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if \"plt.figure\" in line and \"try:\" not in lines[max(0, i-5):i]:\n", + " indent = len(line) - len(line.lstrip())\n", + " indent_str = \" \" * indent\n", + " try_line = f\"{indent_str}try:\\n{indent_str} \"\n", + " except_line = f\"\\n{indent_str}except Exception as e:\\n{indent_str} logging.error(f\\\"Error in plotting: {{e}}\\\")\"\n", + " j = i\n", + " while j < len(lines) and (j == i or lines[j].startswith(indent_str)):\n", + " j += 1\n", + " for k in range(i, j):\n", + " if lines[k].strip():\n", + " lines[k] = indent_str + \" \" + lines[k].lstrip()\n", + " lines.insert(i, try_line.rstrip())\n", + " lines.insert(j+1, except_line)\n", + " code = '\\n'.join(lines)\n", + " break\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if \"print(\" in line and any(term in line for term in ['data[', '.iloc', '.loc', 'Series', 'DataFrame']):\n", + " lines[i] = line.replace(\"print(\", \"print(safe_format(\")\n", + " if \"))\" not in lines[i] and \"),\" in lines[i]:\n", + " lines[i] = lines[i].replace(\"),\", \")),\", 1)\n", + " elif \"))\" not in lines[i] and \")\" in lines[i]:\n", + " lines[i] = lines[i].replace(\")\", \"))\", 1)\n", + " code = '\\n'.join(lines)\n", + " return code\n" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "metadata": {}, + "outputs": [], + "source": [ + "def run_python(code):\n", + " # Create a completely separate namespace for execution\n", + " namespace = {\n", + " '__name__': '__main__',\n", + " '__builtins__': __builtins__\n", + " }\n", + " \n", + " # Modify the code to use a non-interactive matplotlib backend\n", + " # and fix pandas formatting issues\n", + " modified_code = \"\"\"\n", + "import matplotlib\n", + "matplotlib.use('Agg') # Use non-interactive backend\n", + "\n", + "# Import yfinance without setting timeout (not available in all versions)\n", + "import yfinance as yf\n", + "\n", + "# Configure logging to show in the output\n", + "import logging\n", + "logging.basicConfig(\n", + " level=logging.INFO,\n", + " format='[%(asctime)s] %(levelname)s: %(message)s',\n", + " datefmt='%H:%M:%S'\n", + ")\n", + "\n", + "# Fix pandas formatting issues\n", + "import pandas as pd\n", + "pd.set_option('display.float_format', '{:.5f}'.format)\n", + "\n", + "# Override print to ensure it flushes immediately\n", + "import builtins\n", + "original_print = builtins.print\n", + "def custom_print(*args, **kwargs):\n", + " result = original_print(*args, **kwargs)\n", + " import sys\n", + " sys.stdout.flush()\n", + " return result\n", + "builtins.print = custom_print\n", + "\n", + "# Helper function to safely format pandas objects\n", + "def safe_format(obj):\n", + " if isinstance(obj, (pd.Series, pd.DataFrame)):\n", + " return str(obj)\n", + " else:\n", + " return obj\n", + "\"\"\"\n", + " \n", + " # Add the user's code\n", + " modified_code += \"\\n\" + code\n", + " \n", + " # Capture all output\n", + " output_buffer = io.StringIO()\n", + " \n", + " # Save original stdout and redirect to our buffer\n", + " original_stdout = sys.stdout\n", + " sys.stdout = output_buffer\n", + " \n", + " # Add timestamp for execution start\n", + " print(f\"[{time.strftime('%H:%M:%S')}] Executing code...\")\n", + " \n", + " try:\n", + " # Execute the modified code\n", + " exec(modified_code, namespace)\n", + " print(f\"\\n[{time.strftime('%H:%M:%S')}] Execution completed successfully.\")\n", + " \n", + " except ModuleNotFoundError as e:\n", + " missing_module = str(e).split(\"'\")[1]\n", + " print(f\"\\nError: Missing module '{missing_module}'. Click 'Install Dependencies' to install it.\")\n", + " namespace[\"__missing_module__\"] = missing_module\n", + " \n", + " except Exception as e:\n", + " print(f\"\\n[{time.strftime('%H:%M:%S')}] Error during execution: {str(e)}\")\n", + " import traceback\n", + " print(traceback.format_exc())\n", + " \n", + " finally:\n", + " # Restore original stdout\n", + " sys.stdout = original_stdout\n", + " \n", + " # Return the captured output\n", + " return output_buffer.getvalue()\n", + "\n", + "def install_dependencies(code):\n", + " import re\n", + " import subprocess\n", + " \n", + " import_pattern = r'(?:from|import)\\s+([a-zA-Z0-9_]+)(?:\\s+(?:import|as))?'\n", + " imports = re.findall(import_pattern, code)\n", + " \n", + " std_libs = ['os', 'sys', 'io', 'time', 'datetime', 'random', 'math', 're', 'json', \n", + " 'collections', 'itertools', 'functools', 'operator', 'pathlib', 'typing']\n", + " \n", + " modules_to_install = [module for module in imports if module not in std_libs]\n", + " \n", + " if not modules_to_install:\n", + " return \"No external dependencies found to install.\"\n", + " \n", + " results = []\n", + " for module in modules_to_install:\n", + " try:\n", + " result = subprocess.run(\n", + " [sys.executable, \"-m\", \"pip\", \"install\", module],\n", + " capture_output=True,\n", + " text=True,\n", + " check=False\n", + " )\n", + " \n", + " if result.returncode == 0:\n", + " results.append(f\"Successfully installed {module}\")\n", + " else:\n", + " results.append(f\"Failed to install {module}: {result.stderr}\")\n", + " except Exception as e:\n", + " results.append(f\"Error installing {module}: {str(e)}\")\n", + " \n", + " return \"\\n\".join(results)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "metadata": {}, + "outputs": [], + "source": [ + "trading_strategies = [\n", + " {\n", + " \"name\": \"Moving Average Crossover\",\n", + " \"description\": \"Moving Average Crossover strategy for S&P 500 stocks. Buy when the 20-day moving average crosses above the 50-day moving average, and sell when it crosses below.\",\n", + " \"buy_signal\": \"20-day MA crosses above 50-day MA\",\n", + " \"sell_signal\": \"20-day MA crosses below 50-day MA\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"Medium\"\n", + " },\n", + " {\n", + " \"name\": \"RSI Mean Reversion\",\n", + " \"description\": \"Mean reversion strategy that buys stocks when RSI falls below 30 (oversold) and sells when RSI rises above 70 (overbought).\",\n", + " \"buy_signal\": \"RSI below 30 (oversold)\",\n", + " \"sell_signal\": \"RSI above 70 (overbought)\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"Medium\"\n", + " },\n", + " {\n", + " \"name\": \"Momentum Strategy\",\n", + " \"description\": \"Momentum strategy that buys the top 5 performing stocks from the Dow Jones Industrial Average over the past month and rebalances monthly.\",\n", + " \"buy_signal\": \"Stock in top 5 performers over past month\",\n", + " \"sell_signal\": \"Stock no longer in top 5 performers at rebalance\",\n", + " \"timeframe\": \"Monthly\",\n", + " \"risk_level\": \"High\"\n", + " },\n", + " {\n", + " \"name\": \"Pairs Trading\",\n", + " \"description\": \"Pairs trading strategy that identifies correlated stock pairs and trades on the divergence and convergence of their price relationship.\",\n", + " \"buy_signal\": \"Pairs ratio deviates 2+ standard deviations below mean\",\n", + " \"sell_signal\": \"Pairs ratio returns to mean or exceeds mean\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"Medium-High\"\n", + " },\n", + " {\n", + " \"name\": \"Bollinger Band Breakout\",\n", + " \"description\": \"Volatility breakout strategy that buys when a stock breaks out of its upper Bollinger Band and sells when it reverts to the mean.\",\n", + " \"buy_signal\": \"Price breaks above upper Bollinger Band (2 std dev)\",\n", + " \"sell_signal\": \"Price reverts to middle Bollinger Band (SMA)\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"High\"\n", + " },\n", + " {\n", + " \"name\": \"MACD Crossover\",\n", + " \"description\": \"MACD crossover strategy that buys when the MACD line crosses above the signal line and sells when it crosses below.\",\n", + " \"buy_signal\": \"MACD line crosses above signal line\",\n", + " \"sell_signal\": \"MACD line crosses below signal line\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"Medium\"\n", + " },\n", + " {\n", + " \"name\": \"Golden Cross\",\n", + " \"description\": \"Golden Cross strategy that buys when the 50-day moving average crosses above the 200-day moving average and sells on the Death Cross (opposite).\",\n", + " \"buy_signal\": \"50-day MA crosses above 200-day MA\",\n", + " \"sell_signal\": \"50-day MA crosses below 200-day MA\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"Low\"\n", + " }\n", + "]\n", + "\n", + "sample_strategies = [strategy[\"description\"] for strategy in trading_strategies]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "metadata": {}, + "outputs": [], + "source": [ + "default_description = \"\"\"\n", + "Create a moving average crossover strategy with the following specifications:\n", + "- Use yfinance to download historical data for a list of stocks (AAPL, MSFT, AMZN, GOOGL, META)\n", + "- Calculate 20-day and 50-day moving averages\n", + "- Generate buy signals when the 20-day MA crosses above the 50-day MA\n", + "- Generate sell signals when the 20-day MA crosses below the 50-day MA\n", + "- Implement a simple backtesting framework to evaluate the strategy\n", + "- Calculate performance metrics: total return, annualized return, Sharpe ratio, max drawdown\n", + "- Visualize the equity curve, buy/sell signals, and moving averages\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-10-22 18:30:21,233 - INFO - HTTP Request: GET http://127.0.0.1:7875/gradio_api/startup-events \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:30:21,238 - INFO - HTTP Request: HEAD http://127.0.0.1:7875/ \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7875\n", + "* To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 115, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-10-22 18:30:24,092 - INFO - HTTP Request: GET https://api.gradio.app/pkg-version \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:31:03,437 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:31:15,743 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:31:28,425 - ERROR - Error fetching data: periods must be a number, got 2025-10-22 18:31:28.425210\n", + "2025-10-22 18:31:28,429 - INFO - Synthetic data generated for tickers: AAPL\n", + "2025-10-22 18:31:28,432 - INFO - Moving averages calculated with windows 20 and 50\n", + "2025-10-22 18:31:28,434 - INFO - Signals generated based on moving average crossover\n", + "2025-10-22 18:31:28,438 - INFO - Performance calculated\n", + "2025-10-22 18:31:28,438 - INFO - Total Return: -0.010752455100331848, Sharpe Ratio: 0.18162435507214664, Max Drawdown: -0.19919271751608258\n", + "2025-10-22 18:32:28,496 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:32:38,626 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:32:47,779 - ERROR - Error fetching data from yfinance: name 'start_date' is not defined. Using synthetic data.\n", + "2025-10-22 18:33:23,647 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:33:37,829 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n" + ] + } + ], + "source": [ + "with gr.Blocks(css=CSS, theme=gr.themes.Monochrome(), title=\"Trading Code Generator\") as ui:\n", + " with gr.Row():\n", + " gr.HTML(\"

Trading Strategy Code Generator

\")\n", + " \n", + " with gr.Row():\n", + " # Left column - Controls\n", + " with gr.Column(scale=1):\n", + " strategy_dropdown = gr.Dropdown(\n", + " label=\"Select Trading Strategy\",\n", + " choices=[strategy[\"name\"] for strategy in trading_strategies],\n", + " value=trading_strategies[0][\"name\"]\n", + " )\n", + " \n", + " with gr.Accordion(\"Strategy Details\", open=False):\n", + " strategy_info = gr.JSON(\n", + " value=trading_strategies[0]\n", + " )\n", + " \n", + " model = gr.Dropdown(\n", + " label=\"Select Model\",\n", + " choices=models,\n", + " value=models[0]\n", + " )\n", + " \n", + " description = gr.TextArea(\n", + " label=\"Strategy Description (Edit to customize)\",\n", + " value=trading_strategies[0][\"description\"],\n", + " lines=4\n", + " )\n", + " \n", + " with gr.Row():\n", + " generate = gr.Button(\"Generate Code\", variant=\"primary\", size=\"sm\")\n", + " run = gr.Button(\"Run Code\", size=\"sm\")\n", + " install_deps = gr.Button(\"Install Dependencies\", size=\"sm\")\n", + " \n", + " # Right column - Code and Output\n", + " with gr.Column(scale=2):\n", + " trading_code = gr.Code(\n", + " label=\"Generated Trading Code\",\n", + " value=\"\",\n", + " language=\"python\",\n", + " lines=20,\n", + " elem_id=\"code-block\",\n", + " show_label=True\n", + " )\n", + " \n", + " output = gr.TextArea(\n", + " label=\"Execution Output\",\n", + " lines=8,\n", + " elem_classes=[\"trading-out\"]\n", + " )\n", + " \n", + " def update_strategy_info(strategy_name):\n", + " selected = next((s for s in trading_strategies if s[\"name\"] == strategy_name), None)\n", + " if selected:\n", + " return selected, selected[\"description\"]\n", + " return trading_strategies[0], trading_strategies[0][\"description\"]\n", + " \n", + " strategy_dropdown.change(\n", + " fn=update_strategy_info,\n", + " inputs=strategy_dropdown,\n", + " outputs=[strategy_info, description]\n", + " )\n", + " \n", + " # Function to show validation results when generating code\n", + " def generate_with_validation(model, description):\n", + " # Always use GPT-4o for better code quality\n", + " code = generate_trading_code(model, description, force_gpt4=True)\n", + " is_valid, issues = validate_code(code)\n", + " \n", + " validation_message = \"\"\n", + " if is_valid:\n", + " validation_message = \"Code validation passed ✓\"\n", + " else:\n", + " validation_message = \"Code validation warnings:\\n\" + \"\\n\".join([f\"- {issue}\" for issue in issues])\n", + " \n", + " return code, validation_message\n", + " \n", + " generate.click(\n", + " fn=generate_with_validation,\n", + " inputs=[model, description],\n", + " outputs=[trading_code, output]\n", + " )\n", + " \n", + " run.click(\n", + " fn=run_python,\n", + " inputs=[trading_code],\n", + " outputs=[output]\n", + " )\n", + " \n", + " install_deps.click(\n", + " fn=install_dependencies,\n", + " inputs=[trading_code],\n", + " outputs=[output]\n", + " )\n", + "\n", + "ui.launch(inbrowser=True)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing the Trading Code Generator\n", + "\n", + "Let's test the trading code generator with a specific strategy description and model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_description = \"\"\"\n", + "Create a simple RSI-based mean reversion strategy:\n", + "- Use AAPL stock data for the past 2 years\n", + "- Calculate the 14-day RSI indicator\n", + "- Buy when RSI falls below 30 (oversold)\n", + "- Sell when RSI rises above 70 (overbought)\n", + "- Include visualization of entry/exit points\n", + "- Calculate performance metrics\n", + "\"\"\"\n", + "\n", + "test_model = \"gpt-3.5-turbo\"\n", + "\n", + "generated_code = generate_trading_code(test_model, test_description)\n", + "print(\"Generated trading code:\")\n", + "print(generated_code)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " output = run_python(generated_code)\n", + " print(output)\n", + "except Exception as e:\n", + " print(f\"Error running the generated code: {e}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fixed version of the code\n", + "fixed_code = \"\"\"\n", + "import yfinance as yf\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.dates as mdates\n", + "from datetime import datetime, timedelta\n", + "import logging\n", + "\n", + "# Set up logging\n", + "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n", + "\n", + "def calculate_moving_averages(data, short_window=20, long_window=50):\n", + " \\\"\\\"\\\"\n", + " Calculate short and long term moving averages\n", + " \\\"\\\"\\\"\n", + " data['Short_MA'] = data['Close'].rolling(window=short_window, min_periods=1).mean()\n", + " data['Long_MA'] = data['Close'].rolling(window=long_window, min_periods=1).mean()\n", + " return data\n", + "\n", + "def generate_signals(data, short_window=20, long_window=50):\n", + " \\\"\\\"\\\"\n", + " Generate buy/sell signals based on moving average crossover strategy\n", + " \\\"\\\"\\\"\n", + " data['Signal'] = 0\n", + " data['Signal'][short_window:] = np.where(\n", + " data['Short_MA'][short_window:] > data['Long_MA'][short_window:], 1, -1)\n", + " data['Position'] = data['Signal'].shift(1)\n", + " data['Position'].fillna(0, inplace=True) # Fill NaN values with 0\n", + " return data\n", + "\n", + "def backtest_strategy(data):\n", + " \\\"\\\"\\\"\n", + " Backtest the trading strategy and calculate performance metrics\n", + " \\\"\\\"\\\"\n", + " data['Returns'] = data['Close'].pct_change()\n", + " data['Strategy_Returns'] = data['Returns'] * data['Position']\n", + " \n", + " # Replace NaN values with 0\n", + " data['Strategy_Returns'].fillna(0, inplace=True)\n", + "\n", + " cumulative_returns = (1 + data['Strategy_Returns']).cumprod()\n", + " \n", + " # Calculate metrics\n", + " total_return = cumulative_returns.iloc[-1] - 1\n", + " sharpe_ratio = np.sqrt(252) * (data['Strategy_Returns'].mean() / data['Strategy_Returns'].std())\n", + " max_drawdown = ((cumulative_returns / cumulative_returns.cummax()) - 1).min()\n", + "\n", + " metrics = {\n", + " 'Total Return': total_return,\n", + " 'Sharpe Ratio': sharpe_ratio,\n", + " 'Max Drawdown': max_drawdown\n", + " }\n", + "\n", + " return cumulative_returns, metrics\n", + "\n", + "def plot_results(data, cumulative_returns, ticker):\n", + " \\\"\\\"\\\"\n", + " Plot the performance of the trading strategy\n", + " \\\"\\\"\\\"\n", + " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), gridspec_kw={'height_ratios': [2, 1]})\n", + " \n", + " # Price and MA plot\n", + " ax1.plot(data.index, data['Close'], label='Close Price')\n", + " ax1.plot(data.index, data['Short_MA'], label='20-day MA', alpha=0.7)\n", + " ax1.plot(data.index, data['Long_MA'], label='50-day MA', alpha=0.7)\n", + " \n", + " # Add buy/sell signals\n", + " buy_signals = data[data['Signal'] > data['Signal'].shift(1)]\n", + " sell_signals = data[data['Signal'] < data['Signal'].shift(1)]\n", + " \n", + " ax1.scatter(buy_signals.index, buy_signals['Close'], marker='^', color='green', s=100, label='Buy Signal')\n", + " ax1.scatter(sell_signals.index, sell_signals['Close'], marker='v', color='red', s=100, label='Sell Signal')\n", + " \n", + " ax1.set_title(f'Moving Average Crossover Strategy on {ticker}')\n", + " ax1.set_ylabel('Price ($)')\n", + " ax1.legend(loc='best')\n", + " ax1.grid(True)\n", + " \n", + " # Returns plot\n", + " ax2.plot(cumulative_returns.index, cumulative_returns, label='Cumulative Strategy Returns', color='blue')\n", + " ax2.set_title('Cumulative Returns')\n", + " ax2.set_xlabel('Date')\n", + " ax2.set_ylabel('Returns')\n", + " ax2.legend(loc='best')\n", + " ax2.grid(True)\n", + " \n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "if __name__ == \\\"__main__\\\":\n", + " # User inputs\n", + " ticker = 'SPY' # Example: S&P 500 ETF\n", + " start_date = (datetime.now() - timedelta(days=365*2)).strftime('%Y-%m-%d')\n", + " end_date = datetime.now().strftime('%Y-%m-%d')\n", + "\n", + " # Strategy parameters\n", + " short_window = 20\n", + " long_window = 50\n", + "\n", + " # Fetch data\n", + " try:\n", + " logging.info(f\\\"Fetching data for {ticker} from {start_date} to {end_date}...\\\")\n", + " stock_data = yf.download(ticker, start=start_date, end=end_date)\n", + " logging.info(f\\\"Data fetched successfully. Got {len(stock_data)} data points.\\\")\n", + " except Exception as e:\n", + " logging.error(f\\\"Failed to fetch data: {e}\\\")\n", + " raise SystemExit(e)\n", + "\n", + " try:\n", + " # Preprocess and generate signals\n", + " stock_data = calculate_moving_averages(stock_data, short_window, long_window)\n", + " stock_data = generate_signals(stock_data, short_window, long_window)\n", + "\n", + " # Backtest the strategy\n", + " cumulative_returns, metrics = backtest_strategy(stock_data)\n", + "\n", + " # Display metrics\n", + " for key, value in metrics.items():\n", + " logging.info(f\\\"{key}: {value:.4f}\\\")\n", + "\n", + " # Plot results\n", + " plot_results(stock_data, cumulative_returns, ticker)\n", + " except Exception as e:\n", + " logging.error(f\\\"Error while executing strategy: {e}\\\")\n", + "\"\"\"\n", + "\n", + "# Display the fixed code\n", + "print(\"Fixed code:\")\n", + "print(fixed_code)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run the fixed code\n", + "output = run_python(fixed_code)\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's also update our system_prompt to ensure the generated code works properly\n", + "\n", + "system_prompt = \"\"\"\n", + "You are an expert algorithmic trading code generator. Your task is to generate Python code for trading strategies based on user requirements.\n", + "The code should be well-structured, efficient, and ready to run in a simulated environment.\n", + "\n", + "The generated code should:\n", + "1. Use the yfinance library for fetching stock data\n", + "2. Implement the specified trading strategy\n", + "3. Include proper error handling and logging\n", + "4. Include visualization of the strategy performance with clear buy/sell signals\n", + "5. Calculate and display relevant metrics (returns, Sharpe ratio, drawdown, etc.)\n", + "6. Handle NaN values and edge cases properly\n", + "7. Include informative print statements or logging to show progress\n", + "\n", + "IMPORTANT: Make sure all variables are properly defined before use, especially in functions.\n", + "Always pass necessary parameters between functions rather than relying on global variables.\n", + "\n", + "Respond only with Python code. Do not provide any explanation other than occasional comments in the code.\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's test the updated system prompt with a simple strategy\n", + "\n", + "test_description_2 = \"\"\"\n", + "Create a simple Bollinger Bands strategy:\n", + "- Use AAPL stock data for the past 1 year\n", + "- Calculate Bollinger Bands with 20-day SMA and 2 standard deviations\n", + "- Buy when price touches the lower band\n", + "- Sell when price touches the upper band\n", + "- Include visualization of entry/exit points\n", + "- Calculate performance metrics\n", + "\"\"\"\n", + "\n", + "test_model = \"gpt-3.5-turbo\"\n", + "generated_code_2 = generate_trading_code(test_model, test_description_2)\n", + "print(\"Generated trading code with updated prompt:\")\n", + "print(generated_code_2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's test the run function with live logging\n", + "\n", + "test_code = \"\"\"\n", + "import yfinance as yf\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from datetime import datetime, timedelta\n", + "\n", + "# Define ticker and date range\n", + "ticker = 'AAPL'\n", + "end_date = datetime.now()\n", + "start_date = end_date - timedelta(days=365)\n", + "\n", + "# Download data\n", + "print(f\"Downloading data for {ticker}...\")\n", + "data = yf.download(ticker, start=start_date, end=end_date)\n", + "print(f\"Downloaded {len(data)} rows of data\")\n", + "\n", + "# Calculate RSI\n", + "print(\"Calculating RSI...\")\n", + "delta = data['Close'].diff()\n", + "gain = delta.where(delta > 0, 0)\n", + "loss = -delta.where(delta < 0, 0)\n", + "avg_gain = gain.rolling(window=14).mean()\n", + "avg_loss = loss.rolling(window=14).mean()\n", + "rs = avg_gain / avg_loss\n", + "data['RSI'] = 100 - (100 / (1 + rs))\n", + "\n", + "# Generate signals\n", + "print(\"Generating trading signals...\")\n", + "data['Signal'] = 0\n", + "data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", + "data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", + "\n", + "# Count signals\n", + "buy_signals = len(data[data['Signal'] == 1])\n", + "sell_signals = len(data[data['Signal'] == -1])\n", + "print(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", + "\n", + "# Print sample of the data\n", + "print(\"\\\\nSample of the processed data:\")\n", + "print(data[['Close', 'RSI', 'Signal']].tail())\n", + "\n", + "print(\"\\\\nAnalysis complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(test_code)\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's test the improved run function again\n", + "\n", + "test_code_2 = \"\"\"\n", + "import yfinance as yf\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from datetime import datetime, timedelta\n", + "\n", + "# Define ticker and date range\n", + "ticker = 'AAPL'\n", + "end_date = datetime.now()\n", + "start_date = end_date - timedelta(days=365)\n", + "\n", + "# Download data\n", + "print(f\"Downloading data for {ticker}...\")\n", + "data = yf.download(ticker, start=start_date, end=end_date, progress=False)\n", + "print(f\"Downloaded {len(data)} rows of data\")\n", + "\n", + "# Calculate RSI\n", + "print(\"Calculating RSI...\")\n", + "delta = data['Close'].diff()\n", + "gain = delta.where(delta > 0, 0)\n", + "loss = -delta.where(delta < 0, 0)\n", + "avg_gain = gain.rolling(window=14).mean()\n", + "avg_loss = loss.rolling(window=14).mean()\n", + "rs = avg_gain / avg_loss\n", + "data['RSI'] = 100 - (100 / (1 + rs))\n", + "\n", + "# Generate signals\n", + "print(\"Generating trading signals...\")\n", + "data['Signal'] = 0\n", + "data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", + "data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", + "\n", + "# Count signals\n", + "buy_signals = len(data[data['Signal'] == 1])\n", + "sell_signals = len(data[data['Signal'] == -1])\n", + "print(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", + "\n", + "# Print sample of the data\n", + "print(\"\\\\nSample of the processed data:\")\n", + "print(data[['Close', 'RSI', 'Signal']].tail())\n", + "\n", + "print(\"\\\\nAnalysis complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(test_code_2)\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the completely rewritten run function\n", + "\n", + "simple_test = \"\"\"\n", + "print(\"Hello from the trading code generator!\")\n", + "print(\"Testing output capture...\")\n", + "\n", + "# Simulate some data processing\n", + "import numpy as np\n", + "data = np.random.rand(5, 3)\n", + "print(\"Generated random data:\")\n", + "print(data)\n", + "\n", + "# Show a calculation\n", + "result = np.mean(data, axis=0)\n", + "print(\"Mean of each column:\")\n", + "print(result)\n", + "\n", + "print(\"Test complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(simple_test)\n", + "print(\"Output from execution:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the improved code generation with a strategy that typically causes formatting issues\n", + "\n", + "test_description_3 = \"\"\"\n", + "Create a simple trading strategy that:\n", + "- Uses AAPL stock data for the past year\n", + "- Calculates both RSI and Bollinger Bands\n", + "- Buys when price is below lower Bollinger Band AND RSI is below 30\n", + "- Sells when price is above upper Bollinger Band OR RSI is above 70\n", + "- Includes proper error handling for all calculations\n", + "- Visualizes the entry/exit points and performance\n", + "\"\"\"\n", + "\n", + "test_model = \"gpt-3.5-turbo\"\n", + "generated_code_3 = generate_trading_code(test_model, test_description_3)\n", + "print(\"Generated trading code with enhanced validation:\")\n", + "print(generated_code_3)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's test running the generated code\n", + "\n", + "output = run_python(generated_code_3)\n", + "print(\"Execution output:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the improved run function with a simple example that prints output\n", + "\n", + "simple_test_2 = \"\"\"\n", + "print(\"This is a test of the output capture system\")\n", + "print(\"Line 1 of output\")\n", + "print(\"Line 2 of output\")\n", + "\n", + "# Import and use numpy\n", + "import numpy as np\n", + "data = np.random.rand(3, 3)\n", + "print(\"Random matrix:\")\n", + "print(data)\n", + "\n", + "# Create a simple plot\n", + "import matplotlib.pyplot as plt\n", + "plt.figure(figsize=(8, 4))\n", + "plt.plot([1, 2, 3, 4], [10, 20, 25, 30], 'ro-')\n", + "plt.title('Simple Plot')\n", + "plt.xlabel('X axis')\n", + "plt.ylabel('Y axis')\n", + "plt.grid(True)\n", + "plt.savefig('simple_plot.png') # Save instead of showing\n", + "print(\"Plot saved to simple_plot.png\")\n", + "\n", + "print(\"Test complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(simple_test_2)\n", + "print(\"Output from execution:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test with a simpler example that won't time out\n", + "\n", + "simple_test_3 = \"\"\"\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import logging\n", + "\n", + "# Set up logging\n", + "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n", + "\n", + "# Generate synthetic stock data\n", + "def generate_stock_data(days=252, volatility=0.01):\n", + " logging.info(f\"Generating {days} days of synthetic stock data\")\n", + " np.random.seed(42)\n", + " price = 100\n", + " prices = [price]\n", + " \n", + " for _ in range(days - 1):\n", + " change = np.random.normal(0, volatility)\n", + " price *= (1 + change)\n", + " prices.append(price)\n", + " \n", + " dates = pd.date_range(end=pd.Timestamp.today(), periods=days)\n", + " df = pd.DataFrame({\n", + " 'Close': prices,\n", + " 'Open': [p * (1 - volatility/2) for p in prices],\n", + " 'High': [p * (1 + volatility) for p in prices],\n", + " 'Low': [p * (1 - volatility) for p in prices],\n", + " 'Volume': [np.random.randint(100000, 10000000) for _ in range(days)]\n", + " }, index=dates)\n", + " \n", + " logging.info(f\"Generated data with shape {df.shape}\")\n", + " return df\n", + "\n", + "# Calculate RSI\n", + "def calculate_rsi(data, window=14):\n", + " logging.info(f\"Calculating RSI with {window}-day window\")\n", + " delta = data['Close'].diff()\n", + " gain = delta.where(delta > 0, 0).rolling(window=window, min_periods=1).mean()\n", + " loss = -delta.where(delta < 0, 0).rolling(window=window, min_periods=1).mean()\n", + " \n", + " rs = gain / loss\n", + " rsi = 100 - (100 / (1 + rs))\n", + " return rsi\n", + "\n", + "# Main function\n", + "if __name__ == \"__main__\":\n", + " # Generate data\n", + " data = generate_stock_data()\n", + " \n", + " # Calculate RSI\n", + " data['RSI'] = calculate_rsi(data)\n", + " \n", + " # Generate signals\n", + " logging.info(\"Generating trading signals\")\n", + " data['Signal'] = 0\n", + " data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", + " data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", + " \n", + " # Count signals\n", + " buy_signals = len(data[data['Signal'] == 1])\n", + " sell_signals = len(data[data['Signal'] == -1])\n", + " logging.info(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", + " \n", + " # Plot the results\n", + " logging.info(\"Creating visualization\")\n", + " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), gridspec_kw={'height_ratios': [3, 1]})\n", + " \n", + " # Price chart\n", + " ax1.plot(data.index, data['Close'], label='Close Price')\n", + " \n", + " # Add buy/sell signals\n", + " buy_points = data[data['Signal'] == 1]\n", + " sell_points = data[data['Signal'] == -1]\n", + " \n", + " ax1.scatter(buy_points.index, buy_points['Close'], marker='^', color='g', s=100, label='Buy')\n", + " ax1.scatter(sell_points.index, sell_points['Close'], marker='v', color='r', s=100, label='Sell')\n", + " \n", + " ax1.set_title('Stock Price with RSI Signals')\n", + " ax1.set_ylabel('Price')\n", + " ax1.legend()\n", + " ax1.grid(True)\n", + " \n", + " # RSI chart\n", + " ax2.plot(data.index, data['RSI'], color='purple', label='RSI')\n", + " ax2.axhline(y=70, color='r', linestyle='--', alpha=0.5)\n", + " ax2.axhline(y=30, color='g', linestyle='--', alpha=0.5)\n", + " ax2.set_title('RSI Indicator')\n", + " ax2.set_ylabel('RSI')\n", + " ax2.set_ylim(0, 100)\n", + " ax2.grid(True)\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig('rsi_strategy.png')\n", + " logging.info(\"Plot saved to rsi_strategy.png\")\n", + " \n", + " # Print sample of data\n", + " logging.info(\"Sample of the processed data:\")\n", + " print(data[['Close', 'RSI', 'Signal']].tail())\n", + " \n", + " logging.info(\"Analysis complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(simple_test_3)\n", + "print(\"Output from execution:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the enhanced code generation with GPT-4o\n", + "\n", + "test_description_4 = \"\"\"\n", + "Create a trading strategy that:\n", + "- Uses both MACD and RSI indicators\n", + "- Buys when MACD crosses above signal line AND RSI is below 40\n", + "- Sells when MACD crosses below signal line OR RSI is above 70\n", + "- Includes proper visualization with buy/sell signals\n", + "- Uses synthetic data if API calls fail\n", + "- Calculates performance metrics including Sharpe ratio and max drawdown\n", + "\"\"\"\n", + "\n", + "print(\"Generating trading code with GPT-4o...\")\n", + "generated_code_4 = generate_trading_code(\"gpt-4o\", test_description_4, force_gpt4=True)\n", + "print(\"Code generation complete. Validating...\")\n", + "is_valid, issues = validate_code(generated_code_4)\n", + "\n", + "if issues:\n", + " print(f\"Validation found {len(issues)} issues:\")\n", + " for issue in issues:\n", + " print(f\"- {issue}\")\n", + "else:\n", + " print(\"Code validation passed ✓\")\n", + "\n", + "print(\"\\nGenerated code snippet (first 20 lines):\")\n", + "print(\"\\n\".join(generated_code_4.split(\"\\n\")[:20]))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's run the generated code to test it\n", + "\n", + "output = run_python(generated_code_4)\n", + "print(\"Execution output:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's test again with the fixed timeout setting\n", + "\n", + "test_description_5 = \"\"\"\n", + "Create a simple trading strategy that:\n", + "- Uses synthetic data generation to avoid API timeouts\n", + "- Implements a simple moving average crossover (5-day and 20-day)\n", + "- Includes proper visualization with buy/sell signals\n", + "- Calculates basic performance metrics\n", + "\"\"\"\n", + "\n", + "print(\"Generating trading code with proper yfinance timeout settings...\")\n", + "generated_code_5 = generate_trading_code(\"gpt-4o\", test_description_5, force_gpt4=True)\n", + "print(\"Code generation complete.\")\n", + "\n", + "# Run the generated code\n", + "output = run_python(generated_code_5)\n", + "print(\"Execution output:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test with a simpler strategy that focuses on scatter plot safety\n", + "\n", + "test_description_6 = \"\"\"\n", + "Create a simple trading strategy that:\n", + "- Uses synthetic data generation only (no API calls)\n", + "- Implements a simple RSI-based strategy (buy when RSI < 30, sell when RSI > 70)\n", + "- Includes visualization with buy/sell signals using scatter plots\n", + "- Calculates basic performance metrics\n", + "- Uses proper error handling for all operations\n", + "\"\"\"\n", + "\n", + "print(\"Generating trading code with scatter plot safety...\")\n", + "generated_code_6 = generate_trading_code(\"gpt-4o\", test_description_6, force_gpt4=True)\n", + "print(\"Code generation complete.\")\n", + "\n", + "# Run the generated code\n", + "output = run_python(generated_code_6)\n", + "print(\"Execution output:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test with a fixed example that properly handles pandas formatting and scatter plots\n", + "\n", + "test_fixed_code = \"\"\"\n", + "import yfinance as yf\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import logging\n", + "from datetime import datetime, timedelta\n", + "\n", + "# Configure logging\n", + "logging.basicConfig(\n", + " level=logging.INFO,\n", + " format='[%(asctime)s] %(levelname)s: %(message)s',\n", + " datefmt='%H:%M:%S'\n", + ")\n", + "\n", + "# Helper function for safe formatting of pandas objects\n", + "def safe_format(obj):\n", + " if isinstance(obj, (pd.Series, pd.DataFrame)):\n", + " return str(obj)\n", + " return obj\n", + "\n", + "# Helper function to safely create scatter plots\n", + "def safe_scatter(ax, x, y, *args, **kwargs):\n", + " # Ensure x and y are the same length\n", + " if len(x) != len(y):\n", + " logging.warning(f\"Scatter plot inputs have different lengths: x={len(x)}, y={len(y)}\")\n", + " # Find the minimum length\n", + " min_len = min(len(x), len(y))\n", + " x = x[:min_len]\n", + " y = y[:min_len]\n", + " \n", + " # Check for empty arrays\n", + " if len(x) == 0 or len(y) == 0:\n", + " logging.warning(\"Empty arrays passed to scatter plot, skipping\")\n", + " return None\n", + " \n", + " return ax.scatter(x, y, *args, **kwargs)\n", + "\n", + "# Generate synthetic data\n", + "def generate_synthetic_data(ticker='AAPL', days=252, seed=42):\n", + " logging.info(f\"Generating synthetic data for {ticker}\")\n", + " np.random.seed(seed)\n", + " \n", + " # Generate price data\n", + " price = 100 # Starting price\n", + " prices = [price]\n", + " \n", + " for _ in range(days):\n", + " change = np.random.normal(0, 0.01) # 1% volatility\n", + " price *= (1 + change)\n", + " prices.append(price)\n", + " \n", + " # Create date range\n", + " end_date = datetime.now()\n", + " start_date = end_date - timedelta(days=days)\n", + " dates = pd.date_range(start=start_date, end=end_date, periods=len(prices))\n", + " \n", + " # Create DataFrame\n", + " df = pd.DataFrame({\n", + " 'Open': prices[:-1],\n", + " 'High': [p * 1.01 for p in prices[:-1]],\n", + " 'Low': [p * 0.99 for p in prices[:-1]],\n", + " 'Close': prices[1:],\n", + " 'Volume': [np.random.randint(1000000, 10000000) for _ in range(len(prices)-1)]\n", + " }, index=dates[:-1])\n", + " \n", + " logging.info(f\"Generated {len(df)} days of data for {ticker}\")\n", + " return df\n", + "\n", + "# Calculate RSI\n", + "def calculate_rsi(data, window=14):\n", + " logging.info(f\"Calculating RSI with {window}-day window\")\n", + " delta = data['Close'].diff()\n", + " gain = delta.where(delta > 0, 0)\n", + " loss = -delta.where(delta < 0, 0)\n", + " \n", + " avg_gain = gain.rolling(window=window, min_periods=1).mean()\n", + " avg_loss = loss.rolling(window=window, min_periods=1).mean()\n", + " \n", + " rs = avg_gain / avg_loss\n", + " rsi = 100 - (100 / (1 + rs))\n", + " return rsi\n", + "\n", + "# Generate signals\n", + "def generate_signals(data):\n", + " logging.info(\"Generating trading signals\")\n", + " data['Signal'] = 0\n", + " data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", + " data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", + " \n", + " # Count signals\n", + " buy_signals = len(data[data['Signal'] == 1])\n", + " sell_signals = len(data[data['Signal'] == -1])\n", + " logging.info(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", + " return data\n", + "\n", + "# Backtest strategy\n", + "def backtest_strategy(data):\n", + " logging.info(\"Backtesting strategy\")\n", + " data['Returns'] = data['Close'].pct_change()\n", + " data['Strategy'] = data['Signal'].shift(1) * data['Returns']\n", + " \n", + " # Replace NaN values\n", + " data['Strategy'].fillna(0, inplace=True)\n", + " \n", + " # Calculate cumulative returns\n", + " data['Cumulative'] = (1 + data['Strategy']).cumprod()\n", + " \n", + " # Calculate metrics\n", + " total_return = data['Cumulative'].iloc[-1] - 1\n", + " sharpe = np.sqrt(252) * data['Strategy'].mean() / data['Strategy'].std()\n", + " max_dd = (data['Cumulative'] / data['Cumulative'].cummax() - 1).min()\n", + " \n", + " logging.info(f\"Total Return: {total_return:.4f}\")\n", + " logging.info(f\"Sharpe Ratio: {sharpe:.4f}\")\n", + " logging.info(f\"Max Drawdown: {max_dd:.4f}\")\n", + " \n", + " return data\n", + "\n", + "# Visualize results\n", + "def visualize_results(data, ticker):\n", + " logging.info(\"Creating visualization\")\n", + " try:\n", + " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [3, 1]})\n", + " \n", + " # Price chart\n", + " ax1.plot(data.index, data['Close'], label='Close Price')\n", + " \n", + " # Add buy/sell signals\n", + " buy_points = data[data['Signal'] == 1]\n", + " sell_points = data[data['Signal'] == -1]\n", + " \n", + " # Use safe scatter to avoid \"x and y must be the same size\" error\n", + " if not buy_points.empty:\n", + " safe_scatter(ax1, buy_points.index, buy_points['Close'], marker='^', color='g', s=100, label='Buy')\n", + " \n", + " if not sell_points.empty:\n", + " safe_scatter(ax1, sell_points.index, sell_points['Close'], marker='v', color='r', s=100, label='Sell')\n", + " \n", + " ax1.set_title(f'RSI Strategy for {ticker}')\n", + " ax1.set_ylabel('Price')\n", + " ax1.legend()\n", + " ax1.grid(True)\n", + " \n", + " # RSI chart\n", + " ax2.plot(data.index, data['RSI'], color='purple', label='RSI')\n", + " ax2.axhline(y=70, color='r', linestyle='--', alpha=0.5)\n", + " ax2.axhline(y=30, color='g', linestyle='--', alpha=0.5)\n", + " ax2.set_title('RSI Indicator')\n", + " ax2.set_ylabel('RSI')\n", + " ax2.set_ylim(0, 100)\n", + " ax2.grid(True)\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig('rsi_strategy.png')\n", + " logging.info(\"Plot saved to rsi_strategy.png\")\n", + " except Exception as e:\n", + " logging.error(f\"Error in visualization: {e}\")\n", + "\n", + "if __name__ == \"__main__\":\n", + " # Settings\n", + " ticker = 'AAPL'\n", + " days = 252 # One year of trading days\n", + " \n", + " # Generate data\n", + " data = generate_synthetic_data(ticker, days)\n", + " \n", + " # Calculate RSI\n", + " data['RSI'] = calculate_rsi(data)\n", + " \n", + " # Generate signals\n", + " data = generate_signals(data)\n", + " \n", + " # Backtest strategy\n", + " data = backtest_strategy(data)\n", + " \n", + " # Visualize results\n", + " visualize_results(data, ticker)\n", + " \n", + " # Print sample of data\n", + " logging.info(\"Sample of the processed data:\")\n", + " print(data[['Close', 'RSI', 'Signal', 'Strategy', 'Cumulative']].tail())\n", + " \n", + " logging.info(\"Analysis complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(test_fixed_code)\n", + "print(\"Output from execution:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the fix for indentation error in plotting code\n", + "\n", + "test_code_with_plotting = \"\"\"\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import logging\n", + "\n", + "# Configure logging\n", + "logging.basicConfig(\n", + " level=logging.INFO,\n", + " format='[%(asctime)s] %(levelname)s: %(message)s',\n", + " datefmt='%H:%M:%S'\n", + ")\n", + "\n", + "# Simple plotting function\n", + "def create_plot():\n", + " # Generate some data\n", + " x = np.linspace(0, 10, 100)\n", + " y = np.sin(x)\n", + " \n", + " # Create plot\n", + " logging.info(\"Creating sine wave plot\")\n", + " plt.figure(figsize=(10, 6))\n", + " plt.plot(x, y)\n", + " plt.title('Sine Wave')\n", + " plt.xlabel('X')\n", + " plt.ylabel('Y')\n", + " plt.grid(True)\n", + " plt.savefig('sine_wave.png')\n", + " logging.info(\"Plot saved to sine_wave.png\")\n", + "\n", + "if __name__ == \"__main__\":\n", + " create_plot()\n", + "\"\"\"\n", + "\n", + "# Apply safety features to the code\n", + "enhanced_code = add_safety_features(test_code_with_plotting)\n", + "print(\"Code with safety features applied:\")\n", + "print(enhanced_code)\n", + "\n", + "# Run the enhanced code\n", + "output = run_python(enhanced_code)\n", + "print(\"\\nExecution output:\")\n", + "print(output)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}