{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Trading Code Generator\n", "\n", "This notebook creates a code generator that produces trading code to buy and sell equities in a simulated environment based on free APIs. It uses Gradio for the UI, similar to the approach in day5.ipynb.\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import os\n", "import io\n", "import sys\n", "import time\n", "import random\n", "import numpy as np\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", "import gradio as gr\n", "from IPython.display import display\n", "from huggingface_hub import InferenceClient\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OpenAI API Key exists and begins sk-proj-\n", "Hugging Face Token exists and begins hf_fNncb\n" ] } ], "source": [ "load_dotenv(override=True)\n", "openai_api_key = os.getenv('OPENAI_API_KEY')\n", "hf_token = os.getenv('HF_TOKEN')\n", "\n", "if openai_api_key:\n", " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", "else:\n", " print(\"OpenAI API Key not set\")\n", " \n", "if hf_token:\n", " print(f\"Hugging Face Token exists and begins {hf_token[:8]}\")\n", "else:\n", " print(\"Hugging Face Token not set\")\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "openai_client = OpenAI()\n", "hf_client = InferenceClient(token=hf_token)\n" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "models = [\"gpt-4o\", \"gpt-3.5-turbo\", \"meta-llama/Llama-2-70b-chat-hf\"]\n", "\n", "def generate_with_openai(model, messages):\n", " response = openai_client.chat.completions.create(\n", " model=model, \n", " messages=messages\n", " )\n", " return response.choices[0].message.content\n", "\n", "def generate_with_hf(model, messages):\n", " prompt = \"\"\n", " for msg in messages:\n", " role = msg[\"role\"]\n", " content = msg[\"content\"]\n", " if role == \"system\":\n", " prompt += f\"[INST] {content} [/INST]\\n\"\n", " elif role == \"user\":\n", " prompt += f\"[INST] {content} [/INST]\\n\"\n", " else:\n", " prompt += f\"{content}\\n\"\n", " \n", " response = hf_client.text_generation(\n", " prompt,\n", " model=model,\n", " max_new_tokens=1024,\n", " temperature=0.7,\n", " repetition_penalty=1.2\n", " )\n", " return response\n" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "CSS = \"\"\"\n", ":root {\n", " --py-color: #209dd7;\n", " --trading-color: #27ae60;\n", " --accent: #753991;\n", " --card: #161a22;\n", " --text: #e9eef5;\n", "}\n", "\n", "/* Full-width layout */\n", ".gradio-container {\n", " max-width: 100% !important;\n", " padding: 0 40px !important;\n", "}\n", "\n", "/* Code card styling */\n", ".card {\n", " background: var(--card);\n", " border: 1px solid rgba(255,255,255,.08);\n", " border-radius: 14px;\n", " padding: 10px;\n", "}\n", "\n", "/* Make code block scrollable but fixed height */\n", "#code-block {\n", " max-height: 400px !important;\n", " overflow-y: auto !important;\n", "}\n", "\n", "#code-block .cm-editor {\n", " height: 400px !important;\n", "}\n", "\n", "/* Buttons */\n", ".generate-btn button {\n", " background: var(--accent) !important;\n", " border-color: rgba(255,255,255,.12) !important;\n", " color: white !important;\n", " font-weight: 700;\n", "}\n", ".run-btn button {\n", " background: #202631 !important;\n", " color: var(--text) !important;\n", " border-color: rgba(255,255,255,.12) !important;\n", "}\n", ".run-btn.py button:hover { box-shadow: 0 0 0 2px var(--py-color) inset; }\n", ".run-btn.trading button:hover { box-shadow: 0 0 0 2px var(--trading-color) inset; }\n", ".generate-btn button:hover { box-shadow: 0 0 0 2px var(--accent) inset; }\n", "\n", "/* Outputs with color tint */\n", ".py-out textarea {\n", " background: linear-gradient(180deg, rgba(32,157,215,.18), rgba(32,157,215,.10));\n", " border: 1px solid rgba(32,157,215,.35) !important;\n", " color: rgba(32,157,215,1) !important;\n", " font-weight: 600;\n", "}\n", ".trading-out textarea {\n", " background: linear-gradient(180deg, rgba(39,174,96,.18), rgba(39,174,96,.10));\n", " border: 1px solid rgba(39,174,96,.35) !important;\n", " color: rgba(39,174,96,1) !important;\n", " font-weight: 600;\n", "}\n", "\n", "/* Align controls neatly */\n", ".controls .wrap {\n", " gap: 10px;\n", " justify-content: center;\n", " align-items: center;\n", "}\n", "\"\"\"\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "system_prompt = \"\"\"\n", "You are an expert algorithmic trading code generator. Generate clean, bug-free Python code for trading strategies.\n", "\n", "Generate code that:\n", "1. Uses synthetic data generation only - no API calls\n", "2. Implements the specified trading strategy\n", "3. Uses proper error handling\n", "4. Visualizes strategy performance with buy/sell signals\n", "5. Calculates performance metrics\n", "6. Handles edge cases properly\n", "\n", "REQUIREMENTS:\n", "1. Include if __name__ == \"__main__\": block that executes immediately\n", "2. Define all variables before use\n", "3. Pass parameters between functions, avoid global variables\n", "4. NO explanatory text outside of code\n", "5. NO markdown blocks or language indicators\n", "6. Code must execute without user input\n", "7. Use str() for pandas objects in f-strings\n", "8. Use .copy() for DataFrame views that will be modified\n", "9. Include min_periods in rolling calculations\n", "10. Check array lengths before scatter plots\n", "11. Configure logging properly\n", "12. Include helper functions for formatting and plotting\n", "\n", "Respond ONLY with Python code. No explanations or markdown.\n", "\"\"\"\n", "\n", "def user_prompt_for(description):\n", " return f\"\"\"\n", "Generate Python code for a trading strategy:\n", "\n", "{description}\n", "\n", "Requirements:\n", "1. Use synthetic data generation only\n", "2. Implement the strategy exactly as described\n", "3. Include backtesting functionality\n", "4. Visualize results with matplotlib\n", "5. Calculate performance metrics\n", "6. Handle all edge cases\n", "7. No comments needed\n", "\n", "Make the code complete and runnable as-is with all necessary imports.\n", "\"\"\"\n" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [], "source": [ "def messages_for(description):\n", " return [\n", " {\"role\": \"system\", \"content\": system_prompt},\n", " {\"role\": \"user\", \"content\": user_prompt_for(description)}\n", " ]\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def validate_code(code):\n", " issues = []\n", " if \"import yfinance\" not in code and \"from yfinance\" not in code:\n", " issues.append(\"Missing yfinance import\")\n", " if \"import matplotlib\" not in code and \"from matplotlib\" not in code:\n", " issues.append(\"Missing matplotlib import\")\n", " if \"__name__ == \\\"__main__\\\"\" not in code and \"__name__ == '__main__'\" not in code:\n", " issues.append(\"Missing if __name__ == '__main__' block\")\n", " if \"f\\\"\" in code or \"f'\" in code:\n", " lines = code.split('\\n')\n", " for i, line in enumerate(lines):\n", " if ('f\"' in line or \"f'\" in line) and ('data[' in line or '.iloc' in line or '.loc' in line):\n", " issues.append(f\"Potentially unsafe f-string formatting with pandas objects on line {i+1}\")\n", " if \"try:\" in code and \"except\" not in code:\n", " issues.append(\"Try block without except clause\")\n", " if \"rolling\" in code and \"min_periods\" not in code:\n", " issues.append(\"Rolling window without min_periods parameter (may produce NaN values)\")\n", " if \".loc\" in code and \"iloc\" not in code and \"copy()\" not in code:\n", " issues.append(\"Potential pandas SettingWithCopyWarning - consider using .copy() before modifications\")\n", " lines = code.split('\\n')\n", " defined_vars = set()\n", " for line in lines:\n", " if line.strip().startswith('#') or not line.strip():\n", " continue\n", " if '=' in line and not line.strip().startswith('if') and not line.strip().startswith('elif') and not line.strip().startswith('while'):\n", " var_name = line.split('=')[0].strip()\n", " if var_name:\n", " defined_vars.add(var_name)\n", " if issues:\n", " return False, issues\n", " return True, []\n", "\n", "def generate_trading_code(model, description, force_gpt4=False):\n", " messages = messages_for(description)\n", " if force_gpt4:\n", " try:\n", " reply = generate_with_openai(\"gpt-4o\", messages)\n", " except Exception as e:\n", " print(f\"Error using GPT-4o: {e}. Falling back to selected model.\")\n", " if \"gpt\" in model.lower():\n", " reply = generate_with_openai(model, messages)\n", " else:\n", " reply = generate_with_hf(model, messages)\n", " else:\n", " if \"gpt\" in model.lower():\n", " reply = generate_with_openai(model, messages)\n", " else:\n", " reply = generate_with_hf(model, messages)\n", " reply = reply.replace('```python','').replace('```','')\n", " is_valid, issues = validate_code(reply)\n", " max_attempts = 3\n", " attempt = 0\n", " fix_model = \"gpt-4o\" if force_gpt4 else model\n", " while not is_valid and attempt < max_attempts and (\"gpt\" in model.lower() or force_gpt4):\n", " attempt += 1\n", " fix_messages = messages.copy()\n", " fix_messages.append({\"role\": \"assistant\", \"content\": reply})\n", " fix_request = f\"\"\"The code has the following issues that need to be fixed:\n", "{chr(10).join([f\"- {issue}\" for issue in issues])}\n", "\n", "Please provide a completely corrected version that addresses these issues. Make sure to:\n", "\n", "1. Avoid using f-strings with pandas Series or DataFrame objects directly\n", "2. Always handle NaN values in calculations with proper checks\n", "3. Use proper error handling with try/except blocks around all API calls and calculations\n", "4. Include min_periods parameter in rolling window calculations\n", "5. Use .copy() when creating views of DataFrames that will be modified\n", "6. Make sure all variables are properly defined before use\n", "7. Add yfinance timeout settings: yf.set_timeout(30)\n", "8. Add proper logging for all steps\n", "9. Use synthetic data generation as a fallback if API calls fail\n", "10. Include proper if __name__ == \"__main__\" block\n", "\n", "Return ONLY the corrected code with no explanation or markdown formatting.\n", "\"\"\"\n", " fix_messages.append({\"role\": \"user\", \"content\": fix_request})\n", " try:\n", " if force_gpt4:\n", " fixed_reply = generate_with_openai(\"gpt-4o\", fix_messages)\n", " else:\n", " if \"gpt\" in model.lower():\n", " fixed_reply = generate_with_openai(model, fix_messages)\n", " else:\n", " fixed_reply = generate_with_hf(model, fix_messages)\n", " fixed_reply = fixed_reply.replace('```python','').replace('```','')\n", " is_fixed_valid, fixed_issues = validate_code(fixed_reply)\n", " if is_fixed_valid or len(fixed_issues) < len(issues):\n", " reply = fixed_reply\n", " is_valid = is_fixed_valid\n", " issues = fixed_issues\n", " except Exception as e:\n", " print(f\"Error during fix attempt {attempt}: {e}\")\n", " reply = add_safety_features(reply)\n", " return reply\n", "\n", "def add_safety_features(code):\n", " if \"pandas\" in code:\n", " safety_imports = \"\"\"\n", "import pandas as pd\n", "pd.set_option('display.float_format', '{:.5f}'.format)\n", "\n", "def safe_format(obj):\n", " if isinstance(obj, (pd.Series, pd.DataFrame)):\n", " return str(obj)\n", " return obj\n", "\"\"\"\n", " import_lines = [i for i, line in enumerate(code.split('\\n')) if 'import' in line]\n", " if import_lines:\n", " lines = code.split('\\n')\n", " lines.insert(import_lines[-1] + 1, safety_imports)\n", " code = '\\n'.join(lines)\n", " code = code.replace(\"yf.set_timeout(30)\", \"\")\n", " code = code.replace(\"yf.pdr_override()\", \"\")\n", " lines = code.split('\\n')\n", " for i, line in enumerate(lines):\n", " if 'f\"' in line or \"f'\" in line:\n", " if any(term in line for term in ['data[', '.iloc', '.loc', 'Series', 'DataFrame']):\n", " for term in ['.mean()', '.sum()', '.std()', '.min()', '.max()']:\n", " if term in line:\n", " lines[i] = line.replace(f\"{term}\", f\"{term})\")\n", " lines[i] = lines[i].replace(\"f\\\"\", \"f\\\"{safe_format(\")\n", " lines[i] = lines[i].replace(\"f'\", \"f'{safe_format(\")\n", " code = '\\n'.join(lines)\n", " if \"plt.scatter\" in code or \".scatter\" in code:\n", " scatter_safety = \"\"\"\n", "def safe_scatter(ax, x, y, *args, **kwargs):\n", " if len(x) != len(y):\n", " min_len = min(len(x), len(y))\n", " x = x[:min_len]\n", " y = y[:min_len]\n", " if len(x) == 0 or len(y) == 0:\n", " return None\n", " return ax.scatter(x, y, *args, **kwargs)\n", "\"\"\"\n", " func_lines = [i for i, line in enumerate(code.split('\\n')) if line.startswith('def ')]\n", " if func_lines:\n", " lines = code.split('\\n')\n", " lines.insert(func_lines[0], scatter_safety)\n", " code = '\\n'.join(lines)\n", " code = code.replace(\"plt.scatter(\", \"safe_scatter(plt.gca(), \")\n", " code = code.replace(\".scatter(\", \"safe_scatter(\")\n", " if \"yfinance\" in code and \"generate_synthetic_data\" not in code:\n", " synthetic_data_func = \"\"\"\n", "def generate_synthetic_data(ticker='AAPL', start_date=None, end_date=None, days=252, seed=42):\n", " import numpy as np\n", " import pandas as pd\n", " from datetime import datetime, timedelta\n", " if start_date is None:\n", " end_date = datetime.now()\n", " start_date = end_date - timedelta(days=days)\n", " elif end_date is None:\n", " if isinstance(start_date, str):\n", " start_date = pd.to_datetime(start_date)\n", " end_date = datetime.now()\n", " np.random.seed(seed)\n", " if isinstance(start_date, str):\n", " start = pd.to_datetime(start_date)\n", " else:\n", " start = start_date\n", " if isinstance(end_date, str):\n", " end = pd.to_datetime(end_date)\n", " else:\n", " end = end_date\n", " days = (end - start).days + 1\n", " price = 100\n", " prices = [price]\n", " for _ in range(days):\n", " change = np.random.normal(0, 0.01)\n", " price *= (1 + change)\n", " prices.append(price)\n", " dates = pd.date_range(start=start, end=end, periods=len(prices))\n", " df = pd.DataFrame({\n", " 'Open': prices[:-1],\n", " 'High': [p * 1.01 for p in prices[:-1]],\n", " 'Low': [p * 0.99 for p in prices[:-1]],\n", " 'Close': prices[1:],\n", " 'Volume': [np.random.randint(1000000, 10000000) for _ in range(len(prices)-1)]\n", " }, index=dates[:-1])\n", " return df\n", "\"\"\"\n", " func_lines = [i for i, line in enumerate(code.split('\\n')) if line.startswith('def ')]\n", " if func_lines:\n", " lines = code.split('\\n')\n", " lines.insert(func_lines[0], synthetic_data_func)\n", " code = '\\n'.join(lines)\n", " if \"logging\" in code and \"basicConfig\" not in code:\n", " logging_config = \"\"\"\n", "import logging\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format='[%(asctime)s] %(levelname)s: %(message)s',\n", " datefmt='%H:%M:%S'\n", ")\n", "\"\"\"\n", " import_lines = [i for i, line in enumerate(code.split('\\n')) if 'import' in line]\n", " if import_lines:\n", " lines = code.split('\\n')\n", " lines.insert(import_lines[-1] + 1, logging_config)\n", " code = '\\n'.join(lines)\n", " if \"yfinance\" in code and \"try:\" not in code:\n", " lines = code.split('\\n')\n", " for i, line in enumerate(lines):\n", " if \"yf.download\" in line and \"try:\" not in lines[max(0, i-5):i]:\n", " indent = len(line) - len(line.lstrip())\n", " indent_str = \" \" * indent\n", " lines[i] = f\"{indent_str}try:\\n{indent_str} {line}\\n{indent_str}except Exception as e:\\n{indent_str} logging.error(f\\\"Error fetching data: {{e}}\\\")\\n{indent_str} # Use synthetic data as fallback\\n{indent_str} data = generate_synthetic_data(ticker, start_date, end_date)\"\n", " code = '\\n'.join(lines)\n", " break\n", " if \"synthetic data\" in code.lower() and \"yf.download\" in code:\n", " lines = code.split('\\n')\n", " for i, line in enumerate(lines):\n", " if \"yf.download\" in line:\n", " indent = len(line) - len(line.lstrip())\n", " indent_str = \" \" * indent\n", " comment = f\"{indent_str}# Using synthetic data instead of API call\\n\"\n", " synthetic = f\"{indent_str}data = generate_synthetic_data(ticker, start_date, end_date)\\n\"\n", " lines[i] = f\"{indent_str}# {line.strip()} # Commented out to avoid API issues\"\n", " lines.insert(i+1, comment + synthetic)\n", " code = '\\n'.join(lines)\n", " break\n", " if \"plt.figure\" in code:\n", " lines = code.split('\\n')\n", " for i, line in enumerate(lines):\n", " if \"plt.figure\" in line and \"try:\" not in lines[max(0, i-5):i]:\n", " indent = len(line) - len(line.lstrip())\n", " indent_str = \" \" * indent\n", " try_line = f\"{indent_str}try:\\n{indent_str} \"\n", " except_line = f\"\\n{indent_str}except Exception as e:\\n{indent_str} logging.error(f\\\"Error in plotting: {{e}}\\\")\"\n", " j = i\n", " while j < len(lines) and (j == i or lines[j].startswith(indent_str)):\n", " j += 1\n", " for k in range(i, j):\n", " if lines[k].strip():\n", " lines[k] = indent_str + \" \" + lines[k].lstrip()\n", " lines.insert(i, try_line.rstrip())\n", " lines.insert(j+1, except_line)\n", " code = '\\n'.join(lines)\n", " break\n", " lines = code.split('\\n')\n", " for i, line in enumerate(lines):\n", " if \"print(\" in line and any(term in line for term in ['data[', '.iloc', '.loc', 'Series', 'DataFrame']):\n", " lines[i] = line.replace(\"print(\", \"print(safe_format(\")\n", " if \"))\" not in lines[i] and \"),\" in lines[i]:\n", " lines[i] = lines[i].replace(\"),\", \")),\", 1)\n", " elif \"))\" not in lines[i] and \")\" in lines[i]:\n", " lines[i] = lines[i].replace(\")\", \"))\", 1)\n", " code = '\\n'.join(lines)\n", " return code\n" ] }, { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [], "source": [ "def run_python(code):\n", " # Create a completely separate namespace for execution\n", " namespace = {\n", " '__name__': '__main__',\n", " '__builtins__': __builtins__\n", " }\n", " \n", " # Modify the code to use a non-interactive matplotlib backend\n", " # and fix pandas formatting issues\n", " modified_code = \"\"\"\n", "import matplotlib\n", "matplotlib.use('Agg') # Use non-interactive backend\n", "\n", "# Import yfinance without setting timeout (not available in all versions)\n", "import yfinance as yf\n", "\n", "# Configure logging to show in the output\n", "import logging\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format='[%(asctime)s] %(levelname)s: %(message)s',\n", " datefmt='%H:%M:%S'\n", ")\n", "\n", "# Fix pandas formatting issues\n", "import pandas as pd\n", "pd.set_option('display.float_format', '{:.5f}'.format)\n", "\n", "# Override print to ensure it flushes immediately\n", "import builtins\n", "original_print = builtins.print\n", "def custom_print(*args, **kwargs):\n", " result = original_print(*args, **kwargs)\n", " import sys\n", " sys.stdout.flush()\n", " return result\n", "builtins.print = custom_print\n", "\n", "# Helper function to safely format pandas objects\n", "def safe_format(obj):\n", " if isinstance(obj, (pd.Series, pd.DataFrame)):\n", " return str(obj)\n", " else:\n", " return obj\n", "\"\"\"\n", " \n", " # Add the user's code\n", " modified_code += \"\\n\" + code\n", " \n", " # Capture all output\n", " output_buffer = io.StringIO()\n", " \n", " # Save original stdout and redirect to our buffer\n", " original_stdout = sys.stdout\n", " sys.stdout = output_buffer\n", " \n", " # Add timestamp for execution start\n", " print(f\"[{time.strftime('%H:%M:%S')}] Executing code...\")\n", " \n", " try:\n", " # Execute the modified code\n", " exec(modified_code, namespace)\n", " print(f\"\\n[{time.strftime('%H:%M:%S')}] Execution completed successfully.\")\n", " \n", " except ModuleNotFoundError as e:\n", " missing_module = str(e).split(\"'\")[1]\n", " print(f\"\\nError: Missing module '{missing_module}'. Click 'Install Dependencies' to install it.\")\n", " namespace[\"__missing_module__\"] = missing_module\n", " \n", " except Exception as e:\n", " print(f\"\\n[{time.strftime('%H:%M:%S')}] Error during execution: {str(e)}\")\n", " import traceback\n", " print(traceback.format_exc())\n", " \n", " finally:\n", " # Restore original stdout\n", " sys.stdout = original_stdout\n", " \n", " # Return the captured output\n", " return output_buffer.getvalue()\n", "\n", "def install_dependencies(code):\n", " import re\n", " import subprocess\n", " \n", " import_pattern = r'(?:from|import)\\s+([a-zA-Z0-9_]+)(?:\\s+(?:import|as))?'\n", " imports = re.findall(import_pattern, code)\n", " \n", " std_libs = ['os', 'sys', 'io', 'time', 'datetime', 'random', 'math', 're', 'json', \n", " 'collections', 'itertools', 'functools', 'operator', 'pathlib', 'typing']\n", " \n", " modules_to_install = [module for module in imports if module not in std_libs]\n", " \n", " if not modules_to_install:\n", " return \"No external dependencies found to install.\"\n", " \n", " results = []\n", " for module in modules_to_install:\n", " try:\n", " result = subprocess.run(\n", " [sys.executable, \"-m\", \"pip\", \"install\", module],\n", " capture_output=True,\n", " text=True,\n", " check=False\n", " )\n", " \n", " if result.returncode == 0:\n", " results.append(f\"Successfully installed {module}\")\n", " else:\n", " results.append(f\"Failed to install {module}: {result.stderr}\")\n", " except Exception as e:\n", " results.append(f\"Error installing {module}: {str(e)}\")\n", " \n", " return \"\\n\".join(results)\n" ] }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [], "source": [ "trading_strategies = [\n", " {\n", " \"name\": \"Moving Average Crossover\",\n", " \"description\": \"Moving Average Crossover strategy for S&P 500 stocks. Buy when the 20-day moving average crosses above the 50-day moving average, and sell when it crosses below.\",\n", " \"buy_signal\": \"20-day MA crosses above 50-day MA\",\n", " \"sell_signal\": \"20-day MA crosses below 50-day MA\",\n", " \"timeframe\": \"Daily\",\n", " \"risk_level\": \"Medium\"\n", " },\n", " {\n", " \"name\": \"RSI Mean Reversion\",\n", " \"description\": \"Mean reversion strategy that buys stocks when RSI falls below 30 (oversold) and sells when RSI rises above 70 (overbought).\",\n", " \"buy_signal\": \"RSI below 30 (oversold)\",\n", " \"sell_signal\": \"RSI above 70 (overbought)\",\n", " \"timeframe\": \"Daily\",\n", " \"risk_level\": \"Medium\"\n", " },\n", " {\n", " \"name\": \"Momentum Strategy\",\n", " \"description\": \"Momentum strategy that buys the top 5 performing stocks from the Dow Jones Industrial Average over the past month and rebalances monthly.\",\n", " \"buy_signal\": \"Stock in top 5 performers over past month\",\n", " \"sell_signal\": \"Stock no longer in top 5 performers at rebalance\",\n", " \"timeframe\": \"Monthly\",\n", " \"risk_level\": \"High\"\n", " },\n", " {\n", " \"name\": \"Pairs Trading\",\n", " \"description\": \"Pairs trading strategy that identifies correlated stock pairs and trades on the divergence and convergence of their price relationship.\",\n", " \"buy_signal\": \"Pairs ratio deviates 2+ standard deviations below mean\",\n", " \"sell_signal\": \"Pairs ratio returns to mean or exceeds mean\",\n", " \"timeframe\": \"Daily\",\n", " \"risk_level\": \"Medium-High\"\n", " },\n", " {\n", " \"name\": \"Bollinger Band Breakout\",\n", " \"description\": \"Volatility breakout strategy that buys when a stock breaks out of its upper Bollinger Band and sells when it reverts to the mean.\",\n", " \"buy_signal\": \"Price breaks above upper Bollinger Band (2 std dev)\",\n", " \"sell_signal\": \"Price reverts to middle Bollinger Band (SMA)\",\n", " \"timeframe\": \"Daily\",\n", " \"risk_level\": \"High\"\n", " },\n", " {\n", " \"name\": \"MACD Crossover\",\n", " \"description\": \"MACD crossover strategy that buys when the MACD line crosses above the signal line and sells when it crosses below.\",\n", " \"buy_signal\": \"MACD line crosses above signal line\",\n", " \"sell_signal\": \"MACD line crosses below signal line\",\n", " \"timeframe\": \"Daily\",\n", " \"risk_level\": \"Medium\"\n", " },\n", " {\n", " \"name\": \"Golden Cross\",\n", " \"description\": \"Golden Cross strategy that buys when the 50-day moving average crosses above the 200-day moving average and sells on the Death Cross (opposite).\",\n", " \"buy_signal\": \"50-day MA crosses above 200-day MA\",\n", " \"sell_signal\": \"50-day MA crosses below 200-day MA\",\n", " \"timeframe\": \"Daily\",\n", " \"risk_level\": \"Low\"\n", " }\n", "]\n", "\n", "sample_strategies = [strategy[\"description\"] for strategy in trading_strategies]\n" ] }, { "cell_type": "code", "execution_count": 110, "metadata": {}, "outputs": [], "source": [ "default_description = \"\"\"\n", "Create a moving average crossover strategy with the following specifications:\n", "- Use yfinance to download historical data for a list of stocks (AAPL, MSFT, AMZN, GOOGL, META)\n", "- Calculate 20-day and 50-day moving averages\n", "- Generate buy signals when the 20-day MA crosses above the 50-day MA\n", "- Generate sell signals when the 20-day MA crosses below the 50-day MA\n", "- Implement a simple backtesting framework to evaluate the strategy\n", "- Calculate performance metrics: total return, annualized return, Sharpe ratio, max drawdown\n", "- Visualize the equity curve, buy/sell signals, and moving averages\n", "\"\"\"\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-10-22 18:30:21,233 - INFO - HTTP Request: GET http://127.0.0.1:7875/gradio_api/startup-events \"HTTP/1.1 200 OK\"\n", "2025-10-22 18:30:21,238 - INFO - HTTP Request: HEAD http://127.0.0.1:7875/ \"HTTP/1.1 200 OK\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7875\n", "* To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 115, "metadata": {}, "output_type": "execute_result" }, { "name": "stderr", "output_type": "stream", "text": [ "2025-10-22 18:30:24,092 - INFO - HTTP Request: GET https://api.gradio.app/pkg-version \"HTTP/1.1 200 OK\"\n", "2025-10-22 18:31:03,437 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", "2025-10-22 18:31:15,743 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", "2025-10-22 18:31:28,425 - ERROR - Error fetching data: periods must be a number, got 2025-10-22 18:31:28.425210\n", "2025-10-22 18:31:28,429 - INFO - Synthetic data generated for tickers: AAPL\n", "2025-10-22 18:31:28,432 - INFO - Moving averages calculated with windows 20 and 50\n", "2025-10-22 18:31:28,434 - INFO - Signals generated based on moving average crossover\n", "2025-10-22 18:31:28,438 - INFO - Performance calculated\n", "2025-10-22 18:31:28,438 - INFO - Total Return: -0.010752455100331848, Sharpe Ratio: 0.18162435507214664, Max Drawdown: -0.19919271751608258\n", "2025-10-22 18:32:28,496 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", "2025-10-22 18:32:38,626 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", "2025-10-22 18:32:47,779 - ERROR - Error fetching data from yfinance: name 'start_date' is not defined. Using synthetic data.\n", "2025-10-22 18:33:23,647 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", "2025-10-22 18:33:37,829 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n" ] } ], "source": [ "with gr.Blocks(css=CSS, theme=gr.themes.Monochrome(), title=\"Trading Code Generator\") as ui:\n", " with gr.Row():\n", " gr.HTML(\"

Trading Strategy Code Generator

\")\n", " \n", " with gr.Row():\n", " # Left column - Controls\n", " with gr.Column(scale=1):\n", " strategy_dropdown = gr.Dropdown(\n", " label=\"Select Trading Strategy\",\n", " choices=[strategy[\"name\"] for strategy in trading_strategies],\n", " value=trading_strategies[0][\"name\"]\n", " )\n", " \n", " with gr.Accordion(\"Strategy Details\", open=False):\n", " strategy_info = gr.JSON(\n", " value=trading_strategies[0]\n", " )\n", " \n", " model = gr.Dropdown(\n", " label=\"Select Model\",\n", " choices=models,\n", " value=models[0]\n", " )\n", " \n", " description = gr.TextArea(\n", " label=\"Strategy Description (Edit to customize)\",\n", " value=trading_strategies[0][\"description\"],\n", " lines=4\n", " )\n", " \n", " with gr.Row():\n", " generate = gr.Button(\"Generate Code\", variant=\"primary\", size=\"sm\")\n", " run = gr.Button(\"Run Code\", size=\"sm\")\n", " install_deps = gr.Button(\"Install Dependencies\", size=\"sm\")\n", " \n", " # Right column - Code and Output\n", " with gr.Column(scale=2):\n", " trading_code = gr.Code(\n", " label=\"Generated Trading Code\",\n", " value=\"\",\n", " language=\"python\",\n", " lines=20,\n", " elem_id=\"code-block\",\n", " show_label=True\n", " )\n", " \n", " output = gr.TextArea(\n", " label=\"Execution Output\",\n", " lines=8,\n", " elem_classes=[\"trading-out\"]\n", " )\n", " \n", " def update_strategy_info(strategy_name):\n", " selected = next((s for s in trading_strategies if s[\"name\"] == strategy_name), None)\n", " if selected:\n", " return selected, selected[\"description\"]\n", " return trading_strategies[0], trading_strategies[0][\"description\"]\n", " \n", " strategy_dropdown.change(\n", " fn=update_strategy_info,\n", " inputs=strategy_dropdown,\n", " outputs=[strategy_info, description]\n", " )\n", " \n", " # Function to show validation results when generating code\n", " def generate_with_validation(model, description):\n", " # Always use GPT-4o for better code quality\n", " code = generate_trading_code(model, description, force_gpt4=True)\n", " is_valid, issues = validate_code(code)\n", " \n", " validation_message = \"\"\n", " if is_valid:\n", " validation_message = \"Code validation passed ✓\"\n", " else:\n", " validation_message = \"Code validation warnings:\\n\" + \"\\n\".join([f\"- {issue}\" for issue in issues])\n", " \n", " return code, validation_message\n", " \n", " generate.click(\n", " fn=generate_with_validation,\n", " inputs=[model, description],\n", " outputs=[trading_code, output]\n", " )\n", " \n", " run.click(\n", " fn=run_python,\n", " inputs=[trading_code],\n", " outputs=[output]\n", " )\n", " \n", " install_deps.click(\n", " fn=install_dependencies,\n", " inputs=[trading_code],\n", " outputs=[output]\n", " )\n", "\n", "ui.launch(inbrowser=True)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Testing the Trading Code Generator\n", "\n", "Let's test the trading code generator with a specific strategy description and model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_description = \"\"\"\n", "Create a simple RSI-based mean reversion strategy:\n", "- Use AAPL stock data for the past 2 years\n", "- Calculate the 14-day RSI indicator\n", "- Buy when RSI falls below 30 (oversold)\n", "- Sell when RSI rises above 70 (overbought)\n", "- Include visualization of entry/exit points\n", "- Calculate performance metrics\n", "\"\"\"\n", "\n", "test_model = \"gpt-3.5-turbo\"\n", "\n", "generated_code = generate_trading_code(test_model, test_description)\n", "print(\"Generated trading code:\")\n", "print(generated_code)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " output = run_python(generated_code)\n", " print(output)\n", "except Exception as e:\n", " print(f\"Error running the generated code: {e}\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Fixed version of the code\n", "fixed_code = \"\"\"\n", "import yfinance as yf\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import matplotlib.dates as mdates\n", "from datetime import datetime, timedelta\n", "import logging\n", "\n", "# Set up logging\n", "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n", "\n", "def calculate_moving_averages(data, short_window=20, long_window=50):\n", " \\\"\\\"\\\"\n", " Calculate short and long term moving averages\n", " \\\"\\\"\\\"\n", " data['Short_MA'] = data['Close'].rolling(window=short_window, min_periods=1).mean()\n", " data['Long_MA'] = data['Close'].rolling(window=long_window, min_periods=1).mean()\n", " return data\n", "\n", "def generate_signals(data, short_window=20, long_window=50):\n", " \\\"\\\"\\\"\n", " Generate buy/sell signals based on moving average crossover strategy\n", " \\\"\\\"\\\"\n", " data['Signal'] = 0\n", " data['Signal'][short_window:] = np.where(\n", " data['Short_MA'][short_window:] > data['Long_MA'][short_window:], 1, -1)\n", " data['Position'] = data['Signal'].shift(1)\n", " data['Position'].fillna(0, inplace=True) # Fill NaN values with 0\n", " return data\n", "\n", "def backtest_strategy(data):\n", " \\\"\\\"\\\"\n", " Backtest the trading strategy and calculate performance metrics\n", " \\\"\\\"\\\"\n", " data['Returns'] = data['Close'].pct_change()\n", " data['Strategy_Returns'] = data['Returns'] * data['Position']\n", " \n", " # Replace NaN values with 0\n", " data['Strategy_Returns'].fillna(0, inplace=True)\n", "\n", " cumulative_returns = (1 + data['Strategy_Returns']).cumprod()\n", " \n", " # Calculate metrics\n", " total_return = cumulative_returns.iloc[-1] - 1\n", " sharpe_ratio = np.sqrt(252) * (data['Strategy_Returns'].mean() / data['Strategy_Returns'].std())\n", " max_drawdown = ((cumulative_returns / cumulative_returns.cummax()) - 1).min()\n", "\n", " metrics = {\n", " 'Total Return': total_return,\n", " 'Sharpe Ratio': sharpe_ratio,\n", " 'Max Drawdown': max_drawdown\n", " }\n", "\n", " return cumulative_returns, metrics\n", "\n", "def plot_results(data, cumulative_returns, ticker):\n", " \\\"\\\"\\\"\n", " Plot the performance of the trading strategy\n", " \\\"\\\"\\\"\n", " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), gridspec_kw={'height_ratios': [2, 1]})\n", " \n", " # Price and MA plot\n", " ax1.plot(data.index, data['Close'], label='Close Price')\n", " ax1.plot(data.index, data['Short_MA'], label='20-day MA', alpha=0.7)\n", " ax1.plot(data.index, data['Long_MA'], label='50-day MA', alpha=0.7)\n", " \n", " # Add buy/sell signals\n", " buy_signals = data[data['Signal'] > data['Signal'].shift(1)]\n", " sell_signals = data[data['Signal'] < data['Signal'].shift(1)]\n", " \n", " ax1.scatter(buy_signals.index, buy_signals['Close'], marker='^', color='green', s=100, label='Buy Signal')\n", " ax1.scatter(sell_signals.index, sell_signals['Close'], marker='v', color='red', s=100, label='Sell Signal')\n", " \n", " ax1.set_title(f'Moving Average Crossover Strategy on {ticker}')\n", " ax1.set_ylabel('Price ($)')\n", " ax1.legend(loc='best')\n", " ax1.grid(True)\n", " \n", " # Returns plot\n", " ax2.plot(cumulative_returns.index, cumulative_returns, label='Cumulative Strategy Returns', color='blue')\n", " ax2.set_title('Cumulative Returns')\n", " ax2.set_xlabel('Date')\n", " ax2.set_ylabel('Returns')\n", " ax2.legend(loc='best')\n", " ax2.grid(True)\n", " \n", " plt.tight_layout()\n", " plt.show()\n", "\n", "if __name__ == \\\"__main__\\\":\n", " # User inputs\n", " ticker = 'SPY' # Example: S&P 500 ETF\n", " start_date = (datetime.now() - timedelta(days=365*2)).strftime('%Y-%m-%d')\n", " end_date = datetime.now().strftime('%Y-%m-%d')\n", "\n", " # Strategy parameters\n", " short_window = 20\n", " long_window = 50\n", "\n", " # Fetch data\n", " try:\n", " logging.info(f\\\"Fetching data for {ticker} from {start_date} to {end_date}...\\\")\n", " stock_data = yf.download(ticker, start=start_date, end=end_date)\n", " logging.info(f\\\"Data fetched successfully. Got {len(stock_data)} data points.\\\")\n", " except Exception as e:\n", " logging.error(f\\\"Failed to fetch data: {e}\\\")\n", " raise SystemExit(e)\n", "\n", " try:\n", " # Preprocess and generate signals\n", " stock_data = calculate_moving_averages(stock_data, short_window, long_window)\n", " stock_data = generate_signals(stock_data, short_window, long_window)\n", "\n", " # Backtest the strategy\n", " cumulative_returns, metrics = backtest_strategy(stock_data)\n", "\n", " # Display metrics\n", " for key, value in metrics.items():\n", " logging.info(f\\\"{key}: {value:.4f}\\\")\n", "\n", " # Plot results\n", " plot_results(stock_data, cumulative_returns, ticker)\n", " except Exception as e:\n", " logging.error(f\\\"Error while executing strategy: {e}\\\")\n", "\"\"\"\n", "\n", "# Display the fixed code\n", "print(\"Fixed code:\")\n", "print(fixed_code)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Run the fixed code\n", "output = run_python(fixed_code)\n", "print(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's also update our system_prompt to ensure the generated code works properly\n", "\n", "system_prompt = \"\"\"\n", "You are an expert algorithmic trading code generator. Your task is to generate Python code for trading strategies based on user requirements.\n", "The code should be well-structured, efficient, and ready to run in a simulated environment.\n", "\n", "The generated code should:\n", "1. Use the yfinance library for fetching stock data\n", "2. Implement the specified trading strategy\n", "3. Include proper error handling and logging\n", "4. Include visualization of the strategy performance with clear buy/sell signals\n", "5. Calculate and display relevant metrics (returns, Sharpe ratio, drawdown, etc.)\n", "6. Handle NaN values and edge cases properly\n", "7. Include informative print statements or logging to show progress\n", "\n", "IMPORTANT: Make sure all variables are properly defined before use, especially in functions.\n", "Always pass necessary parameters between functions rather than relying on global variables.\n", "\n", "Respond only with Python code. Do not provide any explanation other than occasional comments in the code.\n", "\"\"\"\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's test the updated system prompt with a simple strategy\n", "\n", "test_description_2 = \"\"\"\n", "Create a simple Bollinger Bands strategy:\n", "- Use AAPL stock data for the past 1 year\n", "- Calculate Bollinger Bands with 20-day SMA and 2 standard deviations\n", "- Buy when price touches the lower band\n", "- Sell when price touches the upper band\n", "- Include visualization of entry/exit points\n", "- Calculate performance metrics\n", "\"\"\"\n", "\n", "test_model = \"gpt-3.5-turbo\"\n", "generated_code_2 = generate_trading_code(test_model, test_description_2)\n", "print(\"Generated trading code with updated prompt:\")\n", "print(generated_code_2)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's test the run function with live logging\n", "\n", "test_code = \"\"\"\n", "import yfinance as yf\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from datetime import datetime, timedelta\n", "\n", "# Define ticker and date range\n", "ticker = 'AAPL'\n", "end_date = datetime.now()\n", "start_date = end_date - timedelta(days=365)\n", "\n", "# Download data\n", "print(f\"Downloading data for {ticker}...\")\n", "data = yf.download(ticker, start=start_date, end=end_date)\n", "print(f\"Downloaded {len(data)} rows of data\")\n", "\n", "# Calculate RSI\n", "print(\"Calculating RSI...\")\n", "delta = data['Close'].diff()\n", "gain = delta.where(delta > 0, 0)\n", "loss = -delta.where(delta < 0, 0)\n", "avg_gain = gain.rolling(window=14).mean()\n", "avg_loss = loss.rolling(window=14).mean()\n", "rs = avg_gain / avg_loss\n", "data['RSI'] = 100 - (100 / (1 + rs))\n", "\n", "# Generate signals\n", "print(\"Generating trading signals...\")\n", "data['Signal'] = 0\n", "data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", "data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", "\n", "# Count signals\n", "buy_signals = len(data[data['Signal'] == 1])\n", "sell_signals = len(data[data['Signal'] == -1])\n", "print(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", "\n", "# Print sample of the data\n", "print(\"\\\\nSample of the processed data:\")\n", "print(data[['Close', 'RSI', 'Signal']].tail())\n", "\n", "print(\"\\\\nAnalysis complete!\")\n", "\"\"\"\n", "\n", "output = run_python(test_code)\n", "print(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's test the improved run function again\n", "\n", "test_code_2 = \"\"\"\n", "import yfinance as yf\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from datetime import datetime, timedelta\n", "\n", "# Define ticker and date range\n", "ticker = 'AAPL'\n", "end_date = datetime.now()\n", "start_date = end_date - timedelta(days=365)\n", "\n", "# Download data\n", "print(f\"Downloading data for {ticker}...\")\n", "data = yf.download(ticker, start=start_date, end=end_date, progress=False)\n", "print(f\"Downloaded {len(data)} rows of data\")\n", "\n", "# Calculate RSI\n", "print(\"Calculating RSI...\")\n", "delta = data['Close'].diff()\n", "gain = delta.where(delta > 0, 0)\n", "loss = -delta.where(delta < 0, 0)\n", "avg_gain = gain.rolling(window=14).mean()\n", "avg_loss = loss.rolling(window=14).mean()\n", "rs = avg_gain / avg_loss\n", "data['RSI'] = 100 - (100 / (1 + rs))\n", "\n", "# Generate signals\n", "print(\"Generating trading signals...\")\n", "data['Signal'] = 0\n", "data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", "data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", "\n", "# Count signals\n", "buy_signals = len(data[data['Signal'] == 1])\n", "sell_signals = len(data[data['Signal'] == -1])\n", "print(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", "\n", "# Print sample of the data\n", "print(\"\\\\nSample of the processed data:\")\n", "print(data[['Close', 'RSI', 'Signal']].tail())\n", "\n", "print(\"\\\\nAnalysis complete!\")\n", "\"\"\"\n", "\n", "output = run_python(test_code_2)\n", "print(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test the completely rewritten run function\n", "\n", "simple_test = \"\"\"\n", "print(\"Hello from the trading code generator!\")\n", "print(\"Testing output capture...\")\n", "\n", "# Simulate some data processing\n", "import numpy as np\n", "data = np.random.rand(5, 3)\n", "print(\"Generated random data:\")\n", "print(data)\n", "\n", "# Show a calculation\n", "result = np.mean(data, axis=0)\n", "print(\"Mean of each column:\")\n", "print(result)\n", "\n", "print(\"Test complete!\")\n", "\"\"\"\n", "\n", "output = run_python(simple_test)\n", "print(\"Output from execution:\")\n", "print(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test the improved code generation with a strategy that typically causes formatting issues\n", "\n", "test_description_3 = \"\"\"\n", "Create a simple trading strategy that:\n", "- Uses AAPL stock data for the past year\n", "- Calculates both RSI and Bollinger Bands\n", "- Buys when price is below lower Bollinger Band AND RSI is below 30\n", "- Sells when price is above upper Bollinger Band OR RSI is above 70\n", "- Includes proper error handling for all calculations\n", "- Visualizes the entry/exit points and performance\n", "\"\"\"\n", "\n", "test_model = \"gpt-3.5-turbo\"\n", "generated_code_3 = generate_trading_code(test_model, test_description_3)\n", "print(\"Generated trading code with enhanced validation:\")\n", "print(generated_code_3)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's test running the generated code\n", "\n", "output = run_python(generated_code_3)\n", "print(\"Execution output:\")\n", "print(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test the improved run function with a simple example that prints output\n", "\n", "simple_test_2 = \"\"\"\n", "print(\"This is a test of the output capture system\")\n", "print(\"Line 1 of output\")\n", "print(\"Line 2 of output\")\n", "\n", "# Import and use numpy\n", "import numpy as np\n", "data = np.random.rand(3, 3)\n", "print(\"Random matrix:\")\n", "print(data)\n", "\n", "# Create a simple plot\n", "import matplotlib.pyplot as plt\n", "plt.figure(figsize=(8, 4))\n", "plt.plot([1, 2, 3, 4], [10, 20, 25, 30], 'ro-')\n", "plt.title('Simple Plot')\n", "plt.xlabel('X axis')\n", "plt.ylabel('Y axis')\n", "plt.grid(True)\n", "plt.savefig('simple_plot.png') # Save instead of showing\n", "print(\"Plot saved to simple_plot.png\")\n", "\n", "print(\"Test complete!\")\n", "\"\"\"\n", "\n", "output = run_python(simple_test_2)\n", "print(\"Output from execution:\")\n", "print(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test with a simpler example that won't time out\n", "\n", "simple_test_3 = \"\"\"\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import logging\n", "\n", "# Set up logging\n", "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n", "\n", "# Generate synthetic stock data\n", "def generate_stock_data(days=252, volatility=0.01):\n", " logging.info(f\"Generating {days} days of synthetic stock data\")\n", " np.random.seed(42)\n", " price = 100\n", " prices = [price]\n", " \n", " for _ in range(days - 1):\n", " change = np.random.normal(0, volatility)\n", " price *= (1 + change)\n", " prices.append(price)\n", " \n", " dates = pd.date_range(end=pd.Timestamp.today(), periods=days)\n", " df = pd.DataFrame({\n", " 'Close': prices,\n", " 'Open': [p * (1 - volatility/2) for p in prices],\n", " 'High': [p * (1 + volatility) for p in prices],\n", " 'Low': [p * (1 - volatility) for p in prices],\n", " 'Volume': [np.random.randint(100000, 10000000) for _ in range(days)]\n", " }, index=dates)\n", " \n", " logging.info(f\"Generated data with shape {df.shape}\")\n", " return df\n", "\n", "# Calculate RSI\n", "def calculate_rsi(data, window=14):\n", " logging.info(f\"Calculating RSI with {window}-day window\")\n", " delta = data['Close'].diff()\n", " gain = delta.where(delta > 0, 0).rolling(window=window, min_periods=1).mean()\n", " loss = -delta.where(delta < 0, 0).rolling(window=window, min_periods=1).mean()\n", " \n", " rs = gain / loss\n", " rsi = 100 - (100 / (1 + rs))\n", " return rsi\n", "\n", "# Main function\n", "if __name__ == \"__main__\":\n", " # Generate data\n", " data = generate_stock_data()\n", " \n", " # Calculate RSI\n", " data['RSI'] = calculate_rsi(data)\n", " \n", " # Generate signals\n", " logging.info(\"Generating trading signals\")\n", " data['Signal'] = 0\n", " data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", " data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", " \n", " # Count signals\n", " buy_signals = len(data[data['Signal'] == 1])\n", " sell_signals = len(data[data['Signal'] == -1])\n", " logging.info(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", " \n", " # Plot the results\n", " logging.info(\"Creating visualization\")\n", " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), gridspec_kw={'height_ratios': [3, 1]})\n", " \n", " # Price chart\n", " ax1.plot(data.index, data['Close'], label='Close Price')\n", " \n", " # Add buy/sell signals\n", " buy_points = data[data['Signal'] == 1]\n", " sell_points = data[data['Signal'] == -1]\n", " \n", " ax1.scatter(buy_points.index, buy_points['Close'], marker='^', color='g', s=100, label='Buy')\n", " ax1.scatter(sell_points.index, sell_points['Close'], marker='v', color='r', s=100, label='Sell')\n", " \n", " ax1.set_title('Stock Price with RSI Signals')\n", " ax1.set_ylabel('Price')\n", " ax1.legend()\n", " ax1.grid(True)\n", " \n", " # RSI chart\n", " ax2.plot(data.index, data['RSI'], color='purple', label='RSI')\n", " ax2.axhline(y=70, color='r', linestyle='--', alpha=0.5)\n", " ax2.axhline(y=30, color='g', linestyle='--', alpha=0.5)\n", " ax2.set_title('RSI Indicator')\n", " ax2.set_ylabel('RSI')\n", " ax2.set_ylim(0, 100)\n", " ax2.grid(True)\n", " \n", " plt.tight_layout()\n", " plt.savefig('rsi_strategy.png')\n", " logging.info(\"Plot saved to rsi_strategy.png\")\n", " \n", " # Print sample of data\n", " logging.info(\"Sample of the processed data:\")\n", " print(data[['Close', 'RSI', 'Signal']].tail())\n", " \n", " logging.info(\"Analysis complete!\")\n", "\"\"\"\n", "\n", "output = run_python(simple_test_3)\n", "print(\"Output from execution:\")\n", "print(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test the enhanced code generation with GPT-4o\n", "\n", "test_description_4 = \"\"\"\n", "Create a trading strategy that:\n", "- Uses both MACD and RSI indicators\n", "- Buys when MACD crosses above signal line AND RSI is below 40\n", "- Sells when MACD crosses below signal line OR RSI is above 70\n", "- Includes proper visualization with buy/sell signals\n", "- Uses synthetic data if API calls fail\n", "- Calculates performance metrics including Sharpe ratio and max drawdown\n", "\"\"\"\n", "\n", "print(\"Generating trading code with GPT-4o...\")\n", "generated_code_4 = generate_trading_code(\"gpt-4o\", test_description_4, force_gpt4=True)\n", "print(\"Code generation complete. Validating...\")\n", "is_valid, issues = validate_code(generated_code_4)\n", "\n", "if issues:\n", " print(f\"Validation found {len(issues)} issues:\")\n", " for issue in issues:\n", " print(f\"- {issue}\")\n", "else:\n", " print(\"Code validation passed ✓\")\n", "\n", "print(\"\\nGenerated code snippet (first 20 lines):\")\n", "print(\"\\n\".join(generated_code_4.split(\"\\n\")[:20]))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's run the generated code to test it\n", "\n", "output = run_python(generated_code_4)\n", "print(\"Execution output:\")\n", "print(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's test again with the fixed timeout setting\n", "\n", "test_description_5 = \"\"\"\n", "Create a simple trading strategy that:\n", "- Uses synthetic data generation to avoid API timeouts\n", "- Implements a simple moving average crossover (5-day and 20-day)\n", "- Includes proper visualization with buy/sell signals\n", "- Calculates basic performance metrics\n", "\"\"\"\n", "\n", "print(\"Generating trading code with proper yfinance timeout settings...\")\n", "generated_code_5 = generate_trading_code(\"gpt-4o\", test_description_5, force_gpt4=True)\n", "print(\"Code generation complete.\")\n", "\n", "# Run the generated code\n", "output = run_python(generated_code_5)\n", "print(\"Execution output:\")\n", "print(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test with a simpler strategy that focuses on scatter plot safety\n", "\n", "test_description_6 = \"\"\"\n", "Create a simple trading strategy that:\n", "- Uses synthetic data generation only (no API calls)\n", "- Implements a simple RSI-based strategy (buy when RSI < 30, sell when RSI > 70)\n", "- Includes visualization with buy/sell signals using scatter plots\n", "- Calculates basic performance metrics\n", "- Uses proper error handling for all operations\n", "\"\"\"\n", "\n", "print(\"Generating trading code with scatter plot safety...\")\n", "generated_code_6 = generate_trading_code(\"gpt-4o\", test_description_6, force_gpt4=True)\n", "print(\"Code generation complete.\")\n", "\n", "# Run the generated code\n", "output = run_python(generated_code_6)\n", "print(\"Execution output:\")\n", "print(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test with a fixed example that properly handles pandas formatting and scatter plots\n", "\n", "test_fixed_code = \"\"\"\n", "import yfinance as yf\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import logging\n", "from datetime import datetime, timedelta\n", "\n", "# Configure logging\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format='[%(asctime)s] %(levelname)s: %(message)s',\n", " datefmt='%H:%M:%S'\n", ")\n", "\n", "# Helper function for safe formatting of pandas objects\n", "def safe_format(obj):\n", " if isinstance(obj, (pd.Series, pd.DataFrame)):\n", " return str(obj)\n", " return obj\n", "\n", "# Helper function to safely create scatter plots\n", "def safe_scatter(ax, x, y, *args, **kwargs):\n", " # Ensure x and y are the same length\n", " if len(x) != len(y):\n", " logging.warning(f\"Scatter plot inputs have different lengths: x={len(x)}, y={len(y)}\")\n", " # Find the minimum length\n", " min_len = min(len(x), len(y))\n", " x = x[:min_len]\n", " y = y[:min_len]\n", " \n", " # Check for empty arrays\n", " if len(x) == 0 or len(y) == 0:\n", " logging.warning(\"Empty arrays passed to scatter plot, skipping\")\n", " return None\n", " \n", " return ax.scatter(x, y, *args, **kwargs)\n", "\n", "# Generate synthetic data\n", "def generate_synthetic_data(ticker='AAPL', days=252, seed=42):\n", " logging.info(f\"Generating synthetic data for {ticker}\")\n", " np.random.seed(seed)\n", " \n", " # Generate price data\n", " price = 100 # Starting price\n", " prices = [price]\n", " \n", " for _ in range(days):\n", " change = np.random.normal(0, 0.01) # 1% volatility\n", " price *= (1 + change)\n", " prices.append(price)\n", " \n", " # Create date range\n", " end_date = datetime.now()\n", " start_date = end_date - timedelta(days=days)\n", " dates = pd.date_range(start=start_date, end=end_date, periods=len(prices))\n", " \n", " # Create DataFrame\n", " df = pd.DataFrame({\n", " 'Open': prices[:-1],\n", " 'High': [p * 1.01 for p in prices[:-1]],\n", " 'Low': [p * 0.99 for p in prices[:-1]],\n", " 'Close': prices[1:],\n", " 'Volume': [np.random.randint(1000000, 10000000) for _ in range(len(prices)-1)]\n", " }, index=dates[:-1])\n", " \n", " logging.info(f\"Generated {len(df)} days of data for {ticker}\")\n", " return df\n", "\n", "# Calculate RSI\n", "def calculate_rsi(data, window=14):\n", " logging.info(f\"Calculating RSI with {window}-day window\")\n", " delta = data['Close'].diff()\n", " gain = delta.where(delta > 0, 0)\n", " loss = -delta.where(delta < 0, 0)\n", " \n", " avg_gain = gain.rolling(window=window, min_periods=1).mean()\n", " avg_loss = loss.rolling(window=window, min_periods=1).mean()\n", " \n", " rs = avg_gain / avg_loss\n", " rsi = 100 - (100 / (1 + rs))\n", " return rsi\n", "\n", "# Generate signals\n", "def generate_signals(data):\n", " logging.info(\"Generating trading signals\")\n", " data['Signal'] = 0\n", " data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", " data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", " \n", " # Count signals\n", " buy_signals = len(data[data['Signal'] == 1])\n", " sell_signals = len(data[data['Signal'] == -1])\n", " logging.info(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", " return data\n", "\n", "# Backtest strategy\n", "def backtest_strategy(data):\n", " logging.info(\"Backtesting strategy\")\n", " data['Returns'] = data['Close'].pct_change()\n", " data['Strategy'] = data['Signal'].shift(1) * data['Returns']\n", " \n", " # Replace NaN values\n", " data['Strategy'].fillna(0, inplace=True)\n", " \n", " # Calculate cumulative returns\n", " data['Cumulative'] = (1 + data['Strategy']).cumprod()\n", " \n", " # Calculate metrics\n", " total_return = data['Cumulative'].iloc[-1] - 1\n", " sharpe = np.sqrt(252) * data['Strategy'].mean() / data['Strategy'].std()\n", " max_dd = (data['Cumulative'] / data['Cumulative'].cummax() - 1).min()\n", " \n", " logging.info(f\"Total Return: {total_return:.4f}\")\n", " logging.info(f\"Sharpe Ratio: {sharpe:.4f}\")\n", " logging.info(f\"Max Drawdown: {max_dd:.4f}\")\n", " \n", " return data\n", "\n", "# Visualize results\n", "def visualize_results(data, ticker):\n", " logging.info(\"Creating visualization\")\n", " try:\n", " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [3, 1]})\n", " \n", " # Price chart\n", " ax1.plot(data.index, data['Close'], label='Close Price')\n", " \n", " # Add buy/sell signals\n", " buy_points = data[data['Signal'] == 1]\n", " sell_points = data[data['Signal'] == -1]\n", " \n", " # Use safe scatter to avoid \"x and y must be the same size\" error\n", " if not buy_points.empty:\n", " safe_scatter(ax1, buy_points.index, buy_points['Close'], marker='^', color='g', s=100, label='Buy')\n", " \n", " if not sell_points.empty:\n", " safe_scatter(ax1, sell_points.index, sell_points['Close'], marker='v', color='r', s=100, label='Sell')\n", " \n", " ax1.set_title(f'RSI Strategy for {ticker}')\n", " ax1.set_ylabel('Price')\n", " ax1.legend()\n", " ax1.grid(True)\n", " \n", " # RSI chart\n", " ax2.plot(data.index, data['RSI'], color='purple', label='RSI')\n", " ax2.axhline(y=70, color='r', linestyle='--', alpha=0.5)\n", " ax2.axhline(y=30, color='g', linestyle='--', alpha=0.5)\n", " ax2.set_title('RSI Indicator')\n", " ax2.set_ylabel('RSI')\n", " ax2.set_ylim(0, 100)\n", " ax2.grid(True)\n", " \n", " plt.tight_layout()\n", " plt.savefig('rsi_strategy.png')\n", " logging.info(\"Plot saved to rsi_strategy.png\")\n", " except Exception as e:\n", " logging.error(f\"Error in visualization: {e}\")\n", "\n", "if __name__ == \"__main__\":\n", " # Settings\n", " ticker = 'AAPL'\n", " days = 252 # One year of trading days\n", " \n", " # Generate data\n", " data = generate_synthetic_data(ticker, days)\n", " \n", " # Calculate RSI\n", " data['RSI'] = calculate_rsi(data)\n", " \n", " # Generate signals\n", " data = generate_signals(data)\n", " \n", " # Backtest strategy\n", " data = backtest_strategy(data)\n", " \n", " # Visualize results\n", " visualize_results(data, ticker)\n", " \n", " # Print sample of data\n", " logging.info(\"Sample of the processed data:\")\n", " print(data[['Close', 'RSI', 'Signal', 'Strategy', 'Cumulative']].tail())\n", " \n", " logging.info(\"Analysis complete!\")\n", "\"\"\"\n", "\n", "output = run_python(test_fixed_code)\n", "print(\"Output from execution:\")\n", "print(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test the fix for indentation error in plotting code\n", "\n", "test_code_with_plotting = \"\"\"\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import logging\n", "\n", "# Configure logging\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format='[%(asctime)s] %(levelname)s: %(message)s',\n", " datefmt='%H:%M:%S'\n", ")\n", "\n", "# Simple plotting function\n", "def create_plot():\n", " # Generate some data\n", " x = np.linspace(0, 10, 100)\n", " y = np.sin(x)\n", " \n", " # Create plot\n", " logging.info(\"Creating sine wave plot\")\n", " plt.figure(figsize=(10, 6))\n", " plt.plot(x, y)\n", " plt.title('Sine Wave')\n", " plt.xlabel('X')\n", " plt.ylabel('Y')\n", " plt.grid(True)\n", " plt.savefig('sine_wave.png')\n", " logging.info(\"Plot saved to sine_wave.png\")\n", "\n", "if __name__ == \"__main__\":\n", " create_plot()\n", "\"\"\"\n", "\n", "# Apply safety features to the code\n", "enhanced_code = add_safety_features(test_code_with_plotting)\n", "print(\"Code with safety features applied:\")\n", "print(enhanced_code)\n", "\n", "# Run the enhanced code\n", "output = run_python(enhanced_code)\n", "print(\"\\nExecution output:\")\n", "print(output)\n" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 2 }