1834 lines
78 KiB
Plaintext
1834 lines
78 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Trading Code Generator\n",
|
|
"\n",
|
|
"This notebook creates a code generator that produces trading code to buy and sell equities in a simulated environment based on free APIs. It uses Gradio for the UI, similar to the approach in day5.ipynb.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import io\n",
|
|
"import sys\n",
|
|
"import time\n",
|
|
"import random\n",
|
|
"import numpy as np\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"from openai import OpenAI\n",
|
|
"import gradio as gr\n",
|
|
"from IPython.display import display\n",
|
|
"from huggingface_hub import InferenceClient\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"OpenAI API Key exists and begins sk-proj-\n",
|
|
"Hugging Face Token exists and begins hf_fNncb\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"load_dotenv(override=True)\n",
|
|
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
|
"hf_token = os.getenv('HF_TOKEN')\n",
|
|
"\n",
|
|
"if openai_api_key:\n",
|
|
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
|
|
"else:\n",
|
|
" print(\"OpenAI API Key not set\")\n",
|
|
" \n",
|
|
"if hf_token:\n",
|
|
" print(f\"Hugging Face Token exists and begins {hf_token[:8]}\")\n",
|
|
"else:\n",
|
|
" print(\"Hugging Face Token not set\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"openai_client = OpenAI()\n",
|
|
"hf_client = InferenceClient(token=hf_token)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"models = [\"gpt-4o\", \"gpt-3.5-turbo\", \"meta-llama/Llama-2-70b-chat-hf\"]\n",
|
|
"\n",
|
|
"def generate_with_openai(model, messages):\n",
|
|
" response = openai_client.chat.completions.create(\n",
|
|
" model=model, \n",
|
|
" messages=messages\n",
|
|
" )\n",
|
|
" return response.choices[0].message.content\n",
|
|
"\n",
|
|
"def generate_with_hf(model, messages):\n",
|
|
" prompt = \"\"\n",
|
|
" for msg in messages:\n",
|
|
" role = msg[\"role\"]\n",
|
|
" content = msg[\"content\"]\n",
|
|
" if role == \"system\":\n",
|
|
" prompt += f\"<s>[INST] {content} [/INST]</s>\\n\"\n",
|
|
" elif role == \"user\":\n",
|
|
" prompt += f\"<s>[INST] {content} [/INST]</s>\\n\"\n",
|
|
" else:\n",
|
|
" prompt += f\"{content}\\n\"\n",
|
|
" \n",
|
|
" response = hf_client.text_generation(\n",
|
|
" prompt,\n",
|
|
" model=model,\n",
|
|
" max_new_tokens=1024,\n",
|
|
" temperature=0.7,\n",
|
|
" repetition_penalty=1.2\n",
|
|
" )\n",
|
|
" return response\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"CSS = \"\"\"\n",
|
|
":root {\n",
|
|
" --py-color: #209dd7;\n",
|
|
" --trading-color: #27ae60;\n",
|
|
" --accent: #753991;\n",
|
|
" --card: #161a22;\n",
|
|
" --text: #e9eef5;\n",
|
|
"}\n",
|
|
"\n",
|
|
"/* Full-width layout */\n",
|
|
".gradio-container {\n",
|
|
" max-width: 100% !important;\n",
|
|
" padding: 0 40px !important;\n",
|
|
"}\n",
|
|
"\n",
|
|
"/* Code card styling */\n",
|
|
".card {\n",
|
|
" background: var(--card);\n",
|
|
" border: 1px solid rgba(255,255,255,.08);\n",
|
|
" border-radius: 14px;\n",
|
|
" padding: 10px;\n",
|
|
"}\n",
|
|
"\n",
|
|
"/* Make code block scrollable but fixed height */\n",
|
|
"#code-block {\n",
|
|
" max-height: 400px !important;\n",
|
|
" overflow-y: auto !important;\n",
|
|
"}\n",
|
|
"\n",
|
|
"#code-block .cm-editor {\n",
|
|
" height: 400px !important;\n",
|
|
"}\n",
|
|
"\n",
|
|
"/* Buttons */\n",
|
|
".generate-btn button {\n",
|
|
" background: var(--accent) !important;\n",
|
|
" border-color: rgba(255,255,255,.12) !important;\n",
|
|
" color: white !important;\n",
|
|
" font-weight: 700;\n",
|
|
"}\n",
|
|
".run-btn button {\n",
|
|
" background: #202631 !important;\n",
|
|
" color: var(--text) !important;\n",
|
|
" border-color: rgba(255,255,255,.12) !important;\n",
|
|
"}\n",
|
|
".run-btn.py button:hover { box-shadow: 0 0 0 2px var(--py-color) inset; }\n",
|
|
".run-btn.trading button:hover { box-shadow: 0 0 0 2px var(--trading-color) inset; }\n",
|
|
".generate-btn button:hover { box-shadow: 0 0 0 2px var(--accent) inset; }\n",
|
|
"\n",
|
|
"/* Outputs with color tint */\n",
|
|
".py-out textarea {\n",
|
|
" background: linear-gradient(180deg, rgba(32,157,215,.18), rgba(32,157,215,.10));\n",
|
|
" border: 1px solid rgba(32,157,215,.35) !important;\n",
|
|
" color: rgba(32,157,215,1) !important;\n",
|
|
" font-weight: 600;\n",
|
|
"}\n",
|
|
".trading-out textarea {\n",
|
|
" background: linear-gradient(180deg, rgba(39,174,96,.18), rgba(39,174,96,.10));\n",
|
|
" border: 1px solid rgba(39,174,96,.35) !important;\n",
|
|
" color: rgba(39,174,96,1) !important;\n",
|
|
" font-weight: 600;\n",
|
|
"}\n",
|
|
"\n",
|
|
"/* Align controls neatly */\n",
|
|
".controls .wrap {\n",
|
|
" gap: 10px;\n",
|
|
" justify-content: center;\n",
|
|
" align-items: center;\n",
|
|
"}\n",
|
|
"\"\"\"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"system_prompt = \"\"\"\n",
|
|
"You are an expert algorithmic trading code generator. Generate clean, bug-free Python code for trading strategies.\n",
|
|
"\n",
|
|
"Generate code that:\n",
|
|
"1. Uses synthetic data generation only - no API calls\n",
|
|
"2. Implements the specified trading strategy\n",
|
|
"3. Uses proper error handling\n",
|
|
"4. Visualizes strategy performance with buy/sell signals\n",
|
|
"5. Calculates performance metrics\n",
|
|
"6. Handles edge cases properly\n",
|
|
"\n",
|
|
"REQUIREMENTS:\n",
|
|
"1. Include if __name__ == \"__main__\": block that executes immediately\n",
|
|
"2. Define all variables before use\n",
|
|
"3. Pass parameters between functions, avoid global variables\n",
|
|
"4. NO explanatory text outside of code\n",
|
|
"5. NO markdown blocks or language indicators\n",
|
|
"6. Code must execute without user input\n",
|
|
"7. Use str() for pandas objects in f-strings\n",
|
|
"8. Use .copy() for DataFrame views that will be modified\n",
|
|
"9. Include min_periods in rolling calculations\n",
|
|
"10. Check array lengths before scatter plots\n",
|
|
"11. Configure logging properly\n",
|
|
"12. Include helper functions for formatting and plotting\n",
|
|
"\n",
|
|
"Respond ONLY with Python code. No explanations or markdown.\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"def user_prompt_for(description):\n",
|
|
" return f\"\"\"\n",
|
|
"Generate Python code for a trading strategy:\n",
|
|
"\n",
|
|
"{description}\n",
|
|
"\n",
|
|
"Requirements:\n",
|
|
"1. Use synthetic data generation only\n",
|
|
"2. Implement the strategy exactly as described\n",
|
|
"3. Include backtesting functionality\n",
|
|
"4. Visualize results with matplotlib\n",
|
|
"5. Calculate performance metrics\n",
|
|
"6. Handle all edge cases\n",
|
|
"7. No comments needed\n",
|
|
"\n",
|
|
"Make the code complete and runnable as-is with all necessary imports.\n",
|
|
"\"\"\"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 87,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def messages_for(description):\n",
|
|
" return [\n",
|
|
" {\"role\": \"system\", \"content\": system_prompt},\n",
|
|
" {\"role\": \"user\", \"content\": user_prompt_for(description)}\n",
|
|
" ]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def validate_code(code):\n",
|
|
" issues = []\n",
|
|
" if \"import yfinance\" not in code and \"from yfinance\" not in code:\n",
|
|
" issues.append(\"Missing yfinance import\")\n",
|
|
" if \"import matplotlib\" not in code and \"from matplotlib\" not in code:\n",
|
|
" issues.append(\"Missing matplotlib import\")\n",
|
|
" if \"__name__ == \\\"__main__\\\"\" not in code and \"__name__ == '__main__'\" not in code:\n",
|
|
" issues.append(\"Missing if __name__ == '__main__' block\")\n",
|
|
" if \"f\\\"\" in code or \"f'\" in code:\n",
|
|
" lines = code.split('\\n')\n",
|
|
" for i, line in enumerate(lines):\n",
|
|
" if ('f\"' in line or \"f'\" in line) and ('data[' in line or '.iloc' in line or '.loc' in line):\n",
|
|
" issues.append(f\"Potentially unsafe f-string formatting with pandas objects on line {i+1}\")\n",
|
|
" if \"try:\" in code and \"except\" not in code:\n",
|
|
" issues.append(\"Try block without except clause\")\n",
|
|
" if \"rolling\" in code and \"min_periods\" not in code:\n",
|
|
" issues.append(\"Rolling window without min_periods parameter (may produce NaN values)\")\n",
|
|
" if \".loc\" in code and \"iloc\" not in code and \"copy()\" not in code:\n",
|
|
" issues.append(\"Potential pandas SettingWithCopyWarning - consider using .copy() before modifications\")\n",
|
|
" lines = code.split('\\n')\n",
|
|
" defined_vars = set()\n",
|
|
" for line in lines:\n",
|
|
" if line.strip().startswith('#') or not line.strip():\n",
|
|
" continue\n",
|
|
" if '=' in line and not line.strip().startswith('if') and not line.strip().startswith('elif') and not line.strip().startswith('while'):\n",
|
|
" var_name = line.split('=')[0].strip()\n",
|
|
" if var_name:\n",
|
|
" defined_vars.add(var_name)\n",
|
|
" if issues:\n",
|
|
" return False, issues\n",
|
|
" return True, []\n",
|
|
"\n",
|
|
"def generate_trading_code(model, description, force_gpt4=False):\n",
|
|
" messages = messages_for(description)\n",
|
|
" if force_gpt4:\n",
|
|
" try:\n",
|
|
" reply = generate_with_openai(\"gpt-4o\", messages)\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error using GPT-4o: {e}. Falling back to selected model.\")\n",
|
|
" if \"gpt\" in model.lower():\n",
|
|
" reply = generate_with_openai(model, messages)\n",
|
|
" else:\n",
|
|
" reply = generate_with_hf(model, messages)\n",
|
|
" else:\n",
|
|
" if \"gpt\" in model.lower():\n",
|
|
" reply = generate_with_openai(model, messages)\n",
|
|
" else:\n",
|
|
" reply = generate_with_hf(model, messages)\n",
|
|
" reply = reply.replace('```python','').replace('```','')\n",
|
|
" is_valid, issues = validate_code(reply)\n",
|
|
" max_attempts = 3\n",
|
|
" attempt = 0\n",
|
|
" fix_model = \"gpt-4o\" if force_gpt4 else model\n",
|
|
" while not is_valid and attempt < max_attempts and (\"gpt\" in model.lower() or force_gpt4):\n",
|
|
" attempt += 1\n",
|
|
" fix_messages = messages.copy()\n",
|
|
" fix_messages.append({\"role\": \"assistant\", \"content\": reply})\n",
|
|
" fix_request = f\"\"\"The code has the following issues that need to be fixed:\n",
|
|
"{chr(10).join([f\"- {issue}\" for issue in issues])}\n",
|
|
"\n",
|
|
"Please provide a completely corrected version that addresses these issues. Make sure to:\n",
|
|
"\n",
|
|
"1. Avoid using f-strings with pandas Series or DataFrame objects directly\n",
|
|
"2. Always handle NaN values in calculations with proper checks\n",
|
|
"3. Use proper error handling with try/except blocks around all API calls and calculations\n",
|
|
"4. Include min_periods parameter in rolling window calculations\n",
|
|
"5. Use .copy() when creating views of DataFrames that will be modified\n",
|
|
"6. Make sure all variables are properly defined before use\n",
|
|
"7. Add yfinance timeout settings: yf.set_timeout(30)\n",
|
|
"8. Add proper logging for all steps\n",
|
|
"9. Use synthetic data generation as a fallback if API calls fail\n",
|
|
"10. Include proper if __name__ == \"__main__\" block\n",
|
|
"\n",
|
|
"Return ONLY the corrected code with no explanation or markdown formatting.\n",
|
|
"\"\"\"\n",
|
|
" fix_messages.append({\"role\": \"user\", \"content\": fix_request})\n",
|
|
" try:\n",
|
|
" if force_gpt4:\n",
|
|
" fixed_reply = generate_with_openai(\"gpt-4o\", fix_messages)\n",
|
|
" else:\n",
|
|
" if \"gpt\" in model.lower():\n",
|
|
" fixed_reply = generate_with_openai(model, fix_messages)\n",
|
|
" else:\n",
|
|
" fixed_reply = generate_with_hf(model, fix_messages)\n",
|
|
" fixed_reply = fixed_reply.replace('```python','').replace('```','')\n",
|
|
" is_fixed_valid, fixed_issues = validate_code(fixed_reply)\n",
|
|
" if is_fixed_valid or len(fixed_issues) < len(issues):\n",
|
|
" reply = fixed_reply\n",
|
|
" is_valid = is_fixed_valid\n",
|
|
" issues = fixed_issues\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error during fix attempt {attempt}: {e}\")\n",
|
|
" reply = add_safety_features(reply)\n",
|
|
" return reply\n",
|
|
"\n",
|
|
"def add_safety_features(code):\n",
|
|
" if \"pandas\" in code:\n",
|
|
" safety_imports = \"\"\"\n",
|
|
"import pandas as pd\n",
|
|
"pd.set_option('display.float_format', '{:.5f}'.format)\n",
|
|
"\n",
|
|
"def safe_format(obj):\n",
|
|
" if isinstance(obj, (pd.Series, pd.DataFrame)):\n",
|
|
" return str(obj)\n",
|
|
" return obj\n",
|
|
"\"\"\"\n",
|
|
" import_lines = [i for i, line in enumerate(code.split('\\n')) if 'import' in line]\n",
|
|
" if import_lines:\n",
|
|
" lines = code.split('\\n')\n",
|
|
" lines.insert(import_lines[-1] + 1, safety_imports)\n",
|
|
" code = '\\n'.join(lines)\n",
|
|
" code = code.replace(\"yf.set_timeout(30)\", \"\")\n",
|
|
" code = code.replace(\"yf.pdr_override()\", \"\")\n",
|
|
" lines = code.split('\\n')\n",
|
|
" for i, line in enumerate(lines):\n",
|
|
" if 'f\"' in line or \"f'\" in line:\n",
|
|
" if any(term in line for term in ['data[', '.iloc', '.loc', 'Series', 'DataFrame']):\n",
|
|
" for term in ['.mean()', '.sum()', '.std()', '.min()', '.max()']:\n",
|
|
" if term in line:\n",
|
|
" lines[i] = line.replace(f\"{term}\", f\"{term})\")\n",
|
|
" lines[i] = lines[i].replace(\"f\\\"\", \"f\\\"{safe_format(\")\n",
|
|
" lines[i] = lines[i].replace(\"f'\", \"f'{safe_format(\")\n",
|
|
" code = '\\n'.join(lines)\n",
|
|
" if \"plt.scatter\" in code or \".scatter\" in code:\n",
|
|
" scatter_safety = \"\"\"\n",
|
|
"def safe_scatter(ax, x, y, *args, **kwargs):\n",
|
|
" if len(x) != len(y):\n",
|
|
" min_len = min(len(x), len(y))\n",
|
|
" x = x[:min_len]\n",
|
|
" y = y[:min_len]\n",
|
|
" if len(x) == 0 or len(y) == 0:\n",
|
|
" return None\n",
|
|
" return ax.scatter(x, y, *args, **kwargs)\n",
|
|
"\"\"\"\n",
|
|
" func_lines = [i for i, line in enumerate(code.split('\\n')) if line.startswith('def ')]\n",
|
|
" if func_lines:\n",
|
|
" lines = code.split('\\n')\n",
|
|
" lines.insert(func_lines[0], scatter_safety)\n",
|
|
" code = '\\n'.join(lines)\n",
|
|
" code = code.replace(\"plt.scatter(\", \"safe_scatter(plt.gca(), \")\n",
|
|
" code = code.replace(\".scatter(\", \"safe_scatter(\")\n",
|
|
" if \"yfinance\" in code and \"generate_synthetic_data\" not in code:\n",
|
|
" synthetic_data_func = \"\"\"\n",
|
|
"def generate_synthetic_data(ticker='AAPL', start_date=None, end_date=None, days=252, seed=42):\n",
|
|
" import numpy as np\n",
|
|
" import pandas as pd\n",
|
|
" from datetime import datetime, timedelta\n",
|
|
" if start_date is None:\n",
|
|
" end_date = datetime.now()\n",
|
|
" start_date = end_date - timedelta(days=days)\n",
|
|
" elif end_date is None:\n",
|
|
" if isinstance(start_date, str):\n",
|
|
" start_date = pd.to_datetime(start_date)\n",
|
|
" end_date = datetime.now()\n",
|
|
" np.random.seed(seed)\n",
|
|
" if isinstance(start_date, str):\n",
|
|
" start = pd.to_datetime(start_date)\n",
|
|
" else:\n",
|
|
" start = start_date\n",
|
|
" if isinstance(end_date, str):\n",
|
|
" end = pd.to_datetime(end_date)\n",
|
|
" else:\n",
|
|
" end = end_date\n",
|
|
" days = (end - start).days + 1\n",
|
|
" price = 100\n",
|
|
" prices = [price]\n",
|
|
" for _ in range(days):\n",
|
|
" change = np.random.normal(0, 0.01)\n",
|
|
" price *= (1 + change)\n",
|
|
" prices.append(price)\n",
|
|
" dates = pd.date_range(start=start, end=end, periods=len(prices))\n",
|
|
" df = pd.DataFrame({\n",
|
|
" 'Open': prices[:-1],\n",
|
|
" 'High': [p * 1.01 for p in prices[:-1]],\n",
|
|
" 'Low': [p * 0.99 for p in prices[:-1]],\n",
|
|
" 'Close': prices[1:],\n",
|
|
" 'Volume': [np.random.randint(1000000, 10000000) for _ in range(len(prices)-1)]\n",
|
|
" }, index=dates[:-1])\n",
|
|
" return df\n",
|
|
"\"\"\"\n",
|
|
" func_lines = [i for i, line in enumerate(code.split('\\n')) if line.startswith('def ')]\n",
|
|
" if func_lines:\n",
|
|
" lines = code.split('\\n')\n",
|
|
" lines.insert(func_lines[0], synthetic_data_func)\n",
|
|
" code = '\\n'.join(lines)\n",
|
|
" if \"logging\" in code and \"basicConfig\" not in code:\n",
|
|
" logging_config = \"\"\"\n",
|
|
"import logging\n",
|
|
"logging.basicConfig(\n",
|
|
" level=logging.INFO,\n",
|
|
" format='[%(asctime)s] %(levelname)s: %(message)s',\n",
|
|
" datefmt='%H:%M:%S'\n",
|
|
")\n",
|
|
"\"\"\"\n",
|
|
" import_lines = [i for i, line in enumerate(code.split('\\n')) if 'import' in line]\n",
|
|
" if import_lines:\n",
|
|
" lines = code.split('\\n')\n",
|
|
" lines.insert(import_lines[-1] + 1, logging_config)\n",
|
|
" code = '\\n'.join(lines)\n",
|
|
" if \"yfinance\" in code and \"try:\" not in code:\n",
|
|
" lines = code.split('\\n')\n",
|
|
" for i, line in enumerate(lines):\n",
|
|
" if \"yf.download\" in line and \"try:\" not in lines[max(0, i-5):i]:\n",
|
|
" indent = len(line) - len(line.lstrip())\n",
|
|
" indent_str = \" \" * indent\n",
|
|
" lines[i] = f\"{indent_str}try:\\n{indent_str} {line}\\n{indent_str}except Exception as e:\\n{indent_str} logging.error(f\\\"Error fetching data: {{e}}\\\")\\n{indent_str} # Use synthetic data as fallback\\n{indent_str} data = generate_synthetic_data(ticker, start_date, end_date)\"\n",
|
|
" code = '\\n'.join(lines)\n",
|
|
" break\n",
|
|
" if \"synthetic data\" in code.lower() and \"yf.download\" in code:\n",
|
|
" lines = code.split('\\n')\n",
|
|
" for i, line in enumerate(lines):\n",
|
|
" if \"yf.download\" in line:\n",
|
|
" indent = len(line) - len(line.lstrip())\n",
|
|
" indent_str = \" \" * indent\n",
|
|
" comment = f\"{indent_str}# Using synthetic data instead of API call\\n\"\n",
|
|
" synthetic = f\"{indent_str}data = generate_synthetic_data(ticker, start_date, end_date)\\n\"\n",
|
|
" lines[i] = f\"{indent_str}# {line.strip()} # Commented out to avoid API issues\"\n",
|
|
" lines.insert(i+1, comment + synthetic)\n",
|
|
" code = '\\n'.join(lines)\n",
|
|
" break\n",
|
|
" if \"plt.figure\" in code:\n",
|
|
" lines = code.split('\\n')\n",
|
|
" for i, line in enumerate(lines):\n",
|
|
" if \"plt.figure\" in line and \"try:\" not in lines[max(0, i-5):i]:\n",
|
|
" indent = len(line) - len(line.lstrip())\n",
|
|
" indent_str = \" \" * indent\n",
|
|
" try_line = f\"{indent_str}try:\\n{indent_str} \"\n",
|
|
" except_line = f\"\\n{indent_str}except Exception as e:\\n{indent_str} logging.error(f\\\"Error in plotting: {{e}}\\\")\"\n",
|
|
" j = i\n",
|
|
" while j < len(lines) and (j == i or lines[j].startswith(indent_str)):\n",
|
|
" j += 1\n",
|
|
" for k in range(i, j):\n",
|
|
" if lines[k].strip():\n",
|
|
" lines[k] = indent_str + \" \" + lines[k].lstrip()\n",
|
|
" lines.insert(i, try_line.rstrip())\n",
|
|
" lines.insert(j+1, except_line)\n",
|
|
" code = '\\n'.join(lines)\n",
|
|
" break\n",
|
|
" lines = code.split('\\n')\n",
|
|
" for i, line in enumerate(lines):\n",
|
|
" if \"print(\" in line and any(term in line for term in ['data[', '.iloc', '.loc', 'Series', 'DataFrame']):\n",
|
|
" lines[i] = line.replace(\"print(\", \"print(safe_format(\")\n",
|
|
" if \"))\" not in lines[i] and \"),\" in lines[i]:\n",
|
|
" lines[i] = lines[i].replace(\"),\", \")),\", 1)\n",
|
|
" elif \"))\" not in lines[i] and \")\" in lines[i]:\n",
|
|
" lines[i] = lines[i].replace(\")\", \"))\", 1)\n",
|
|
" code = '\\n'.join(lines)\n",
|
|
" return code\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 114,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def run_python(code):\n",
|
|
" # Create a completely separate namespace for execution\n",
|
|
" namespace = {\n",
|
|
" '__name__': '__main__',\n",
|
|
" '__builtins__': __builtins__\n",
|
|
" }\n",
|
|
" \n",
|
|
" # Modify the code to use a non-interactive matplotlib backend\n",
|
|
" # and fix pandas formatting issues\n",
|
|
" modified_code = \"\"\"\n",
|
|
"import matplotlib\n",
|
|
"matplotlib.use('Agg') # Use non-interactive backend\n",
|
|
"\n",
|
|
"# Import yfinance without setting timeout (not available in all versions)\n",
|
|
"import yfinance as yf\n",
|
|
"\n",
|
|
"# Configure logging to show in the output\n",
|
|
"import logging\n",
|
|
"logging.basicConfig(\n",
|
|
" level=logging.INFO,\n",
|
|
" format='[%(asctime)s] %(levelname)s: %(message)s',\n",
|
|
" datefmt='%H:%M:%S'\n",
|
|
")\n",
|
|
"\n",
|
|
"# Fix pandas formatting issues\n",
|
|
"import pandas as pd\n",
|
|
"pd.set_option('display.float_format', '{:.5f}'.format)\n",
|
|
"\n",
|
|
"# Override print to ensure it flushes immediately\n",
|
|
"import builtins\n",
|
|
"original_print = builtins.print\n",
|
|
"def custom_print(*args, **kwargs):\n",
|
|
" result = original_print(*args, **kwargs)\n",
|
|
" import sys\n",
|
|
" sys.stdout.flush()\n",
|
|
" return result\n",
|
|
"builtins.print = custom_print\n",
|
|
"\n",
|
|
"# Helper function to safely format pandas objects\n",
|
|
"def safe_format(obj):\n",
|
|
" if isinstance(obj, (pd.Series, pd.DataFrame)):\n",
|
|
" return str(obj)\n",
|
|
" else:\n",
|
|
" return obj\n",
|
|
"\"\"\"\n",
|
|
" \n",
|
|
" # Add the user's code\n",
|
|
" modified_code += \"\\n\" + code\n",
|
|
" \n",
|
|
" # Capture all output\n",
|
|
" output_buffer = io.StringIO()\n",
|
|
" \n",
|
|
" # Save original stdout and redirect to our buffer\n",
|
|
" original_stdout = sys.stdout\n",
|
|
" sys.stdout = output_buffer\n",
|
|
" \n",
|
|
" # Add timestamp for execution start\n",
|
|
" print(f\"[{time.strftime('%H:%M:%S')}] Executing code...\")\n",
|
|
" \n",
|
|
" try:\n",
|
|
" # Execute the modified code\n",
|
|
" exec(modified_code, namespace)\n",
|
|
" print(f\"\\n[{time.strftime('%H:%M:%S')}] Execution completed successfully.\")\n",
|
|
" \n",
|
|
" except ModuleNotFoundError as e:\n",
|
|
" missing_module = str(e).split(\"'\")[1]\n",
|
|
" print(f\"\\nError: Missing module '{missing_module}'. Click 'Install Dependencies' to install it.\")\n",
|
|
" namespace[\"__missing_module__\"] = missing_module\n",
|
|
" \n",
|
|
" except Exception as e:\n",
|
|
" print(f\"\\n[{time.strftime('%H:%M:%S')}] Error during execution: {str(e)}\")\n",
|
|
" import traceback\n",
|
|
" print(traceback.format_exc())\n",
|
|
" \n",
|
|
" finally:\n",
|
|
" # Restore original stdout\n",
|
|
" sys.stdout = original_stdout\n",
|
|
" \n",
|
|
" # Return the captured output\n",
|
|
" return output_buffer.getvalue()\n",
|
|
"\n",
|
|
"def install_dependencies(code):\n",
|
|
" import re\n",
|
|
" import subprocess\n",
|
|
" \n",
|
|
" import_pattern = r'(?:from|import)\\s+([a-zA-Z0-9_]+)(?:\\s+(?:import|as))?'\n",
|
|
" imports = re.findall(import_pattern, code)\n",
|
|
" \n",
|
|
" std_libs = ['os', 'sys', 'io', 'time', 'datetime', 'random', 'math', 're', 'json', \n",
|
|
" 'collections', 'itertools', 'functools', 'operator', 'pathlib', 'typing']\n",
|
|
" \n",
|
|
" modules_to_install = [module for module in imports if module not in std_libs]\n",
|
|
" \n",
|
|
" if not modules_to_install:\n",
|
|
" return \"No external dependencies found to install.\"\n",
|
|
" \n",
|
|
" results = []\n",
|
|
" for module in modules_to_install:\n",
|
|
" try:\n",
|
|
" result = subprocess.run(\n",
|
|
" [sys.executable, \"-m\", \"pip\", \"install\", module],\n",
|
|
" capture_output=True,\n",
|
|
" text=True,\n",
|
|
" check=False\n",
|
|
" )\n",
|
|
" \n",
|
|
" if result.returncode == 0:\n",
|
|
" results.append(f\"Successfully installed {module}\")\n",
|
|
" else:\n",
|
|
" results.append(f\"Failed to install {module}: {result.stderr}\")\n",
|
|
" except Exception as e:\n",
|
|
" results.append(f\"Error installing {module}: {str(e)}\")\n",
|
|
" \n",
|
|
" return \"\\n\".join(results)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 109,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"trading_strategies = [\n",
|
|
" {\n",
|
|
" \"name\": \"Moving Average Crossover\",\n",
|
|
" \"description\": \"Moving Average Crossover strategy for S&P 500 stocks. Buy when the 20-day moving average crosses above the 50-day moving average, and sell when it crosses below.\",\n",
|
|
" \"buy_signal\": \"20-day MA crosses above 50-day MA\",\n",
|
|
" \"sell_signal\": \"20-day MA crosses below 50-day MA\",\n",
|
|
" \"timeframe\": \"Daily\",\n",
|
|
" \"risk_level\": \"Medium\"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"name\": \"RSI Mean Reversion\",\n",
|
|
" \"description\": \"Mean reversion strategy that buys stocks when RSI falls below 30 (oversold) and sells when RSI rises above 70 (overbought).\",\n",
|
|
" \"buy_signal\": \"RSI below 30 (oversold)\",\n",
|
|
" \"sell_signal\": \"RSI above 70 (overbought)\",\n",
|
|
" \"timeframe\": \"Daily\",\n",
|
|
" \"risk_level\": \"Medium\"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"name\": \"Momentum Strategy\",\n",
|
|
" \"description\": \"Momentum strategy that buys the top 5 performing stocks from the Dow Jones Industrial Average over the past month and rebalances monthly.\",\n",
|
|
" \"buy_signal\": \"Stock in top 5 performers over past month\",\n",
|
|
" \"sell_signal\": \"Stock no longer in top 5 performers at rebalance\",\n",
|
|
" \"timeframe\": \"Monthly\",\n",
|
|
" \"risk_level\": \"High\"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"name\": \"Pairs Trading\",\n",
|
|
" \"description\": \"Pairs trading strategy that identifies correlated stock pairs and trades on the divergence and convergence of their price relationship.\",\n",
|
|
" \"buy_signal\": \"Pairs ratio deviates 2+ standard deviations below mean\",\n",
|
|
" \"sell_signal\": \"Pairs ratio returns to mean or exceeds mean\",\n",
|
|
" \"timeframe\": \"Daily\",\n",
|
|
" \"risk_level\": \"Medium-High\"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"name\": \"Bollinger Band Breakout\",\n",
|
|
" \"description\": \"Volatility breakout strategy that buys when a stock breaks out of its upper Bollinger Band and sells when it reverts to the mean.\",\n",
|
|
" \"buy_signal\": \"Price breaks above upper Bollinger Band (2 std dev)\",\n",
|
|
" \"sell_signal\": \"Price reverts to middle Bollinger Band (SMA)\",\n",
|
|
" \"timeframe\": \"Daily\",\n",
|
|
" \"risk_level\": \"High\"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"name\": \"MACD Crossover\",\n",
|
|
" \"description\": \"MACD crossover strategy that buys when the MACD line crosses above the signal line and sells when it crosses below.\",\n",
|
|
" \"buy_signal\": \"MACD line crosses above signal line\",\n",
|
|
" \"sell_signal\": \"MACD line crosses below signal line\",\n",
|
|
" \"timeframe\": \"Daily\",\n",
|
|
" \"risk_level\": \"Medium\"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"name\": \"Golden Cross\",\n",
|
|
" \"description\": \"Golden Cross strategy that buys when the 50-day moving average crosses above the 200-day moving average and sells on the Death Cross (opposite).\",\n",
|
|
" \"buy_signal\": \"50-day MA crosses above 200-day MA\",\n",
|
|
" \"sell_signal\": \"50-day MA crosses below 200-day MA\",\n",
|
|
" \"timeframe\": \"Daily\",\n",
|
|
" \"risk_level\": \"Low\"\n",
|
|
" }\n",
|
|
"]\n",
|
|
"\n",
|
|
"sample_strategies = [strategy[\"description\"] for strategy in trading_strategies]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 110,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"default_description = \"\"\"\n",
|
|
"Create a moving average crossover strategy with the following specifications:\n",
|
|
"- Use yfinance to download historical data for a list of stocks (AAPL, MSFT, AMZN, GOOGL, META)\n",
|
|
"- Calculate 20-day and 50-day moving averages\n",
|
|
"- Generate buy signals when the 20-day MA crosses above the 50-day MA\n",
|
|
"- Generate sell signals when the 20-day MA crosses below the 50-day MA\n",
|
|
"- Implement a simple backtesting framework to evaluate the strategy\n",
|
|
"- Calculate performance metrics: total return, annualized return, Sharpe ratio, max drawdown\n",
|
|
"- Visualize the equity curve, buy/sell signals, and moving averages\n",
|
|
"\"\"\"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2025-10-22 18:30:21,233 - INFO - HTTP Request: GET http://127.0.0.1:7875/gradio_api/startup-events \"HTTP/1.1 200 OK\"\n",
|
|
"2025-10-22 18:30:21,238 - INFO - HTTP Request: HEAD http://127.0.0.1:7875/ \"HTTP/1.1 200 OK\"\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"* Running on local URL: http://127.0.0.1:7875\n",
|
|
"* To create a public link, set `share=True` in `launch()`.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div><iframe src=\"http://127.0.0.1:7875/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": []
|
|
},
|
|
"execution_count": 115,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2025-10-22 18:30:24,092 - INFO - HTTP Request: GET https://api.gradio.app/pkg-version \"HTTP/1.1 200 OK\"\n",
|
|
"2025-10-22 18:31:03,437 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
|
|
"2025-10-22 18:31:15,743 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
|
|
"2025-10-22 18:31:28,425 - ERROR - Error fetching data: periods must be a number, got 2025-10-22 18:31:28.425210\n",
|
|
"2025-10-22 18:31:28,429 - INFO - Synthetic data generated for tickers: AAPL\n",
|
|
"2025-10-22 18:31:28,432 - INFO - Moving averages calculated with windows 20 and 50\n",
|
|
"2025-10-22 18:31:28,434 - INFO - Signals generated based on moving average crossover\n",
|
|
"2025-10-22 18:31:28,438 - INFO - Performance calculated\n",
|
|
"2025-10-22 18:31:28,438 - INFO - Total Return: -0.010752455100331848, Sharpe Ratio: 0.18162435507214664, Max Drawdown: -0.19919271751608258\n",
|
|
"2025-10-22 18:32:28,496 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
|
|
"2025-10-22 18:32:38,626 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
|
|
"2025-10-22 18:32:47,779 - ERROR - Error fetching data from yfinance: name 'start_date' is not defined. Using synthetic data.\n",
|
|
"2025-10-22 18:33:23,647 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
|
|
"2025-10-22 18:33:37,829 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"with gr.Blocks(css=CSS, theme=gr.themes.Monochrome(), title=\"Trading Code Generator\") as ui:\n",
|
|
" with gr.Row():\n",
|
|
" gr.HTML(\"<h1 style='text-align: center; margin-bottom: 0.5rem;'>Trading Strategy Code Generator</h1>\")\n",
|
|
" \n",
|
|
" with gr.Row():\n",
|
|
" # Left column - Controls\n",
|
|
" with gr.Column(scale=1):\n",
|
|
" strategy_dropdown = gr.Dropdown(\n",
|
|
" label=\"Select Trading Strategy\",\n",
|
|
" choices=[strategy[\"name\"] for strategy in trading_strategies],\n",
|
|
" value=trading_strategies[0][\"name\"]\n",
|
|
" )\n",
|
|
" \n",
|
|
" with gr.Accordion(\"Strategy Details\", open=False):\n",
|
|
" strategy_info = gr.JSON(\n",
|
|
" value=trading_strategies[0]\n",
|
|
" )\n",
|
|
" \n",
|
|
" model = gr.Dropdown(\n",
|
|
" label=\"Select Model\",\n",
|
|
" choices=models,\n",
|
|
" value=models[0]\n",
|
|
" )\n",
|
|
" \n",
|
|
" description = gr.TextArea(\n",
|
|
" label=\"Strategy Description (Edit to customize)\",\n",
|
|
" value=trading_strategies[0][\"description\"],\n",
|
|
" lines=4\n",
|
|
" )\n",
|
|
" \n",
|
|
" with gr.Row():\n",
|
|
" generate = gr.Button(\"Generate Code\", variant=\"primary\", size=\"sm\")\n",
|
|
" run = gr.Button(\"Run Code\", size=\"sm\")\n",
|
|
" install_deps = gr.Button(\"Install Dependencies\", size=\"sm\")\n",
|
|
" \n",
|
|
" # Right column - Code and Output\n",
|
|
" with gr.Column(scale=2):\n",
|
|
" trading_code = gr.Code(\n",
|
|
" label=\"Generated Trading Code\",\n",
|
|
" value=\"\",\n",
|
|
" language=\"python\",\n",
|
|
" lines=20,\n",
|
|
" elem_id=\"code-block\",\n",
|
|
" show_label=True\n",
|
|
" )\n",
|
|
" \n",
|
|
" output = gr.TextArea(\n",
|
|
" label=\"Execution Output\",\n",
|
|
" lines=8,\n",
|
|
" elem_classes=[\"trading-out\"]\n",
|
|
" )\n",
|
|
" \n",
|
|
" def update_strategy_info(strategy_name):\n",
|
|
" selected = next((s for s in trading_strategies if s[\"name\"] == strategy_name), None)\n",
|
|
" if selected:\n",
|
|
" return selected, selected[\"description\"]\n",
|
|
" return trading_strategies[0], trading_strategies[0][\"description\"]\n",
|
|
" \n",
|
|
" strategy_dropdown.change(\n",
|
|
" fn=update_strategy_info,\n",
|
|
" inputs=strategy_dropdown,\n",
|
|
" outputs=[strategy_info, description]\n",
|
|
" )\n",
|
|
" \n",
|
|
" # Function to show validation results when generating code\n",
|
|
" def generate_with_validation(model, description):\n",
|
|
" # Always use GPT-4o for better code quality\n",
|
|
" code = generate_trading_code(model, description, force_gpt4=True)\n",
|
|
" is_valid, issues = validate_code(code)\n",
|
|
" \n",
|
|
" validation_message = \"\"\n",
|
|
" if is_valid:\n",
|
|
" validation_message = \"Code validation passed ✓\"\n",
|
|
" else:\n",
|
|
" validation_message = \"Code validation warnings:\\n\" + \"\\n\".join([f\"- {issue}\" for issue in issues])\n",
|
|
" \n",
|
|
" return code, validation_message\n",
|
|
" \n",
|
|
" generate.click(\n",
|
|
" fn=generate_with_validation,\n",
|
|
" inputs=[model, description],\n",
|
|
" outputs=[trading_code, output]\n",
|
|
" )\n",
|
|
" \n",
|
|
" run.click(\n",
|
|
" fn=run_python,\n",
|
|
" inputs=[trading_code],\n",
|
|
" outputs=[output]\n",
|
|
" )\n",
|
|
" \n",
|
|
" install_deps.click(\n",
|
|
" fn=install_dependencies,\n",
|
|
" inputs=[trading_code],\n",
|
|
" outputs=[output]\n",
|
|
" )\n",
|
|
"\n",
|
|
"ui.launch(inbrowser=True)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Testing the Trading Code Generator\n",
|
|
"\n",
|
|
"Let's test the trading code generator with a specific strategy description and model.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"test_description = \"\"\"\n",
|
|
"Create a simple RSI-based mean reversion strategy:\n",
|
|
"- Use AAPL stock data for the past 2 years\n",
|
|
"- Calculate the 14-day RSI indicator\n",
|
|
"- Buy when RSI falls below 30 (oversold)\n",
|
|
"- Sell when RSI rises above 70 (overbought)\n",
|
|
"- Include visualization of entry/exit points\n",
|
|
"- Calculate performance metrics\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"test_model = \"gpt-3.5-turbo\"\n",
|
|
"\n",
|
|
"generated_code = generate_trading_code(test_model, test_description)\n",
|
|
"print(\"Generated trading code:\")\n",
|
|
"print(generated_code)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"try:\n",
|
|
" output = run_python(generated_code)\n",
|
|
" print(output)\n",
|
|
"except Exception as e:\n",
|
|
" print(f\"Error running the generated code: {e}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Fixed version of the code\n",
|
|
"fixed_code = \"\"\"\n",
|
|
"import yfinance as yf\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import matplotlib.dates as mdates\n",
|
|
"from datetime import datetime, timedelta\n",
|
|
"import logging\n",
|
|
"\n",
|
|
"# Set up logging\n",
|
|
"logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
|
|
"\n",
|
|
"def calculate_moving_averages(data, short_window=20, long_window=50):\n",
|
|
" \\\"\\\"\\\"\n",
|
|
" Calculate short and long term moving averages\n",
|
|
" \\\"\\\"\\\"\n",
|
|
" data['Short_MA'] = data['Close'].rolling(window=short_window, min_periods=1).mean()\n",
|
|
" data['Long_MA'] = data['Close'].rolling(window=long_window, min_periods=1).mean()\n",
|
|
" return data\n",
|
|
"\n",
|
|
"def generate_signals(data, short_window=20, long_window=50):\n",
|
|
" \\\"\\\"\\\"\n",
|
|
" Generate buy/sell signals based on moving average crossover strategy\n",
|
|
" \\\"\\\"\\\"\n",
|
|
" data['Signal'] = 0\n",
|
|
" data['Signal'][short_window:] = np.where(\n",
|
|
" data['Short_MA'][short_window:] > data['Long_MA'][short_window:], 1, -1)\n",
|
|
" data['Position'] = data['Signal'].shift(1)\n",
|
|
" data['Position'].fillna(0, inplace=True) # Fill NaN values with 0\n",
|
|
" return data\n",
|
|
"\n",
|
|
"def backtest_strategy(data):\n",
|
|
" \\\"\\\"\\\"\n",
|
|
" Backtest the trading strategy and calculate performance metrics\n",
|
|
" \\\"\\\"\\\"\n",
|
|
" data['Returns'] = data['Close'].pct_change()\n",
|
|
" data['Strategy_Returns'] = data['Returns'] * data['Position']\n",
|
|
" \n",
|
|
" # Replace NaN values with 0\n",
|
|
" data['Strategy_Returns'].fillna(0, inplace=True)\n",
|
|
"\n",
|
|
" cumulative_returns = (1 + data['Strategy_Returns']).cumprod()\n",
|
|
" \n",
|
|
" # Calculate metrics\n",
|
|
" total_return = cumulative_returns.iloc[-1] - 1\n",
|
|
" sharpe_ratio = np.sqrt(252) * (data['Strategy_Returns'].mean() / data['Strategy_Returns'].std())\n",
|
|
" max_drawdown = ((cumulative_returns / cumulative_returns.cummax()) - 1).min()\n",
|
|
"\n",
|
|
" metrics = {\n",
|
|
" 'Total Return': total_return,\n",
|
|
" 'Sharpe Ratio': sharpe_ratio,\n",
|
|
" 'Max Drawdown': max_drawdown\n",
|
|
" }\n",
|
|
"\n",
|
|
" return cumulative_returns, metrics\n",
|
|
"\n",
|
|
"def plot_results(data, cumulative_returns, ticker):\n",
|
|
" \\\"\\\"\\\"\n",
|
|
" Plot the performance of the trading strategy\n",
|
|
" \\\"\\\"\\\"\n",
|
|
" fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), gridspec_kw={'height_ratios': [2, 1]})\n",
|
|
" \n",
|
|
" # Price and MA plot\n",
|
|
" ax1.plot(data.index, data['Close'], label='Close Price')\n",
|
|
" ax1.plot(data.index, data['Short_MA'], label='20-day MA', alpha=0.7)\n",
|
|
" ax1.plot(data.index, data['Long_MA'], label='50-day MA', alpha=0.7)\n",
|
|
" \n",
|
|
" # Add buy/sell signals\n",
|
|
" buy_signals = data[data['Signal'] > data['Signal'].shift(1)]\n",
|
|
" sell_signals = data[data['Signal'] < data['Signal'].shift(1)]\n",
|
|
" \n",
|
|
" ax1.scatter(buy_signals.index, buy_signals['Close'], marker='^', color='green', s=100, label='Buy Signal')\n",
|
|
" ax1.scatter(sell_signals.index, sell_signals['Close'], marker='v', color='red', s=100, label='Sell Signal')\n",
|
|
" \n",
|
|
" ax1.set_title(f'Moving Average Crossover Strategy on {ticker}')\n",
|
|
" ax1.set_ylabel('Price ($)')\n",
|
|
" ax1.legend(loc='best')\n",
|
|
" ax1.grid(True)\n",
|
|
" \n",
|
|
" # Returns plot\n",
|
|
" ax2.plot(cumulative_returns.index, cumulative_returns, label='Cumulative Strategy Returns', color='blue')\n",
|
|
" ax2.set_title('Cumulative Returns')\n",
|
|
" ax2.set_xlabel('Date')\n",
|
|
" ax2.set_ylabel('Returns')\n",
|
|
" ax2.legend(loc='best')\n",
|
|
" ax2.grid(True)\n",
|
|
" \n",
|
|
" plt.tight_layout()\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"if __name__ == \\\"__main__\\\":\n",
|
|
" # User inputs\n",
|
|
" ticker = 'SPY' # Example: S&P 500 ETF\n",
|
|
" start_date = (datetime.now() - timedelta(days=365*2)).strftime('%Y-%m-%d')\n",
|
|
" end_date = datetime.now().strftime('%Y-%m-%d')\n",
|
|
"\n",
|
|
" # Strategy parameters\n",
|
|
" short_window = 20\n",
|
|
" long_window = 50\n",
|
|
"\n",
|
|
" # Fetch data\n",
|
|
" try:\n",
|
|
" logging.info(f\\\"Fetching data for {ticker} from {start_date} to {end_date}...\\\")\n",
|
|
" stock_data = yf.download(ticker, start=start_date, end=end_date)\n",
|
|
" logging.info(f\\\"Data fetched successfully. Got {len(stock_data)} data points.\\\")\n",
|
|
" except Exception as e:\n",
|
|
" logging.error(f\\\"Failed to fetch data: {e}\\\")\n",
|
|
" raise SystemExit(e)\n",
|
|
"\n",
|
|
" try:\n",
|
|
" # Preprocess and generate signals\n",
|
|
" stock_data = calculate_moving_averages(stock_data, short_window, long_window)\n",
|
|
" stock_data = generate_signals(stock_data, short_window, long_window)\n",
|
|
"\n",
|
|
" # Backtest the strategy\n",
|
|
" cumulative_returns, metrics = backtest_strategy(stock_data)\n",
|
|
"\n",
|
|
" # Display metrics\n",
|
|
" for key, value in metrics.items():\n",
|
|
" logging.info(f\\\"{key}: {value:.4f}\\\")\n",
|
|
"\n",
|
|
" # Plot results\n",
|
|
" plot_results(stock_data, cumulative_returns, ticker)\n",
|
|
" except Exception as e:\n",
|
|
" logging.error(f\\\"Error while executing strategy: {e}\\\")\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"# Display the fixed code\n",
|
|
"print(\"Fixed code:\")\n",
|
|
"print(fixed_code)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Run the fixed code\n",
|
|
"output = run_python(fixed_code)\n",
|
|
"print(output)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's also update our system_prompt to ensure the generated code works properly\n",
|
|
"\n",
|
|
"system_prompt = \"\"\"\n",
|
|
"You are an expert algorithmic trading code generator. Your task is to generate Python code for trading strategies based on user requirements.\n",
|
|
"The code should be well-structured, efficient, and ready to run in a simulated environment.\n",
|
|
"\n",
|
|
"The generated code should:\n",
|
|
"1. Use the yfinance library for fetching stock data\n",
|
|
"2. Implement the specified trading strategy\n",
|
|
"3. Include proper error handling and logging\n",
|
|
"4. Include visualization of the strategy performance with clear buy/sell signals\n",
|
|
"5. Calculate and display relevant metrics (returns, Sharpe ratio, drawdown, etc.)\n",
|
|
"6. Handle NaN values and edge cases properly\n",
|
|
"7. Include informative print statements or logging to show progress\n",
|
|
"\n",
|
|
"IMPORTANT: Make sure all variables are properly defined before use, especially in functions.\n",
|
|
"Always pass necessary parameters between functions rather than relying on global variables.\n",
|
|
"\n",
|
|
"Respond only with Python code. Do not provide any explanation other than occasional comments in the code.\n",
|
|
"\"\"\"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's test the updated system prompt with a simple strategy\n",
|
|
"\n",
|
|
"test_description_2 = \"\"\"\n",
|
|
"Create a simple Bollinger Bands strategy:\n",
|
|
"- Use AAPL stock data for the past 1 year\n",
|
|
"- Calculate Bollinger Bands with 20-day SMA and 2 standard deviations\n",
|
|
"- Buy when price touches the lower band\n",
|
|
"- Sell when price touches the upper band\n",
|
|
"- Include visualization of entry/exit points\n",
|
|
"- Calculate performance metrics\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"test_model = \"gpt-3.5-turbo\"\n",
|
|
"generated_code_2 = generate_trading_code(test_model, test_description_2)\n",
|
|
"print(\"Generated trading code with updated prompt:\")\n",
|
|
"print(generated_code_2)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's test the run function with live logging\n",
|
|
"\n",
|
|
"test_code = \"\"\"\n",
|
|
"import yfinance as yf\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from datetime import datetime, timedelta\n",
|
|
"\n",
|
|
"# Define ticker and date range\n",
|
|
"ticker = 'AAPL'\n",
|
|
"end_date = datetime.now()\n",
|
|
"start_date = end_date - timedelta(days=365)\n",
|
|
"\n",
|
|
"# Download data\n",
|
|
"print(f\"Downloading data for {ticker}...\")\n",
|
|
"data = yf.download(ticker, start=start_date, end=end_date)\n",
|
|
"print(f\"Downloaded {len(data)} rows of data\")\n",
|
|
"\n",
|
|
"# Calculate RSI\n",
|
|
"print(\"Calculating RSI...\")\n",
|
|
"delta = data['Close'].diff()\n",
|
|
"gain = delta.where(delta > 0, 0)\n",
|
|
"loss = -delta.where(delta < 0, 0)\n",
|
|
"avg_gain = gain.rolling(window=14).mean()\n",
|
|
"avg_loss = loss.rolling(window=14).mean()\n",
|
|
"rs = avg_gain / avg_loss\n",
|
|
"data['RSI'] = 100 - (100 / (1 + rs))\n",
|
|
"\n",
|
|
"# Generate signals\n",
|
|
"print(\"Generating trading signals...\")\n",
|
|
"data['Signal'] = 0\n",
|
|
"data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n",
|
|
"data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n",
|
|
"\n",
|
|
"# Count signals\n",
|
|
"buy_signals = len(data[data['Signal'] == 1])\n",
|
|
"sell_signals = len(data[data['Signal'] == -1])\n",
|
|
"print(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n",
|
|
"\n",
|
|
"# Print sample of the data\n",
|
|
"print(\"\\\\nSample of the processed data:\")\n",
|
|
"print(data[['Close', 'RSI', 'Signal']].tail())\n",
|
|
"\n",
|
|
"print(\"\\\\nAnalysis complete!\")\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"output = run_python(test_code)\n",
|
|
"print(output)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's test the improved run function again\n",
|
|
"\n",
|
|
"test_code_2 = \"\"\"\n",
|
|
"import yfinance as yf\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from datetime import datetime, timedelta\n",
|
|
"\n",
|
|
"# Define ticker and date range\n",
|
|
"ticker = 'AAPL'\n",
|
|
"end_date = datetime.now()\n",
|
|
"start_date = end_date - timedelta(days=365)\n",
|
|
"\n",
|
|
"# Download data\n",
|
|
"print(f\"Downloading data for {ticker}...\")\n",
|
|
"data = yf.download(ticker, start=start_date, end=end_date, progress=False)\n",
|
|
"print(f\"Downloaded {len(data)} rows of data\")\n",
|
|
"\n",
|
|
"# Calculate RSI\n",
|
|
"print(\"Calculating RSI...\")\n",
|
|
"delta = data['Close'].diff()\n",
|
|
"gain = delta.where(delta > 0, 0)\n",
|
|
"loss = -delta.where(delta < 0, 0)\n",
|
|
"avg_gain = gain.rolling(window=14).mean()\n",
|
|
"avg_loss = loss.rolling(window=14).mean()\n",
|
|
"rs = avg_gain / avg_loss\n",
|
|
"data['RSI'] = 100 - (100 / (1 + rs))\n",
|
|
"\n",
|
|
"# Generate signals\n",
|
|
"print(\"Generating trading signals...\")\n",
|
|
"data['Signal'] = 0\n",
|
|
"data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n",
|
|
"data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n",
|
|
"\n",
|
|
"# Count signals\n",
|
|
"buy_signals = len(data[data['Signal'] == 1])\n",
|
|
"sell_signals = len(data[data['Signal'] == -1])\n",
|
|
"print(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n",
|
|
"\n",
|
|
"# Print sample of the data\n",
|
|
"print(\"\\\\nSample of the processed data:\")\n",
|
|
"print(data[['Close', 'RSI', 'Signal']].tail())\n",
|
|
"\n",
|
|
"print(\"\\\\nAnalysis complete!\")\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"output = run_python(test_code_2)\n",
|
|
"print(output)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test the completely rewritten run function\n",
|
|
"\n",
|
|
"simple_test = \"\"\"\n",
|
|
"print(\"Hello from the trading code generator!\")\n",
|
|
"print(\"Testing output capture...\")\n",
|
|
"\n",
|
|
"# Simulate some data processing\n",
|
|
"import numpy as np\n",
|
|
"data = np.random.rand(5, 3)\n",
|
|
"print(\"Generated random data:\")\n",
|
|
"print(data)\n",
|
|
"\n",
|
|
"# Show a calculation\n",
|
|
"result = np.mean(data, axis=0)\n",
|
|
"print(\"Mean of each column:\")\n",
|
|
"print(result)\n",
|
|
"\n",
|
|
"print(\"Test complete!\")\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"output = run_python(simple_test)\n",
|
|
"print(\"Output from execution:\")\n",
|
|
"print(output)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test the improved code generation with a strategy that typically causes formatting issues\n",
|
|
"\n",
|
|
"test_description_3 = \"\"\"\n",
|
|
"Create a simple trading strategy that:\n",
|
|
"- Uses AAPL stock data for the past year\n",
|
|
"- Calculates both RSI and Bollinger Bands\n",
|
|
"- Buys when price is below lower Bollinger Band AND RSI is below 30\n",
|
|
"- Sells when price is above upper Bollinger Band OR RSI is above 70\n",
|
|
"- Includes proper error handling for all calculations\n",
|
|
"- Visualizes the entry/exit points and performance\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"test_model = \"gpt-3.5-turbo\"\n",
|
|
"generated_code_3 = generate_trading_code(test_model, test_description_3)\n",
|
|
"print(\"Generated trading code with enhanced validation:\")\n",
|
|
"print(generated_code_3)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's test running the generated code\n",
|
|
"\n",
|
|
"output = run_python(generated_code_3)\n",
|
|
"print(\"Execution output:\")\n",
|
|
"print(output)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test the improved run function with a simple example that prints output\n",
|
|
"\n",
|
|
"simple_test_2 = \"\"\"\n",
|
|
"print(\"This is a test of the output capture system\")\n",
|
|
"print(\"Line 1 of output\")\n",
|
|
"print(\"Line 2 of output\")\n",
|
|
"\n",
|
|
"# Import and use numpy\n",
|
|
"import numpy as np\n",
|
|
"data = np.random.rand(3, 3)\n",
|
|
"print(\"Random matrix:\")\n",
|
|
"print(data)\n",
|
|
"\n",
|
|
"# Create a simple plot\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"plt.figure(figsize=(8, 4))\n",
|
|
"plt.plot([1, 2, 3, 4], [10, 20, 25, 30], 'ro-')\n",
|
|
"plt.title('Simple Plot')\n",
|
|
"plt.xlabel('X axis')\n",
|
|
"plt.ylabel('Y axis')\n",
|
|
"plt.grid(True)\n",
|
|
"plt.savefig('simple_plot.png') # Save instead of showing\n",
|
|
"print(\"Plot saved to simple_plot.png\")\n",
|
|
"\n",
|
|
"print(\"Test complete!\")\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"output = run_python(simple_test_2)\n",
|
|
"print(\"Output from execution:\")\n",
|
|
"print(output)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test with a simpler example that won't time out\n",
|
|
"\n",
|
|
"simple_test_3 = \"\"\"\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import logging\n",
|
|
"\n",
|
|
"# Set up logging\n",
|
|
"logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
|
|
"\n",
|
|
"# Generate synthetic stock data\n",
|
|
"def generate_stock_data(days=252, volatility=0.01):\n",
|
|
" logging.info(f\"Generating {days} days of synthetic stock data\")\n",
|
|
" np.random.seed(42)\n",
|
|
" price = 100\n",
|
|
" prices = [price]\n",
|
|
" \n",
|
|
" for _ in range(days - 1):\n",
|
|
" change = np.random.normal(0, volatility)\n",
|
|
" price *= (1 + change)\n",
|
|
" prices.append(price)\n",
|
|
" \n",
|
|
" dates = pd.date_range(end=pd.Timestamp.today(), periods=days)\n",
|
|
" df = pd.DataFrame({\n",
|
|
" 'Close': prices,\n",
|
|
" 'Open': [p * (1 - volatility/2) for p in prices],\n",
|
|
" 'High': [p * (1 + volatility) for p in prices],\n",
|
|
" 'Low': [p * (1 - volatility) for p in prices],\n",
|
|
" 'Volume': [np.random.randint(100000, 10000000) for _ in range(days)]\n",
|
|
" }, index=dates)\n",
|
|
" \n",
|
|
" logging.info(f\"Generated data with shape {df.shape}\")\n",
|
|
" return df\n",
|
|
"\n",
|
|
"# Calculate RSI\n",
|
|
"def calculate_rsi(data, window=14):\n",
|
|
" logging.info(f\"Calculating RSI with {window}-day window\")\n",
|
|
" delta = data['Close'].diff()\n",
|
|
" gain = delta.where(delta > 0, 0).rolling(window=window, min_periods=1).mean()\n",
|
|
" loss = -delta.where(delta < 0, 0).rolling(window=window, min_periods=1).mean()\n",
|
|
" \n",
|
|
" rs = gain / loss\n",
|
|
" rsi = 100 - (100 / (1 + rs))\n",
|
|
" return rsi\n",
|
|
"\n",
|
|
"# Main function\n",
|
|
"if __name__ == \"__main__\":\n",
|
|
" # Generate data\n",
|
|
" data = generate_stock_data()\n",
|
|
" \n",
|
|
" # Calculate RSI\n",
|
|
" data['RSI'] = calculate_rsi(data)\n",
|
|
" \n",
|
|
" # Generate signals\n",
|
|
" logging.info(\"Generating trading signals\")\n",
|
|
" data['Signal'] = 0\n",
|
|
" data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n",
|
|
" data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n",
|
|
" \n",
|
|
" # Count signals\n",
|
|
" buy_signals = len(data[data['Signal'] == 1])\n",
|
|
" sell_signals = len(data[data['Signal'] == -1])\n",
|
|
" logging.info(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n",
|
|
" \n",
|
|
" # Plot the results\n",
|
|
" logging.info(\"Creating visualization\")\n",
|
|
" fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), gridspec_kw={'height_ratios': [3, 1]})\n",
|
|
" \n",
|
|
" # Price chart\n",
|
|
" ax1.plot(data.index, data['Close'], label='Close Price')\n",
|
|
" \n",
|
|
" # Add buy/sell signals\n",
|
|
" buy_points = data[data['Signal'] == 1]\n",
|
|
" sell_points = data[data['Signal'] == -1]\n",
|
|
" \n",
|
|
" ax1.scatter(buy_points.index, buy_points['Close'], marker='^', color='g', s=100, label='Buy')\n",
|
|
" ax1.scatter(sell_points.index, sell_points['Close'], marker='v', color='r', s=100, label='Sell')\n",
|
|
" \n",
|
|
" ax1.set_title('Stock Price with RSI Signals')\n",
|
|
" ax1.set_ylabel('Price')\n",
|
|
" ax1.legend()\n",
|
|
" ax1.grid(True)\n",
|
|
" \n",
|
|
" # RSI chart\n",
|
|
" ax2.plot(data.index, data['RSI'], color='purple', label='RSI')\n",
|
|
" ax2.axhline(y=70, color='r', linestyle='--', alpha=0.5)\n",
|
|
" ax2.axhline(y=30, color='g', linestyle='--', alpha=0.5)\n",
|
|
" ax2.set_title('RSI Indicator')\n",
|
|
" ax2.set_ylabel('RSI')\n",
|
|
" ax2.set_ylim(0, 100)\n",
|
|
" ax2.grid(True)\n",
|
|
" \n",
|
|
" plt.tight_layout()\n",
|
|
" plt.savefig('rsi_strategy.png')\n",
|
|
" logging.info(\"Plot saved to rsi_strategy.png\")\n",
|
|
" \n",
|
|
" # Print sample of data\n",
|
|
" logging.info(\"Sample of the processed data:\")\n",
|
|
" print(data[['Close', 'RSI', 'Signal']].tail())\n",
|
|
" \n",
|
|
" logging.info(\"Analysis complete!\")\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"output = run_python(simple_test_3)\n",
|
|
"print(\"Output from execution:\")\n",
|
|
"print(output)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test the enhanced code generation with GPT-4o\n",
|
|
"\n",
|
|
"test_description_4 = \"\"\"\n",
|
|
"Create a trading strategy that:\n",
|
|
"- Uses both MACD and RSI indicators\n",
|
|
"- Buys when MACD crosses above signal line AND RSI is below 40\n",
|
|
"- Sells when MACD crosses below signal line OR RSI is above 70\n",
|
|
"- Includes proper visualization with buy/sell signals\n",
|
|
"- Uses synthetic data if API calls fail\n",
|
|
"- Calculates performance metrics including Sharpe ratio and max drawdown\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"print(\"Generating trading code with GPT-4o...\")\n",
|
|
"generated_code_4 = generate_trading_code(\"gpt-4o\", test_description_4, force_gpt4=True)\n",
|
|
"print(\"Code generation complete. Validating...\")\n",
|
|
"is_valid, issues = validate_code(generated_code_4)\n",
|
|
"\n",
|
|
"if issues:\n",
|
|
" print(f\"Validation found {len(issues)} issues:\")\n",
|
|
" for issue in issues:\n",
|
|
" print(f\"- {issue}\")\n",
|
|
"else:\n",
|
|
" print(\"Code validation passed ✓\")\n",
|
|
"\n",
|
|
"print(\"\\nGenerated code snippet (first 20 lines):\")\n",
|
|
"print(\"\\n\".join(generated_code_4.split(\"\\n\")[:20]))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's run the generated code to test it\n",
|
|
"\n",
|
|
"output = run_python(generated_code_4)\n",
|
|
"print(\"Execution output:\")\n",
|
|
"print(output)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's test again with the fixed timeout setting\n",
|
|
"\n",
|
|
"test_description_5 = \"\"\"\n",
|
|
"Create a simple trading strategy that:\n",
|
|
"- Uses synthetic data generation to avoid API timeouts\n",
|
|
"- Implements a simple moving average crossover (5-day and 20-day)\n",
|
|
"- Includes proper visualization with buy/sell signals\n",
|
|
"- Calculates basic performance metrics\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"print(\"Generating trading code with proper yfinance timeout settings...\")\n",
|
|
"generated_code_5 = generate_trading_code(\"gpt-4o\", test_description_5, force_gpt4=True)\n",
|
|
"print(\"Code generation complete.\")\n",
|
|
"\n",
|
|
"# Run the generated code\n",
|
|
"output = run_python(generated_code_5)\n",
|
|
"print(\"Execution output:\")\n",
|
|
"print(output)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test with a simpler strategy that focuses on scatter plot safety\n",
|
|
"\n",
|
|
"test_description_6 = \"\"\"\n",
|
|
"Create a simple trading strategy that:\n",
|
|
"- Uses synthetic data generation only (no API calls)\n",
|
|
"- Implements a simple RSI-based strategy (buy when RSI < 30, sell when RSI > 70)\n",
|
|
"- Includes visualization with buy/sell signals using scatter plots\n",
|
|
"- Calculates basic performance metrics\n",
|
|
"- Uses proper error handling for all operations\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"print(\"Generating trading code with scatter plot safety...\")\n",
|
|
"generated_code_6 = generate_trading_code(\"gpt-4o\", test_description_6, force_gpt4=True)\n",
|
|
"print(\"Code generation complete.\")\n",
|
|
"\n",
|
|
"# Run the generated code\n",
|
|
"output = run_python(generated_code_6)\n",
|
|
"print(\"Execution output:\")\n",
|
|
"print(output)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test with a fixed example that properly handles pandas formatting and scatter plots\n",
|
|
"\n",
|
|
"test_fixed_code = \"\"\"\n",
|
|
"import yfinance as yf\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import logging\n",
|
|
"from datetime import datetime, timedelta\n",
|
|
"\n",
|
|
"# Configure logging\n",
|
|
"logging.basicConfig(\n",
|
|
" level=logging.INFO,\n",
|
|
" format='[%(asctime)s] %(levelname)s: %(message)s',\n",
|
|
" datefmt='%H:%M:%S'\n",
|
|
")\n",
|
|
"\n",
|
|
"# Helper function for safe formatting of pandas objects\n",
|
|
"def safe_format(obj):\n",
|
|
" if isinstance(obj, (pd.Series, pd.DataFrame)):\n",
|
|
" return str(obj)\n",
|
|
" return obj\n",
|
|
"\n",
|
|
"# Helper function to safely create scatter plots\n",
|
|
"def safe_scatter(ax, x, y, *args, **kwargs):\n",
|
|
" # Ensure x and y are the same length\n",
|
|
" if len(x) != len(y):\n",
|
|
" logging.warning(f\"Scatter plot inputs have different lengths: x={len(x)}, y={len(y)}\")\n",
|
|
" # Find the minimum length\n",
|
|
" min_len = min(len(x), len(y))\n",
|
|
" x = x[:min_len]\n",
|
|
" y = y[:min_len]\n",
|
|
" \n",
|
|
" # Check for empty arrays\n",
|
|
" if len(x) == 0 or len(y) == 0:\n",
|
|
" logging.warning(\"Empty arrays passed to scatter plot, skipping\")\n",
|
|
" return None\n",
|
|
" \n",
|
|
" return ax.scatter(x, y, *args, **kwargs)\n",
|
|
"\n",
|
|
"# Generate synthetic data\n",
|
|
"def generate_synthetic_data(ticker='AAPL', days=252, seed=42):\n",
|
|
" logging.info(f\"Generating synthetic data for {ticker}\")\n",
|
|
" np.random.seed(seed)\n",
|
|
" \n",
|
|
" # Generate price data\n",
|
|
" price = 100 # Starting price\n",
|
|
" prices = [price]\n",
|
|
" \n",
|
|
" for _ in range(days):\n",
|
|
" change = np.random.normal(0, 0.01) # 1% volatility\n",
|
|
" price *= (1 + change)\n",
|
|
" prices.append(price)\n",
|
|
" \n",
|
|
" # Create date range\n",
|
|
" end_date = datetime.now()\n",
|
|
" start_date = end_date - timedelta(days=days)\n",
|
|
" dates = pd.date_range(start=start_date, end=end_date, periods=len(prices))\n",
|
|
" \n",
|
|
" # Create DataFrame\n",
|
|
" df = pd.DataFrame({\n",
|
|
" 'Open': prices[:-1],\n",
|
|
" 'High': [p * 1.01 for p in prices[:-1]],\n",
|
|
" 'Low': [p * 0.99 for p in prices[:-1]],\n",
|
|
" 'Close': prices[1:],\n",
|
|
" 'Volume': [np.random.randint(1000000, 10000000) for _ in range(len(prices)-1)]\n",
|
|
" }, index=dates[:-1])\n",
|
|
" \n",
|
|
" logging.info(f\"Generated {len(df)} days of data for {ticker}\")\n",
|
|
" return df\n",
|
|
"\n",
|
|
"# Calculate RSI\n",
|
|
"def calculate_rsi(data, window=14):\n",
|
|
" logging.info(f\"Calculating RSI with {window}-day window\")\n",
|
|
" delta = data['Close'].diff()\n",
|
|
" gain = delta.where(delta > 0, 0)\n",
|
|
" loss = -delta.where(delta < 0, 0)\n",
|
|
" \n",
|
|
" avg_gain = gain.rolling(window=window, min_periods=1).mean()\n",
|
|
" avg_loss = loss.rolling(window=window, min_periods=1).mean()\n",
|
|
" \n",
|
|
" rs = avg_gain / avg_loss\n",
|
|
" rsi = 100 - (100 / (1 + rs))\n",
|
|
" return rsi\n",
|
|
"\n",
|
|
"# Generate signals\n",
|
|
"def generate_signals(data):\n",
|
|
" logging.info(\"Generating trading signals\")\n",
|
|
" data['Signal'] = 0\n",
|
|
" data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n",
|
|
" data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n",
|
|
" \n",
|
|
" # Count signals\n",
|
|
" buy_signals = len(data[data['Signal'] == 1])\n",
|
|
" sell_signals = len(data[data['Signal'] == -1])\n",
|
|
" logging.info(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n",
|
|
" return data\n",
|
|
"\n",
|
|
"# Backtest strategy\n",
|
|
"def backtest_strategy(data):\n",
|
|
" logging.info(\"Backtesting strategy\")\n",
|
|
" data['Returns'] = data['Close'].pct_change()\n",
|
|
" data['Strategy'] = data['Signal'].shift(1) * data['Returns']\n",
|
|
" \n",
|
|
" # Replace NaN values\n",
|
|
" data['Strategy'].fillna(0, inplace=True)\n",
|
|
" \n",
|
|
" # Calculate cumulative returns\n",
|
|
" data['Cumulative'] = (1 + data['Strategy']).cumprod()\n",
|
|
" \n",
|
|
" # Calculate metrics\n",
|
|
" total_return = data['Cumulative'].iloc[-1] - 1\n",
|
|
" sharpe = np.sqrt(252) * data['Strategy'].mean() / data['Strategy'].std()\n",
|
|
" max_dd = (data['Cumulative'] / data['Cumulative'].cummax() - 1).min()\n",
|
|
" \n",
|
|
" logging.info(f\"Total Return: {total_return:.4f}\")\n",
|
|
" logging.info(f\"Sharpe Ratio: {sharpe:.4f}\")\n",
|
|
" logging.info(f\"Max Drawdown: {max_dd:.4f}\")\n",
|
|
" \n",
|
|
" return data\n",
|
|
"\n",
|
|
"# Visualize results\n",
|
|
"def visualize_results(data, ticker):\n",
|
|
" logging.info(\"Creating visualization\")\n",
|
|
" try:\n",
|
|
" fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [3, 1]})\n",
|
|
" \n",
|
|
" # Price chart\n",
|
|
" ax1.plot(data.index, data['Close'], label='Close Price')\n",
|
|
" \n",
|
|
" # Add buy/sell signals\n",
|
|
" buy_points = data[data['Signal'] == 1]\n",
|
|
" sell_points = data[data['Signal'] == -1]\n",
|
|
" \n",
|
|
" # Use safe scatter to avoid \"x and y must be the same size\" error\n",
|
|
" if not buy_points.empty:\n",
|
|
" safe_scatter(ax1, buy_points.index, buy_points['Close'], marker='^', color='g', s=100, label='Buy')\n",
|
|
" \n",
|
|
" if not sell_points.empty:\n",
|
|
" safe_scatter(ax1, sell_points.index, sell_points['Close'], marker='v', color='r', s=100, label='Sell')\n",
|
|
" \n",
|
|
" ax1.set_title(f'RSI Strategy for {ticker}')\n",
|
|
" ax1.set_ylabel('Price')\n",
|
|
" ax1.legend()\n",
|
|
" ax1.grid(True)\n",
|
|
" \n",
|
|
" # RSI chart\n",
|
|
" ax2.plot(data.index, data['RSI'], color='purple', label='RSI')\n",
|
|
" ax2.axhline(y=70, color='r', linestyle='--', alpha=0.5)\n",
|
|
" ax2.axhline(y=30, color='g', linestyle='--', alpha=0.5)\n",
|
|
" ax2.set_title('RSI Indicator')\n",
|
|
" ax2.set_ylabel('RSI')\n",
|
|
" ax2.set_ylim(0, 100)\n",
|
|
" ax2.grid(True)\n",
|
|
" \n",
|
|
" plt.tight_layout()\n",
|
|
" plt.savefig('rsi_strategy.png')\n",
|
|
" logging.info(\"Plot saved to rsi_strategy.png\")\n",
|
|
" except Exception as e:\n",
|
|
" logging.error(f\"Error in visualization: {e}\")\n",
|
|
"\n",
|
|
"if __name__ == \"__main__\":\n",
|
|
" # Settings\n",
|
|
" ticker = 'AAPL'\n",
|
|
" days = 252 # One year of trading days\n",
|
|
" \n",
|
|
" # Generate data\n",
|
|
" data = generate_synthetic_data(ticker, days)\n",
|
|
" \n",
|
|
" # Calculate RSI\n",
|
|
" data['RSI'] = calculate_rsi(data)\n",
|
|
" \n",
|
|
" # Generate signals\n",
|
|
" data = generate_signals(data)\n",
|
|
" \n",
|
|
" # Backtest strategy\n",
|
|
" data = backtest_strategy(data)\n",
|
|
" \n",
|
|
" # Visualize results\n",
|
|
" visualize_results(data, ticker)\n",
|
|
" \n",
|
|
" # Print sample of data\n",
|
|
" logging.info(\"Sample of the processed data:\")\n",
|
|
" print(data[['Close', 'RSI', 'Signal', 'Strategy', 'Cumulative']].tail())\n",
|
|
" \n",
|
|
" logging.info(\"Analysis complete!\")\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"output = run_python(test_fixed_code)\n",
|
|
"print(\"Output from execution:\")\n",
|
|
"print(output)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test the fix for indentation error in plotting code\n",
|
|
"\n",
|
|
"test_code_with_plotting = \"\"\"\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"import logging\n",
|
|
"\n",
|
|
"# Configure logging\n",
|
|
"logging.basicConfig(\n",
|
|
" level=logging.INFO,\n",
|
|
" format='[%(asctime)s] %(levelname)s: %(message)s',\n",
|
|
" datefmt='%H:%M:%S'\n",
|
|
")\n",
|
|
"\n",
|
|
"# Simple plotting function\n",
|
|
"def create_plot():\n",
|
|
" # Generate some data\n",
|
|
" x = np.linspace(0, 10, 100)\n",
|
|
" y = np.sin(x)\n",
|
|
" \n",
|
|
" # Create plot\n",
|
|
" logging.info(\"Creating sine wave plot\")\n",
|
|
" plt.figure(figsize=(10, 6))\n",
|
|
" plt.plot(x, y)\n",
|
|
" plt.title('Sine Wave')\n",
|
|
" plt.xlabel('X')\n",
|
|
" plt.ylabel('Y')\n",
|
|
" plt.grid(True)\n",
|
|
" plt.savefig('sine_wave.png')\n",
|
|
" logging.info(\"Plot saved to sine_wave.png\")\n",
|
|
"\n",
|
|
"if __name__ == \"__main__\":\n",
|
|
" create_plot()\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"# Apply safety features to the code\n",
|
|
"enhanced_code = add_safety_features(test_code_with_plotting)\n",
|
|
"print(\"Code with safety features applied:\")\n",
|
|
"print(enhanced_code)\n",
|
|
"\n",
|
|
"# Run the enhanced code\n",
|
|
"output = run_python(enhanced_code)\n",
|
|
"print(\"\\nExecution output:\")\n",
|
|
"print(output)\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "base",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.13.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|