Merge pull request #669 from Krabulek/week-4-contributions

Week 4 excercises: added Gemini and Python Code Documentation Assistant
This commit is contained in:
Ed Donner
2025-09-19 19:43:31 -04:00
committed by GitHub
6 changed files with 2374 additions and 0 deletions

View File

@@ -0,0 +1,828 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "4a6ab9a2-28a2-445d-8512-a0dc8d1b54e9",
"metadata": {},
"source": [
"# Python Code Documentation Assistant\n",
"\n",
"The requirement: use a Frontier model to add docstrings and comments to your Python code\n"
]
},
{
"cell_type": "markdown",
"id": "d4634170-c444-4326-9e68-5f87c63fa0e0",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f72dfaf-9f20-4d81-b082-018eda152c9f",
"metadata": {},
"outputs": [],
"source": [
"!pip install -U -q \"google-genai\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e610bf56-a46e-4aff-8de1-ab49d62b1ad3",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import io\n",
"import sys\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"from google import genai\n",
"from google.genai import types\n",
"import anthropic\n",
"from IPython.display import Markdown, display, update_display\n",
"import gradio as gr\n",
"import subprocess"
]
},
{
"cell_type": "markdown",
"id": "f91e8b32-4c98-4210-a1e1-bfe0b1fddab7",
"metadata": {},
"source": [
"## Environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4f672e1c-87e9-4865-b760-370fa605e614",
"metadata": {},
"outputs": [],
"source": [
"load_dotenv(override=True)\n",
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
"google_api_key = os.getenv('GOOGLE_API_KEY')\n",
"\n",
"if openai_api_key:\n",
" print(f\"OpenAI API Key exists and begins with: {openai_api_key[:8]}\")\n",
"else:\n",
" print(\"OpenAI API Key not set\")\n",
" \n",
"if anthropic_api_key:\n",
" print(f\"Anthropic API Key exists and begins with: {anthropic_api_key[:7]}\")\n",
"else:\n",
" print(\"Anthropic API Key not set\")\n",
"\n",
"if google_api_key:\n",
" print(f\"Google API Key exists and begins with: {google_api_key[:4]}\")\n",
"else:\n",
" print(\"Google API Key not set\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8aa149ed-9298-4d69-8fe2-8f5de0f667da",
"metadata": {},
"outputs": [],
"source": [
"openai = OpenAI()\n",
"claude = anthropic.Anthropic()\n",
"gemini = genai.Client()\n",
"\n",
"OPENAI_MODEL = \"o4-mini\"\n",
"CLAUDE_MODEL = \"claude-3-7-sonnet-latest\"\n",
"GEMINI_MODEL = \"gemini-2.5-flash\""
]
},
{
"cell_type": "markdown",
"id": "88a18c58-40d5-4592-8dd3-d7c7b0d951aa",
"metadata": {},
"source": [
"## Prompts"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6896636f-923e-4a2c-9d6c-fac07828a201",
"metadata": {},
"outputs": [],
"source": [
"system_message = \"\"\"\n",
"You are an assistant that documents Python code. \n",
"Your task: \n",
"- Add concise, clear, and informative docstrings to functions, classes, and modules. \n",
"- Add inline comments only where they improve readability or clarify intent. \n",
"- Do not modify the code logic or structure. \n",
"- Respond with Python code only. \n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e7b3546-57aa-4c29-bc5d-f211970d04eb",
"metadata": {},
"outputs": [],
"source": [
"def user_prompt_for(python):\n",
" user_prompt = \"Add docstrings and comments to the following Python code:\\n\"\n",
" user_prompt += python\n",
" return user_prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6190659-f54c-4951-bef4-4960f8e51cc4",
"metadata": {},
"outputs": [],
"source": [
"def messages_for(python):\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt_for(python)}\n",
" ]"
]
},
{
"cell_type": "markdown",
"id": "624e5066-bcf6-490d-a790-608d2bb34184",
"metadata": {},
"source": [
"## Helper functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71e1ba8c-5b05-4726-a9f3-8d8c6257350b",
"metadata": {},
"outputs": [],
"source": [
"def write_output(python, filename_suffix):\n",
" filename = f\"annotated_{filename_suffix}.py\"\n",
" code = python.replace(\"```python\",\"\").replace(\"```\",\"\")\n",
" with open(filename, \"w\") as f:\n",
" f.write(code)\n",
" print(f\"\\nWritten code to {filename}\")\n",
" return filename"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7d2fea8-74c6-4421-8f1e-0e76d5b201b9",
"metadata": {},
"outputs": [],
"source": [
"def annotate_with_gpt(python, task_name): \n",
" stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python), stream=True)\n",
" reply = \"\"\n",
" for chunk in stream:\n",
" fragment = chunk.choices[0].delta.content or \"\"\n",
" reply += fragment\n",
" print(fragment, end='', flush=True)\n",
" return write_output(reply, f\"{task_name}_gpt\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7cd84ad8-d55c-4fe0-9eeb-1895c95c4a9d",
"metadata": {},
"outputs": [],
"source": [
"def annotate_with_claude(python, task_name):\n",
" result = claude.messages.stream(\n",
" model=CLAUDE_MODEL,\n",
" max_tokens=2000,\n",
" system=system_message,\n",
" messages=[{\"role\": \"user\", \"content\": user_prompt_for(python)}],\n",
" )\n",
" reply = \"\"\n",
" with result as stream:\n",
" for text in stream.text_stream:\n",
" reply += text\n",
" print(text, end=\"\", flush=True)\n",
" return write_output(reply, f\"{task_name}_claude\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8a35102-1c95-469b-8855-e85f4c9bdbdf",
"metadata": {},
"outputs": [],
"source": [
"def annotate_with_gemini(python, task_name):\n",
" reply = gemini.models.generate_content(\n",
" model=GEMINI_MODEL,\n",
" contents=user_prompt_for(python),\n",
" config=types.GenerateContentConfig(\n",
" system_instruction=system_message,\n",
" )\n",
" )\n",
"\n",
" print(reply.text)\n",
" return write_output(reply.text, f\"{task_name}_gemini\")"
]
},
{
"cell_type": "markdown",
"id": "028dcfdd-2d52-4e11-a79e-2214a97cb26d",
"metadata": {},
"source": [
"# Run the Annotator"
]
},
{
"cell_type": "markdown",
"id": "7462d9f9-6215-4fb0-9471-1d0141d33205",
"metadata": {},
"source": [
"## Pi example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1cbb778-fa57-43de-b04b-ed523f396c38",
"metadata": {},
"outputs": [],
"source": [
"pi = \"\"\"\n",
"import time\n",
"\n",
"def calculate(iterations, param1, param2):\n",
" result = 1.0\n",
" for i in range(1, iterations+1):\n",
" j = i * param1 - param2\n",
" result -= (1/j)\n",
" j = i * param1 + param2\n",
" result += (1/j)\n",
" return result\n",
"\n",
"start_time = time.time()\n",
"result = calculate(100_000_000, 4, 1) * 4\n",
"end_time = time.time()\n",
"\n",
"print(f\"Result: {result:.12f}\")\n",
"print(f\"Execution Time: {(end_time - start_time):.6f} seconds\")\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "105db6f9-343c-491d-8e44-3a5328b81719",
"metadata": {},
"outputs": [],
"source": [
"gpt_pi = annotate_with_gpt(pi, \"pi))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "415819d0-fc95-4f78-a6ae-5c7d6781c6a7",
"metadata": {},
"outputs": [],
"source": [
"# check if the script works\n",
"\n",
"exec(open(gpt_pi).read())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "983a11fe-e24d-4c65-8269-9802c5ef3ae6",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"claude_pi = annotate_with_claude(pi, \"pi\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52f5b710-0dea-4884-8ed7-a94059d88281",
"metadata": {},
"outputs": [],
"source": [
"exec(open(claude_pi).read())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01f331f2-caac-48f6-9a03-8a228ee521bc",
"metadata": {},
"outputs": [],
"source": [
"gemini_pi = annotate_with_gemini(pi, \"pi\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23529942-53fa-46ad-a5db-1f3096dd6607",
"metadata": {},
"outputs": [],
"source": [
"exec(open(gemini_pi).read())"
]
},
{
"cell_type": "markdown",
"id": "7d1eaeca-61be-4d0a-a525-dd09f52aaa0f",
"metadata": {},
"source": [
"## Hard example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3b497b3-f569-420e-b92e-fb0f49957ce0",
"metadata": {},
"outputs": [],
"source": [
"python_hard = \"\"\"# Be careful to support large number sizes\n",
"\n",
"def lcg(seed, a=1664525, c=1013904223, m=2**32):\n",
" value = seed\n",
" while True:\n",
" value = (a * value + c) % m\n",
" yield value\n",
" \n",
"def max_subarray_sum(n, seed, min_val, max_val):\n",
" lcg_gen = lcg(seed)\n",
" random_numbers = [next(lcg_gen) % (max_val - min_val + 1) + min_val for _ in range(n)]\n",
" max_sum = float('-inf')\n",
" for i in range(n):\n",
" current_sum = 0\n",
" for j in range(i, n):\n",
" current_sum += random_numbers[j]\n",
" if current_sum > max_sum:\n",
" max_sum = current_sum\n",
" return max_sum\n",
"\n",
"def total_max_subarray_sum(n, initial_seed, min_val, max_val):\n",
" total_sum = 0\n",
" lcg_gen = lcg(initial_seed)\n",
" for _ in range(20):\n",
" seed = next(lcg_gen)\n",
" total_sum += max_subarray_sum(n, seed, min_val, max_val)\n",
" return total_sum\n",
"\n",
"# Parameters\n",
"n = 10000 # Number of random numbers\n",
"initial_seed = 42 # Initial seed for the LCG\n",
"min_val = -10 # Minimum value of random numbers\n",
"max_val = 10 # Maximum value of random numbers\n",
"\n",
"# Timing the function\n",
"import time\n",
"start_time = time.time()\n",
"result = total_max_subarray_sum(n, initial_seed, min_val, max_val)\n",
"end_time = time.time()\n",
"\n",
"print(\"Total Maximum Subarray Sum (20 runs):\", result)\n",
"print(\"Execution Time: {:.6f} seconds\".format(end_time - start_time))\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dab5e4bc-276c-4555-bd4c-12c699d5e899",
"metadata": {},
"outputs": [],
"source": [
"exec(python_hard)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8d24ed5-2c15-4f55-80e7-13a3952b3cb8",
"metadata": {},
"outputs": [],
"source": [
"gpt_hard = annotate_with_gpt(python_hard, \"hard\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80a15259-3d51-47b8-953c-6271fbd4b6fb",
"metadata": {},
"outputs": [],
"source": [
"exec(open(gpt_hard).read())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9305446-1d0c-4b51-866a-b8c1e299bf5c",
"metadata": {},
"outputs": [],
"source": [
"gemini_hard = annotate_with_gemini(python_hard, \"hard\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad6eecc8-0517-43d8-bd21-5bbdedae7a10",
"metadata": {},
"outputs": [],
"source": [
"exec(open(gemini_hard).read())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ee75e72-9ecb-4edd-a74a-4d3a83c1eb79",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"claude_hard = annotate_with_claude(python_hard, \"hard\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47af1516-455f-4d1c-8a1c-2da5a38c0ba5",
"metadata": {},
"outputs": [],
"source": [
"exec(open(claude_hard).read())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f60d33c-f6b7-4fc5-bc2b-57957b076e34",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"This module implements a Linear Congruential Generator (LCG) and uses it\n",
"to generate random numbers for calculating the maximum subarray sum.\n",
"It includes functions for the LCG, finding the maximum subarray sum, and\n",
"aggregating results over multiple runs.\n",
"\"\"\"\n",
"\n",
"def lcg(seed, a=1664525, c=1013904223, m=2**32):\n",
" \"\"\"\n",
" Implements a Linear Congruential Generator (LCG) to produce a sequence of\n",
" pseudorandom numbers.\n",
"\n",
" The generator uses the formula: X_{n+1} = (a * X_n + c) % m.\n",
"\n",
" Args:\n",
" seed (int): The initial seed value for the generator (X_0).\n",
" a (int, optional): The multiplier. Defaults to 1664525 (common LCG parameter).\n",
" c (int, optional): The increment. Defaults to 1013904223 (common LCG parameter).\n",
" m (int, optional): The modulus. Defaults to 2**32, meaning numbers will be\n",
" between 0 and m-1.\n",
"\n",
" Yields:\n",
" int: The next pseudorandom number in the sequence.\n",
" \"\"\"\n",
" value = seed\n",
" while True:\n",
" # Calculate the next pseudorandom number using the LCG formula.\n",
" value = (a * value + c) % m\n",
" yield value\n",
"\n",
"def max_subarray_sum(n, seed, min_val, max_val):\n",
" \"\"\"\n",
" Calculates the maximum possible sum of a contiguous subarray within a list\n",
" of 'n' pseudorandom numbers.\n",
"\n",
" The random numbers are generated using an LCG based on the provided seed,\n",
" and then mapped to the range [min_val, max_val].\n",
" This implementation uses a brute-force approach with O(n^2) complexity.\n",
"\n",
" Args:\n",
" n (int): The number of random integers to generate for the array.\n",
" seed (int): The seed for the LCG to generate the random numbers.\n",
" min_val (int): The minimum possible value for the generated random numbers.\n",
" max_val (int): The maximum possible value for the generated random numbers.\n",
"\n",
" Returns:\n",
" int: The maximum sum found among all contiguous subarrays.\n",
" \"\"\"\n",
" lcg_gen = lcg(seed)\n",
" # Generate a list of 'n' random numbers within the specified range [min_val, max_val].\n",
" random_numbers = [next(lcg_gen) % (max_val - min_val + 1) + min_val for _ in range(n)]\n",
"\n",
" max_sum = float('-inf') # Initialize max_sum to negative infinity to handle all negative numbers.\n",
"\n",
" # Iterate through all possible starting points of a subarray.\n",
" for i in range(n):\n",
" current_sum = 0\n",
" # Iterate through all possible ending points for the current starting point.\n",
" for j in range(i, n):\n",
" current_sum += random_numbers[j]\n",
" # Update max_sum if the current subarray sum is greater.\n",
" if current_sum > max_sum:\n",
" max_sum = current_sum\n",
" return max_sum\n",
"\n",
"def total_max_subarray_sum(n, initial_seed, min_val, max_val):\n",
" \"\"\"\n",
" Calculates the sum of maximum subarray sums over 20 separate runs.\n",
"\n",
" Each run generates a new set of 'n' random numbers for `max_subarray_sum`\n",
" using a new seed derived from the initial LCG sequence.\n",
"\n",
" Args:\n",
" n (int): The number of random integers for each subarray sum calculation.\n",
" initial_seed (int): The initial seed for the LCG that generates seeds\n",
" for individual `max_subarray_sum` runs.\n",
" min_val (int): The minimum possible value for random numbers in each run.\n",
" max_val (int): The maximum possible value for random numbers in each run.\n",
"\n",
" Returns:\n",
" int: The sum of the maximum subarray sums across all 20 runs.\n",
" \"\"\"\n",
" total_sum = 0\n",
" lcg_gen = lcg(initial_seed) # LCG to generate seeds for subsequent runs.\n",
" # Perform 20 independent runs.\n",
" for _ in range(20):\n",
" # Get a new seed for each run from the initial LCG generator.\n",
" seed = next(lcg_gen)\n",
" # Add the maximum subarray sum of the current run to the total sum.\n",
" total_sum += max_subarray_sum(n, seed, min_val, max_val)\n",
" return total_sum\n",
"\n",
"# Parameters for the simulation\n",
"n = 10000 # Number of random numbers to generate for each subarray\n",
"initial_seed = 42 # Initial seed for the LCG that generates seeds for runs\n",
"min_val = -10 # Minimum value for the random numbers\n",
"max_val = 10 # Maximum value for the random numbers\n",
"\n",
"# Import the time module to measure execution time.\n",
"import time\n",
"\n",
"# Record the start time before executing the main function.\n",
"start_time = time.time()\n",
"# Call the function to calculate the total maximum subarray sum over multiple runs.\n",
"result = total_max_subarray_sum(n, initial_seed, min_val, max_val)\n",
"# Record the end time after the function completes.\n",
"end_time = time.time()\n",
"\n",
"# Print the final aggregated result.\n",
"print(\"Total Maximum Subarray Sum (20 runs):\", result)\n",
"# Print the total execution time, formatted to 6 decimal places.\n",
"print(\"Execution Time: {:.6f} seconds\".format(end_time - start_time))"
]
},
{
"cell_type": "markdown",
"id": "ff02ce09-0544-49a5-944d-a57b25bf9b72",
"metadata": {},
"source": [
"# Streaming"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0be9f47d-5213-4700-b0e2-d444c7c738c0",
"metadata": {},
"outputs": [],
"source": [
"def stream_gpt(python): \n",
" stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python), stream=True)\n",
" reply = \"\"\n",
" for chunk in stream:\n",
" fragment = chunk.choices[0].delta.content or \"\"\n",
" reply += fragment\n",
" yield reply.replace('```python\\n','').replace('```','')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8669f56b-8314-4582-a167-78842caea131",
"metadata": {},
"outputs": [],
"source": [
"def stream_claude(python):\n",
" result = claude.messages.stream(\n",
" model=CLAUDE_MODEL,\n",
" max_tokens=2000,\n",
" system=system_message,\n",
" messages=[{\"role\": \"user\", \"content\": user_prompt_for(python)}],\n",
" )\n",
" reply = \"\"\n",
" with result as stream:\n",
" for text in stream.text_stream:\n",
" reply += text\n",
" yield reply.replace('```python\\n','').replace('```','')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d48d44df-c082-4ed1-b3ea-fc2a880591c2",
"metadata": {},
"outputs": [],
"source": [
"def stream_gemini(python):\n",
" stream = gemini.models.generate_content_stream(\n",
" model=GEMINI_MODEL,\n",
" contents=user_prompt_for(python),\n",
" config=types.GenerateContentConfig(\n",
" system_instruction=system_message,\n",
" ),\n",
" )\n",
" reply = \"\"\n",
" for chunk in stream:\n",
" reply += chunk.text\n",
" yield reply.replace('```python\\n','').replace('```','')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f1ae8f5-16c8-40a0-aa18-63b617df078d",
"metadata": {},
"outputs": [],
"source": [
"def annotate(python, model):\n",
" if model == \"GPT\":\n",
" result = stream_gpt(python)\n",
" elif model == \"Claude\":\n",
" result = stream_claude(python)\n",
" elif model == \"Gemini\":\n",
" result = stream_gemini(python)\n",
" else:\n",
" raise ValueError(\"Unknown model\")\n",
" for stream_so_far in result:\n",
" yield stream_so_far "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19bf2bff-a822-4009-a539-f003b1651383",
"metadata": {},
"outputs": [],
"source": [
"def execute_python(code):\n",
" try:\n",
" output = io.StringIO()\n",
" sys.stdout = output\n",
" exec(code)\n",
" finally:\n",
" sys.stdout = sys.__stdout__\n",
" return output.getvalue()"
]
},
{
"cell_type": "markdown",
"id": "8391444b-b938-4f92-982f-91439b38d901",
"metadata": {},
"source": [
"# Gradio App"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a2274f1-d03b-42c0-8dcc-4ce159b18442",
"metadata": {},
"outputs": [],
"source": [
"css = \"\"\"\n",
".python {background-color: #306998;}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76167ea9-d0a1-4bc6-8d73-633d3b8c8df6",
"metadata": {},
"outputs": [],
"source": [
"import gradio as gr\n",
"\n",
"# Parameters\n",
"LINES = 25\n",
"LINE_HEIGHT = 20 # px, typical CodeMirror line height\n",
"PADDING = 10 # px, top + bottom padding\n",
"\n",
"CODE_HEIGHT = LINES * LINE_HEIGHT + PADDING\n",
"\n",
"\n",
"with gr.Blocks(\n",
" theme=gr.themes.Soft(),\n",
" css=f\"\"\"\n",
"#code_input .cm-editor, #annotated_code .cm-editor {{\n",
" height: {CODE_HEIGHT}px !important;\n",
" overflow-y: auto !important;\n",
"}}\n",
"\"\"\"\n",
") as demo_v2:\n",
" gr.Markdown(\"## 🐍 Annotate Python Code with Docstrings and Comments\")\n",
"\n",
" with gr.Row():\n",
" with gr.Column(scale=1):\n",
" gr.Markdown(\"### Python code:\")\n",
" code_input = gr.Code(\n",
" language=\"python\", \n",
" value=python_hard,\n",
" elem_id=\"code_input\"\n",
" )\n",
" \n",
" with gr.Column(scale=1):\n",
" gr.Markdown(\"### Annotated code:\")\n",
" annotated_output = gr.Code(\n",
" language=\"python\",\n",
" elem_id=\"annotated_code\",\n",
" interactive=False\n",
" )\n",
"\n",
" with gr.Row():\n",
" with gr.Column(scale=1):\n",
" model_dropdown = gr.Dropdown(\n",
" choices=[\"Gemini\", \"GPT-4\", \"Claude\"],\n",
" value=\"Gemini\",\n",
" label=\"Select model\"\n",
" )\n",
" with gr.Column(scale=1):\n",
" annotate_btn = gr.Button(\"✨ Annotate code\", variant=\"primary\")\n",
" run_btn = gr.Button(\"▶️ Run Python\", variant=\"secondary\")\n",
"\n",
" with gr.Row():\n",
" with gr.Column():\n",
" gr.Markdown(\"### Python result:\")\n",
" result_output = gr.Textbox(\n",
" lines=5, \n",
" label=\"Output\",\n",
" interactive=False\n",
" )\n",
" \n",
" annotate_btn.click(\n",
" annotate,\n",
" inputs=[code_input, model_dropdown],\n",
" outputs=[annotated_output]\n",
" )\n",
" run_btn.click(execute_python, inputs=[annotated_output], outputs=[result_output])\n",
"\n",
" \n",
"demo_v2.launch(inbrowser=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea42883b-fdba-46ed-97be-f42e3cb41f11",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,113 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "1c8f17b7-dc42-408f-9b21-cdcfd7dbfb78",
"metadata": {},
"source": [
"# AutoTrader Code Generator\n",
"\n",
"Gemini-driven autonomous equities trading bot code generator for simulated market APIs"
]
},
{
"cell_type": "markdown",
"id": "fcffcfd7-000f-4995-82ae-94232fef1654",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43d8c659-08d4-4a65-8c39-e42f6c458ba8",
"metadata": {},
"outputs": [],
"source": [
"!pip install -U -q \"google-genai\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de542411-6b4d-47cd-bf84-d80562b333a5",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import io\n",
"import sys\n",
"from dotenv import load_dotenv\n",
"from google import genai\n",
"from google.genai import types\n",
"from IPython.display import Markdown, display, update_display\n",
"import gradio as gr\n",
"import subprocess"
]
},
{
"cell_type": "markdown",
"id": "e9b78b19-2d47-4973-adbc-d281d8ac8224",
"metadata": {},
"source": [
"## Google API Key Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9a2e07f-9b07-4afe-8938-ba40c41701ff",
"metadata": {},
"outputs": [],
"source": [
"load_dotenv(override=True)\n",
"google_api_key = os.getenv('GOOGLE_API_KEY')\n",
"\n",
"if google_api_key:\n",
" print(f\"Google API Key exists and begins with: {google_api_key[:4]}\")\n",
"else:\n",
" print(\"Google API Key not set\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4da3664d-9d73-41a2-8a71-fe9468f3955f",
"metadata": {},
"outputs": [],
"source": [
"!python ./gemini_trading_code_generator.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f7b7c74-3d77-49d3-ac9f-f3a04743946c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,34 @@
{
"name": "ExampleEquitySim",
"base_url": "https://sim.example.com/api",
"endpoints": {
"get_price": {
"path": "/market/price",
"method": "GET",
"params": ["symbol"]
},
"place_order": {
"path": "/orders",
"method": "POST",
"body": ["symbol", "side", "quantity", "order_type", "price_optional"]
},
"cancel_order": {
"path": "/orders/{order_id}/cancel",
"method": "POST"
},
"get_balance": {
"path": "/account/balance",
"method": "GET"
},
"get_positions": {
"path": "/account/positions",
"method": "GET"
}
},
"auth": {
"type": "api_key_header",
"header_name": "X-API-KEY",
"api_key_placeholder": "<SIM_API_KEY>"
},
"notes": "This simulated API uses JSON and returns ISO timestamps in UTC."
}

View File

@@ -0,0 +1,177 @@
"""
gemini_trading_code_generator.py
Usage:
- Prepare you API Specification JSON file with your simulated API details.
- Run: pip install google-genai.
- Set GOOGLE_API_KEY env var before running.
- Run: python gemini_trading_code_generator.py
- The generated bot will be saved as `generated_trading_bot.py`.
Notes:
- THIS GENERATES CODE FOR A SIMULATED ENVIRONMENT. Read and review generated code before running.
- Keep your API keys safe.
"""
import os
import json
from typing import Dict, Any
from datetime import datetime
# Gemini client import (Google GenAI SDK)
try:
from google import genai
from google.genai import types
except Exception as e:
raise RuntimeError("google-genai not installed. Run: pip install google-genai") from e
# ------------ Gemini / Prompting helpers -------------
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
# We won't fail here — the generator will raise when trying to call the client.
pass
GEMINI_MODEL = "gemini-2.5-flash"
MAX_TOKEN_COUNT = 12000
def build_prompt(api_spec: Dict[str, Any], strategy: str = "sma_crossover") -> str:
"""
Create a clear instruction for Gemini to generate the trading bot code.
Strategy choices:
- sma_crossover (default): simple moving-average crossover strategy
- random: random buy/sell (for testing)
- placeholder for the user to request others
"""
now = datetime.utcnow().isoformat() + "Z"
prompt = f"""
You are a code-writing assistant. Produce a single, self-contained Python script named `generated_trading_bot.py`
that implements a trading bot for a *simulated equities API*. The simulator has the following specification (JSON):
{json.dumps(api_spec, indent=2)}
Requirements for the generated script:
1. The script must be runnable as-is (except for inserting API keys/config). Use only stdlib + `requests` (no other external deps).
2. Implement a simple trading strategy: {strategy}. For `sma_crossover`, implement:
- Fetch historical or recent prices (you may simulate historical by sampling current price in a loop if the API doesn't return history).
- Compute short and long simple moving averages (e.g., 5-period and 20-period).
- When short SMA crosses above long SMA: submit a MARKET or SIMULATED BUY order sized to use a configurable fraction of available cash.
- When short SMA crosses below long SMA: submit a MARKET or SIMULATED SELL order to close position.
3. Use the API endpoints from the spec exactly (build URLs using base_url + path). Respect auth header scheme from spec.
4. Include robust error handling and logging (print statements acceptable).
5. Include a `--dry-run` flag that prints actions instead of placing orders.
6. Include a safe simulation mode: throttle requests, avoid rapid-fire orders, and include a configurable `min_time_between_trades_seconds`.
7. Add inline comments explaining important functions and a short README-like docstring at the top of the generated file describing how to configure and run it.
8. At the end of the generated file, add a __main__ section that demonstrates a short run (e.g., a 60-second loop) in dry-run mode.
9. Do NOT assume any third-party libraries beyond `requests`. Use dataclasses where helpful. Use typing annotations.
10. Always document any assumptions you make in a top-level comment block.
11. Keep the entire output as valid Python code only (no additional text around it).
Generate code now.
Timestamp for reproducibility: {now}
"""
return prompt.strip()
# ------------ Gemini call -------------
def generate_code_with_gemini(prompt: str, model: str = GEMINI_MODEL, max_tokens: int = MAX_TOKEN_COUNT) -> str:
"""
Call the Gemini model to generate the code.
Uses google-genai SDK. Make sure env var GOOGLE_API_KEY or GOOGLE_API_KEY is set.
"""
if not GOOGLE_API_KEY:
raise RuntimeError("No Google API key found. Set GOOGLE_API_KEY environment variable.")
# Create client (per Google Gen AI quickstart)
client = genai.Client(api_key=GOOGLE_API_KEY)
# The SDK surface has varied; using the documented 'models.generate_content' style.
# If your SDK differs, adapt accordingly.
response = client.models.generate_content(
model=model,
contents=prompt,
config=types.GenerateContentConfig(
max_output_tokens=max_tokens,
)
)
text = None
if hasattr(response, "text") and response.text:
text = response.text
else:
# attempt to dig into typical structures
try:
# some SDKs return dict-like object
if isinstance(response, dict):
# Try common keys
for k in ("text", "content", "output", "candidates"):
if k in response and response[k]:
text = json.dumps(response[k]) if not isinstance(response[k], str) else response[k]
break
else:
# object with attributes
if hasattr(response, "output") and response.output:
# navigate first candidate -> text
out = response.output
if isinstance(out, (list, tuple)) and len(out) > 0:
first = out[0]
if isinstance(first, dict) and "content" in first:
text = first["content"][0].get("text")
elif hasattr(first, "content"):
text = first.content[0].text
except Exception:
pass
if not text:
raise RuntimeError("Could not extract generated text from Gemini response. Inspect `response` object: " + repr(response))
return text
# ------------ Save & basic verification -------------
def basic_sanity_check(code_text: str) -> bool:
"""Do a quick check that the output looks like Python file and contains required sections."""
checks = [
"import requests" in code_text or "import urllib" in code_text,
"def " in code_text,
"if __name__" in code_text,
"place_order" in code_text or "order" in code_text
]
return all(checks)
def save_generated_file(code_text: str, filename: str = "generated_trading_bot.py") -> str:
code_text = code_text.replace("```python","").replace("```","")
with open(filename, "w", encoding="utf-8") as f:
f.write(code_text)
return os.path.abspath(filename)
# ------------ Main CLI -------------
def main():
import argparse
parser = argparse.ArgumentParser(description="Generate trading bot code using Gemini (Google GenAI).")
parser.add_argument("--api-spec", type=str, default="api_spec.json", help="Path to JSON file with API spec.")
parser.add_argument("--out", type=str, default="generated_trading_bot.py", help="Output filename.")
parser.add_argument("--model", type=str, default=GEMINI_MODEL, help="Gemini model to use.")
parser.add_argument("--max-tokens", type=int, default=MAX_TOKEN_COUNT, help="Max tokens for generation.")
parser.add_argument("--strategy", type=str, default="sma_crossover", help="Trading strategy to request.")
args = parser.parse_args()
with open(args.api_spec, "r", encoding="utf-8") as f:
api_spec = json.load(f)
prompt = build_prompt(api_spec, strategy=args.strategy)
print("Calling Gemini to generate code... (this will use your GOOGLE_API_KEY)")
generated = generate_code_with_gemini(prompt, model=args.model, max_tokens=args.max_tokens)
print("Performing sanity checks on the generated code...")
if not basic_sanity_check(generated):
print("Warning: basic sanity checks failed. Still saving the file for inspection.")
path = save_generated_file(generated, filename=args.out)
print(f"Generated code saved to: {path}")
print("Important: Review the generated code carefully before running against any system (even a simulator).")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,532 @@
import requests
import json
import time
import os
import collections
import argparse
import math
from dataclasses import dataclass
from typing import Dict, Any, Deque, Optional, List
# --- Assumptions ---
# 1. Historical Price Data Simulation: The simulated API's `/market/price` endpoint
# only provides the current price. To implement SMA crossover, which requires
# historical data, this bot simulates history by repeatedly calling `get_price`
# over time and storing the results. It assumes that calling `get_price` at regular
# intervals (e.g., every 5 seconds) effectively provides a time-series of prices.
# 2. API Response Formats:
# - `get_price`: Assumed to return `{"symbol": "SYM", "price": 123.45}`.
# - `get_balance`: Assumed to return `{"cash_balance": 10000.00, ...}`.
# - `get_positions`: Assumed to return a list of dictionaries, e.
# `[{"symbol": "SYM", "quantity": 10}, ...]`. If no position, an empty list or
# a list without the symbol.
# - `place_order`: Assumed to return `{"order_id": "...", "status": "accepted"}`.
# 3. Order Type: For `place_order`, `order_type` is assumed to be "MARKET" for simplicity,
# as no other types are specified and "price_optional" implies it's for limit orders.
# For MARKET orders, `price_optional` will not be sent.
# 4. Error Handling: Basic network and API-level error checking is implemented.
# More complex retry logic or backoff strategies are not included to keep the example concise.
# 5. Time Zones: The API notes specify ISO timestamps in UTC. For internal logic,
# `time.time()` (epoch seconds in UTC) is used for time comparisons, which is
# sufficient for throttling and trade timing.
# --- Configuration ---
# You can override these defaults using command-line arguments.
DEFAULT_API_KEY = os.environ.get("SIM_API_KEY", "<YOUR_SIM_API_KEY>") # Set SIM_API_KEY env var or replace
DEFAULT_BASE_URL = "https://sim.example.com/api"
DEFAULT_SYMBOL = "AAPL" # Example stock symbol
# Trading Strategy Parameters
DEFAULT_SHORT_SMA_PERIOD = 5 # Number of price points for short SMA
DEFAULT_LONG_SMA_PERIOD = 20 # Number of price points for long SMA
DEFAULT_BUY_CASH_FRACTION = 0.95 # Fraction of available cash to use for a BUY order
# Bot Operation Parameters
DEFAULT_PRICE_FETCH_INTERVAL_SECONDS = 5 # How often to fetch a new price point for SMA calculation
DEFAULT_MAIN_LOOP_INTERVAL_SECONDS = 10 # How often the bot evaluates the strategy
DEFAULT_MIN_TIME_BETWEEN_TRADES_SECONDS = 60 # Minimum time (seconds) between placing orders
DEFAULT_INITIAL_HISTORY_COLLECTION_COUNT = DEFAULT_LONG_SMA_PERIOD + 5 # Ensure enough data for long SMA
@dataclass
class TradingBotConfig:
api_key: str
base_url: str
symbol: str
short_sma_period: int
long_sma_period: int
buy_cash_fraction: float
price_fetch_interval_seconds: int
main_loop_interval_seconds: int
min_time_between_trades_seconds: int
initial_history_collection_count: int
dry_run: bool
class SimulatedAPIClient:
"""
Client for interacting with the ExampleEquitySim API.
Handles request building, authentication, and basic error parsing.
"""
def __init__(self, base_url: str, api_key: str):
self.base_url = base_url
self.headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
self.session = requests.Session() # Use a session for connection pooling
def _log(self, message: str) -> None:
"""Simple logging utility."""
print(f"[API Client] {message}")
def _make_request(
self,
method: str,
path: str,
params: Optional[Dict[str, Any]] = None,
json_data: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""
Generic helper to make API requests.
"""
url = f"{self.base_url}{path}"
try:
response = self.session.request(
method, url, headers=self.headers, params=params, json=json_data, timeout=10
)
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
return response.json()
except requests.exceptions.HTTPError as e:
self._log(f"HTTP error for {method} {url}: {e.response.status_code} - {e.response.text}")
except requests.exceptions.ConnectionError as e:
self._log(f"Connection error for {method} {url}: {e}")
except requests.exceptions.Timeout as e:
self._log(f"Timeout error for {method} {url}: {e}")
except requests.exceptions.RequestException as e:
self._log(f"An unexpected request error occurred for {method} {url}: {e}")
except json.JSONDecodeError:
self._log(f"Failed to decode JSON from response for {method} {url}: {response.text}")
return None
def get_price(self, symbol: str) -> Optional[float]:
"""
Fetches the current market price for a given symbol.
Returns the price as a float, or None on error.
"""
path = "/market/price"
params = {"symbol": symbol}
response = self._make_request("GET", path, params=params)
if response and "price" in response:
return float(response["price"])
self._log(f"Could not get price for {symbol}.")
return None
def place_order(
self,
symbol: str,
side: str, # "BUY" or "SELL"
quantity: float,
order_type: str = "MARKET",
price_optional: Optional[float] = None # For LIMIT orders, not used for MARKET
) -> Optional[Dict[str, Any]]:
"""
Places a trading order.
"""
path = "/orders"
payload = {
"symbol": symbol,
"side": side,
"quantity": quantity,
"order_type": order_type,
}
if order_type != "MARKET" and price_optional is not None:
payload["price_optional"] = price_optional
self._log(f"Placing {side} order: {quantity} {symbol} ({order_type})...")
response = self._make_request("POST", path, json_data=payload)
if response and response.get("status") == "accepted":
self._log(f"Order placed successfully: {response.get('order_id')}")
return response
self._log(f"Failed to place {side} order for {quantity} {symbol}. Response: {response}")
return None
def cancel_order(self, order_id: str) -> Optional[Dict[str, Any]]:
"""
Cancels an existing order.
"""
path = f"/orders/{order_id}/cancel"
self._log(f"Cancelling order {order_id}...")
response = self._make_request("POST", path)
if response and response.get("status") == "cancelled":
self._log(f"Order {order_id} cancelled.")
return response
self._log(f"Failed to cancel order {order_id}. Response: {response}")
return None
def get_balance(self) -> Optional[float]:
"""
Fetches the current cash balance.
Returns cash balance as float, or None on error.
"""
path = "/account/balance"
response = self._make_request("GET", path)
if response and "cash_balance" in response:
return float(response["cash_balance"])
self._log("Could not get account balance.")
return None
def get_positions(self) -> Optional[List[Dict[str, Any]]]:
"""
Fetches all current open positions.
Returns a list of position dictionaries, or None on error.
"""
path = "/account/positions"
response = self._make_request("GET", path)
if response is not None:
# Assuming the API returns a list, even if empty
if isinstance(response, list):
return response
else:
self._log(f"Unexpected response format for get_positions: {response}")
return []
self._log("Could not get account positions.")
return None
class TradingBot:
"""
Implements the SMA Crossover trading strategy for a simulated equities API.
"""
def __init__(self, config: TradingBotConfig, api_client: SimulatedAPIClient):
self.config = config
self.api_client = api_client
# Deque for efficient rolling window of prices
self.price_history: Deque[float] = collections.deque(
maxlen=self.config.long_sma_period
)
self.last_trade_timestamp: float = 0.0
self.current_position_quantity: float = 0.0
self.previous_short_sma: Optional[float] = None
self.previous_long_sma: Optional[float] = None
self._log(f"Trading bot initialized for symbol: {self.config.symbol}")
self._log(f"Short SMA: {self.config.short_sma_period} periods, Long SMA: {self.config.long_sma_period} periods")
if self.config.dry_run:
self._log("!!! DRY RUN MODE ACTIVE - NO REAL ORDERS WILL BE PLACED !!!")
def _log(self, message: str) -> None:
"""Simple logging utility for the bot."""
print(f"[Bot] {message}")
def _fetch_and_store_price(self) -> Optional[float]:
"""
Fetches the current price from the API and adds it to the price history.
Returns the fetched price or None if failed.
"""
price = self.api_client.get_price(self.config.symbol)
if price is not None:
self.price_history.append(price)
self._log(f"Fetched price for {self.config.symbol}: {price}. History size: {len(self.price_history)}")
return price
self._log(f"Failed to fetch current price for {self.config.symbol}.")
return None
def _calculate_sma(self, period: int) -> Optional[float]:
"""
Calculates the Simple Moving Average (SMA) for a given period
using the stored price history.
"""
if len(self.price_history) < period:
return None
# Get the last 'period' prices from the deque
# Python's deque doesn't have direct slicing like list[-period:]
# So we convert to list for slicing or iterate last 'n' elements
recent_prices = list(self.price_history)[-period:]
return sum(recent_prices) / period
def _update_current_position(self) -> None:
"""
Fetches the current position for the trading symbol from the API
and updates the bot's internal state.
"""
positions = self.api_client.get_positions()
self.current_position_quantity = 0.0
if positions:
for pos in positions:
if pos.get("symbol") == self.config.symbol:
self.current_position_quantity = float(pos.get("quantity", 0))
break
self._log(f"Current position in {self.config.symbol}: {self.current_position_quantity}")
def _can_trade(self) -> bool:
"""
Checks if enough time has passed since the last trade to place a new one.
"""
time_since_last_trade = time.time() - self.last_trade_timestamp
if time_since_last_trade < self.config.min_time_between_trades_seconds:
self._log(f"Throttling: Waiting {math.ceil(self.config.min_time_between_trades_seconds - time_since_last_trade)}s before next trade.")
return False
return True
def collect_initial_history(self) -> None:
"""
Collects an initial set of price data before starting the trading strategy.
This is crucial for calculating SMAs from the start.
"""
self._log(f"Collecting initial price history ({self.config.initial_history_collection_count} points required)...")
for i in range(self.config.initial_history_collection_count):
if self._fetch_and_store_price() is None:
self._log("Failed to collect initial price. Retrying...")
# Wait before fetching next price to simulate time passing
time.sleep(self.config.price_fetch_interval_seconds)
self._log(f"Collected {i+1}/{self.config.initial_history_collection_count} prices.")
self._log("Initial price history collection complete.")
def run_strategy_iteration(self) -> None:
"""
Executes one iteration of the SMA crossover strategy.
"""
self._log("--- Running strategy iteration ---")
# 1. Fetch current position and balance
self._update_current_position()
cash_balance = self.api_client.get_balance()
if cash_balance is None:
self._log("Could not get cash balance. Skipping iteration.")
return
# 2. Fetch new price and update history
if self._fetch_and_store_price() is None:
return # Skip iteration if price fetch fails
# 3. Ensure enough data for SMAs
if len(self.price_history) < self.config.long_sma_period:
self._log(f"Not enough price history for SMAs (need {self.config.long_sma_period}, have {len(self.price_history)}). Waiting for more data.")
return
# 4. Calculate SMAs
short_sma = self._calculate_sma(self.config.short_sma_period)
long_sma = self._calculate_sma(self.config.long_sma_period)
if short_sma is None or long_sma is None:
self._log("Could not calculate SMAs. Skipping iteration.")
return
self._log(f"Current SMAs: Short={short_sma:.2f}, Long={long_sma:.2f}")
# If this is the first time we calculated SMAs, just store them and exit
if self.previous_short_sma is None or self.previous_long_sma is None:
self._log("First SMA calculation. Storing values for next iteration comparison.")
self.previous_short_sma = short_sma
self.previous_long_sma = long_sma
return
# 5. Check for crossover signals
# Buy Signal: Short SMA crosses above Long SMA
if (self.previous_short_sma < self.previous_long_sma) and (short_sma >= long_sma):
self._log("!!! BUY SIGNAL DETECTED: Short SMA crossed above Long SMA !!!")
if self.current_position_quantity > 0:
self._log(f"Already hold a position of {self.current_position_quantity} {self.config.symbol}. No new buy order.")
elif not self._can_trade():
pass # Message already logged by _can_trade()
else:
buy_amount_dollars = cash_balance * self.config.buy_cash_fraction
# Use the most recent price for calculating quantity
current_price = self.price_history[-1]
if current_price > 0:
quantity_to_buy = math.floor(buy_amount_dollars / current_price)
if quantity_to_buy > 0:
self._log(f"Attempting to BUY {quantity_to_buy} shares of {self.config.symbol} at approx ${current_price:.2f} using ${buy_amount_dollars:.2f} of cash.")
if not self.config.dry_run:
order_response = self.api_client.place_order(self.config.symbol, "BUY", quantity_to_buy)
if order_response:
self.last_trade_timestamp = time.time()
self._update_current_position() # Refresh position after order
else:
self._log(f"DRY RUN: Would have placed BUY order for {quantity_to_buy} {self.config.symbol}.")
self.last_trade_timestamp = time.time() # Still simulate trade delay
else:
self._log("Calculated quantity to buy is zero.")
else:
self._log("Current price is zero, cannot calculate buy quantity.")
# Sell Signal: Short SMA crosses below Long SMA
elif (self.previous_short_sma > self.previous_long_sma) and (short_sma <= long_sma):
self._log("!!! SELL SIGNAL DETECTED: Short SMA crossed below Long SMA !!!")
if self.current_position_quantity == 0:
self._log("No open position to sell. No new sell order.")
elif not self._can_trade():
pass # Message already logged by _can_trade()
else:
quantity_to_sell = self.current_position_quantity
self._log(f"Attempting to SELL {quantity_to_sell} shares of {self.config.symbol}.")
if not self.config.dry_run:
order_response = self.api_client.place_order(self.config.symbol, "SELL", quantity_to_sell)
if order_response:
self.last_trade_timestamp = time.time()
self._update_current_position() # Refresh position after order
else:
self._log(f"DRY RUN: Would have placed SELL order for {quantity_to_sell} {self.config.symbol}.")
self.last_trade_timestamp = time.time() # Still simulate trade delay
else:
self._log("No crossover signal detected.")
# 6. Update previous SMA values for the next iteration
self.previous_short_sma = short_sma
self.previous_long_sma = long_sma
self._log("--- Iteration complete ---")
def main():
"""
Main function to parse arguments, configure the bot, and run the trading loop.
"""
parser = argparse.ArgumentParser(
description="SMA Crossover Trading Bot for Simulated Equities API."
)
parser.add_argument(
"--api-key",
type=str,
default=DEFAULT_API_KEY,
help=f"Your API key for the simulator. Default: '{DEFAULT_API_KEY}' (or SIM_API_KEY env var)"
)
parser.add_argument(
"--base-url",
type=str,
default=DEFAULT_BASE_URL,
help=f"Base URL of the simulated API. Default: {DEFAULT_BASE_URL}"
)
parser.add_argument(
"--symbol",
type=str,
default=DEFAULT_SYMBOL,
help=f"Trading symbol (e.g., AAPL). Default: {DEFAULT_SYMBOL}"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, the bot will log trade actions instead of placing real orders."
)
parser.add_argument(
"--short-sma-period",
type=int,
default=DEFAULT_SHORT_SMA_PERIOD,
help=f"Number of periods for the short SMA. Default: {DEFAULT_SHORT_SMA_PERIOD}"
)
parser.add_argument(
"--long-sma-period",
type=int,
default=DEFAULT_LONG_SMA_PERIOD,
help=f"Number of periods for the long SMA. Default: {DEFAULT_LONG_SMA_PERIOD}"
)
parser.add_argument(
"--buy-cash-fraction",
type=float,
default=DEFAULT_BUY_CASH_FRACTION,
help=f"Fraction of available cash to use for a BUY order (e.g., 0.95). Default: {DEFAULT_BUY_CASH_FRACTION}"
)
parser.add_argument(
"--price-fetch-interval",
type=int,
default=DEFAULT_PRICE_FETCH_INTERVAL_SECONDS,
help=f"Interval in seconds to fetch new price data for SMA calculation. Default: {DEFAULT_PRICE_FETCH_INTERVAL_SECONDS}"
)
parser.add_argument(
"--main-loop-interval",
type=int,
default=DEFAULT_MAIN_LOOP_INTERVAL_SECONDS,
help=f"Interval in seconds between strategy evaluations. Default: {DEFAULT_MAIN_LOOP_INTERVAL_SECONDS}"
)
parser.add_argument(
"--min-trade-interval",
type=int,
default=DEFAULT_MIN_TIME_BETWEEN_TRADES_SECONDS,
help=f"Minimum time in seconds between placing actual orders. Default: {DEFAULT_MIN_TIME_BETWEEN_TRADES_SECONDS}"
)
parser.add_argument(
"--initial-history-count",
type=int,
default=DEFAULT_INITIAL_HISTORY_COLLECTION_COUNT,
help=f"Number of initial price points to collect before starting strategy. Default: {DEFAULT_INITIAL_HISTORY_COLLECTION_COUNT}"
)
parser.add_argument(
"--run-duration",
type=int,
default=300, # Default to 5 minutes for demonstration
help="Total duration in seconds to run the bot loop. (0 for indefinite run)."
)
args = parser.parse_args()
if args.api_key == "<YOUR_SIM_API_KEY>":
print("WARNING: API Key is not set. Please replace <YOUR_SIM_API_KEY> in the script or set SIM_API_KEY environment variable, or pass with --api-key.")
print("Exiting...")
return
config = TradingBotConfig(
api_key=args.api_key,
base_url=args.base_url,
symbol=args.symbol,
short_sma_period=args.short_sma_period,
long_sma_period=args.long_sma_period,
buy_cash_fraction=args.buy_cash_fraction,
price_fetch_interval_seconds=args.price_fetch_interval,
main_loop_interval_seconds=args.main_loop_interval,
min_time_between_trades_seconds=args.min_trade_interval,
initial_history_collection_count=args.initial_history_count,
dry_run=args.dry_run,
)
api_client = SimulatedAPIClient(config.base_url, config.api_key)
trading_bot = TradingBot(config, api_client)
# Ensure enough history for SMA calculations
if config.initial_history_collection_count < config.long_sma_period:
trading_bot._log(f"WARNING: Initial history collection count ({config.initial_history_collection_count}) is less than long SMA period ({config.long_sma_period}). Adjusting to {config.long_sma_period + 5}.")
config.initial_history_collection_count = config.long_sma_period + 5
# Collect initial price data
trading_bot.collect_initial_history()
# Main trading loop
start_time = time.time()
iteration = 0
trading_bot._log(f"Starting main trading loop for {args.run_duration} seconds (0 for indefinite)...")
try:
while True:
iteration += 1
trading_bot._log(f"\n--- Main Loop Iteration {iteration} ---")
trading_bot.run_strategy_iteration()
if args.run_duration > 0 and (time.time() - start_time) >= args.run_duration:
trading_bot._log(f"Run duration of {args.run_duration} seconds completed. Exiting.")
break
trading_bot._log(f"Sleeping for {config.main_loop_interval_seconds} seconds...")
time.sleep(config.main_loop_interval_seconds)
except KeyboardInterrupt:
trading_bot._log("Bot stopped manually by user (KeyboardInterrupt).")
except Exception as e:
trading_bot._log(f"An unexpected error occurred in the main loop: {e}")
finally:
trading_bot._log("Trading bot shutting down.")
if __name__ == "__main__":
# --- Demonstration Run ---
# To run this example:
# 1. Save this script as `generated_trading_bot.py`.
# 2. Install requests: `pip install requests`
# 3. Replace `<YOUR_SIM_API_KEY>` with an actual API key or set the SIM_API_KEY environment variable.
# 4. Run from your terminal:
# `python generated_trading_bot.py --dry-run --run-duration 60 --symbol MSFT`
# This will simulate a 60-second run for MSFT in dry-run mode,
# printing potential trades without actually executing them.
# For a longer run, change --run-duration (e.g., 3600 for 1 hour).
# Remove --dry-run to enable live trading (use with caution!).
main()

View File

@@ -0,0 +1,690 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "4a6ab9a2-28a2-445d-8512-a0dc8d1b54e9",
"metadata": {},
"source": [
"# Code Generator\n",
"\n",
"The requirement: use a Frontier model to generate high performance C++ code from Python code\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f72dfaf-9f20-4d81-b082-018eda152c9f",
"metadata": {},
"outputs": [],
"source": [
"!pip install -U -q \"google-genai\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e610bf56-a46e-4aff-8de1-ab49d62b1ad3",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import io\n",
"import sys\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"from google import genai\n",
"from google.genai import types\n",
"import anthropic\n",
"from IPython.display import Markdown, display, update_display\n",
"import gradio as gr\n",
"import subprocess"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4f672e1c-87e9-4865-b760-370fa605e614",
"metadata": {},
"outputs": [],
"source": [
"# environment\n",
"\n",
"load_dotenv(override=True)\n",
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
"google_api_key = os.getenv('GOOGLE_API_KEY')\n",
"\n",
"if openai_api_key:\n",
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
"else:\n",
" print(\"OpenAI API Key not set\")\n",
" \n",
"if anthropic_api_key:\n",
" print(f\"Anthropic API Key exists and begins {anthropic_api_key[:7]}\")\n",
"else:\n",
" print(\"Anthropic API Key not set\")\n",
"\n",
"if google_api_key:\n",
" print(f\"Google API Key exists and begins {google_api_key[:8]}\")\n",
"else:\n",
" print(\"Google API Key not set\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8aa149ed-9298-4d69-8fe2-8f5de0f667da",
"metadata": {},
"outputs": [],
"source": [
"# initialize\n",
"\n",
"openai = OpenAI()\n",
"claude = anthropic.Anthropic()\n",
"gemini = genai.Client()\n",
"\n",
"OPENAI_MODEL = \"o4-mini\"\n",
"CLAUDE_MODEL = \"claude-3-7-sonnet-latest\"\n",
"GEMINI_MODEL = \"gemini-2.5-flash\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6896636f-923e-4a2c-9d6c-fac07828a201",
"metadata": {},
"outputs": [],
"source": [
"system_message = \"You are an assistant that reimplements Python code in high performance C++ for an M1 Mac. \"\n",
"system_message += \"Respond only with C++ code; use comments sparingly and do not provide any explanation other than occasional comments. \"\n",
"system_message += \"The C++ response needs to produce an identical output in the fastest possible time.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e7b3546-57aa-4c29-bc5d-f211970d04eb",
"metadata": {},
"outputs": [],
"source": [
"def user_prompt_for(python):\n",
" user_prompt = \"Rewrite this Python code in C++ with the fastest possible implementation that produces identical output in the least time. \"\n",
" user_prompt += \"Respond only with C++ code; do not explain your work other than a few comments. \"\n",
" user_prompt += \"Pay attention to number types to ensure no int overflows. Remember to #include all necessary C++ packages such as iomanip.\\n\\n\"\n",
" user_prompt += python\n",
" return user_prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6190659-f54c-4951-bef4-4960f8e51cc4",
"metadata": {},
"outputs": [],
"source": [
"def messages_for(python):\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt_for(python)}\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71e1ba8c-5b05-4726-a9f3-8d8c6257350b",
"metadata": {},
"outputs": [],
"source": [
"# write to a file called optimized.cpp\n",
"\n",
"def write_output(cpp):\n",
" code = cpp.replace(\"```cpp\",\"\").replace(\"```\",\"\")\n",
" with open(\"optimized.cpp\", \"w\") as f:\n",
" f.write(code)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7d2fea8-74c6-4421-8f1e-0e76d5b201b9",
"metadata": {},
"outputs": [],
"source": [
"def optimize_gpt(python): \n",
" stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python), stream=True)\n",
" reply = \"\"\n",
" for chunk in stream:\n",
" fragment = chunk.choices[0].delta.content or \"\"\n",
" reply += fragment\n",
" print(fragment, end='', flush=True)\n",
" write_output(reply)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7cd84ad8-d55c-4fe0-9eeb-1895c95c4a9d",
"metadata": {},
"outputs": [],
"source": [
"def optimize_claude(python):\n",
" result = claude.messages.stream(\n",
" model=CLAUDE_MODEL,\n",
" max_tokens=2000,\n",
" system=system_message,\n",
" messages=[{\"role\": \"user\", \"content\": user_prompt_for(python)}],\n",
" )\n",
" reply = \"\"\n",
" with result as stream:\n",
" for text in stream.text_stream:\n",
" reply += text\n",
" print(text, end=\"\", flush=True)\n",
" write_output(reply)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8a35102-1c95-469b-8855-e85f4c9bdbdf",
"metadata": {},
"outputs": [],
"source": [
"def optimize_gemini(python):\n",
" reply = gemini.models.generate_content(\n",
" model=GEMINI_MODEL,\n",
" contents=user_prompt_for(python),\n",
" config=types.GenerateContentConfig(\n",
" system_instruction=system_message,\n",
" )\n",
" )\n",
"\n",
" print(reply.text)\n",
" write_output(reply.text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1cbb778-fa57-43de-b04b-ed523f396c38",
"metadata": {},
"outputs": [],
"source": [
"pi = \"\"\"\n",
"import time\n",
"\n",
"def calculate(iterations, param1, param2):\n",
" result = 1.0\n",
" for i in range(1, iterations+1):\n",
" j = i * param1 - param2\n",
" result -= (1/j)\n",
" j = i * param1 + param2\n",
" result += (1/j)\n",
" return result\n",
"\n",
"start_time = time.time()\n",
"result = calculate(100_000_000, 4, 1) * 4\n",
"end_time = time.time()\n",
"\n",
"print(f\"Result: {result:.12f}\")\n",
"print(f\"Execution Time: {(end_time - start_time):.6f} seconds\")\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7fe1cd4b-d2c5-4303-afed-2115a3fef200",
"metadata": {},
"outputs": [],
"source": [
"exec(pi)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "105db6f9-343c-491d-8e44-3a5328b81719",
"metadata": {},
"outputs": [],
"source": [
"optimize_gpt(pi)"
]
},
{
"cell_type": "markdown",
"id": "bf8f8018-f64d-425c-a0e1-d7862aa9592d",
"metadata": {},
"source": [
"# Compiling C++ and executing\n",
"\n",
"This next cell contains the command to compile a C++ file on my M1 Mac. \n",
"It compiles the file `optimized.cpp` into an executable called `optimized` \n",
"Then it runs the program called `optimized`\n",
"\n",
"In the next lab (day4), a student has contributed a full solution that compiles to efficient code on Mac, PC and Linux!\n",
"\n",
"You can wait for this, or you can google (or ask ChatGPT!) for how to do this on your platform, then replace the lines below.\n",
"If you're not comfortable with this step, you can skip it for sure - I'll show you exactly how it performs on my Mac.\n",
"\n",
"\n",
"OR alternatively: student Sandeep K.G. points out that you can run Python and C++ code online to test it out that way. Thank you Sandeep! \n",
"> Not an exact comparison but you can still get the idea of performance difference.\n",
"> For example here: https://www.programiz.com/cpp-programming/online-compiler/"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4194e40c-04ab-4940-9d64-b4ad37c5bb40",
"metadata": {},
"outputs": [],
"source": [
"# Compile C++ and run the executable\n",
"\n",
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
"!./optimized"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "983a11fe-e24d-4c65-8269-9802c5ef3ae6",
"metadata": {},
"outputs": [],
"source": [
"optimize_claude(pi)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5a766f9-3d23-4bb4-a1d4-88ec44b61ddf",
"metadata": {},
"outputs": [],
"source": [
"# Repeat for Claude - again, use the right approach for your platform\n",
"\n",
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
"!./optimized"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01f331f2-caac-48f6-9a03-8a228ee521bc",
"metadata": {},
"outputs": [],
"source": [
"optimize_gemini(pi)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ef707a4-930e-4b8b-9443-e7e4fd309c2a",
"metadata": {},
"outputs": [],
"source": [
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
"!./optimized"
]
},
{
"cell_type": "markdown",
"id": "7d1eaeca-61be-4d0a-a525-dd09f52aaa0f",
"metadata": {},
"source": [
"# Python Hard Version"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3b497b3-f569-420e-b92e-fb0f49957ce0",
"metadata": {},
"outputs": [],
"source": [
"python_hard = \"\"\"# Be careful to support large number sizes\n",
"\n",
"def lcg(seed, a=1664525, c=1013904223, m=2**32):\n",
" value = seed\n",
" while True:\n",
" value = (a * value + c) % m\n",
" yield value\n",
" \n",
"def max_subarray_sum(n, seed, min_val, max_val):\n",
" lcg_gen = lcg(seed)\n",
" random_numbers = [next(lcg_gen) % (max_val - min_val + 1) + min_val for _ in range(n)]\n",
" max_sum = float('-inf')\n",
" for i in range(n):\n",
" current_sum = 0\n",
" for j in range(i, n):\n",
" current_sum += random_numbers[j]\n",
" if current_sum > max_sum:\n",
" max_sum = current_sum\n",
" return max_sum\n",
"\n",
"def total_max_subarray_sum(n, initial_seed, min_val, max_val):\n",
" total_sum = 0\n",
" lcg_gen = lcg(initial_seed)\n",
" for _ in range(20):\n",
" seed = next(lcg_gen)\n",
" total_sum += max_subarray_sum(n, seed, min_val, max_val)\n",
" return total_sum\n",
"\n",
"# Parameters\n",
"n = 10000 # Number of random numbers\n",
"initial_seed = 42 # Initial seed for the LCG\n",
"min_val = -10 # Minimum value of random numbers\n",
"max_val = 10 # Maximum value of random numbers\n",
"\n",
"# Timing the function\n",
"import time\n",
"start_time = time.time()\n",
"result = total_max_subarray_sum(n, initial_seed, min_val, max_val)\n",
"end_time = time.time()\n",
"\n",
"print(\"Total Maximum Subarray Sum (20 runs):\", result)\n",
"print(\"Execution Time: {:.6f} seconds\".format(end_time - start_time))\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dab5e4bc-276c-4555-bd4c-12c699d5e899",
"metadata": {},
"outputs": [],
"source": [
"exec(python_hard)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8d24ed5-2c15-4f55-80e7-13a3952b3cb8",
"metadata": {},
"outputs": [],
"source": [
"optimize_gpt(python_hard)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0b3d073-88a2-40b2-831c-6f0c345c256f",
"metadata": {},
"outputs": [],
"source": [
"# Replace this with the right C++ compile + execute command for your platform\n",
"\n",
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
"!./optimized"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9305446-1d0c-4b51-866a-b8c1e299bf5c",
"metadata": {},
"outputs": [],
"source": [
"optimize_gemini(python_hard)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c181036-8193-4fdd-aef3-fc513b218d43",
"metadata": {},
"outputs": [],
"source": [
"# Replace this with the right C++ compile + execute command for your platform\n",
"\n",
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
"!./optimized"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ee75e72-9ecb-4edd-a74a-4d3a83c1eb79",
"metadata": {},
"outputs": [],
"source": [
"optimize_claude(python_hard)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a4ab43c-7df2-4770-bd05-6bbc198a8c45",
"metadata": {},
"outputs": [],
"source": [
"# Replace this with the right C++ compile + execute command for your platform\n",
"\n",
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
"!./optimized"
]
},
{
"cell_type": "markdown",
"id": "ff02ce09-0544-49a5-944d-a57b25bf9b72",
"metadata": {},
"source": [
"# Streaming"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0be9f47d-5213-4700-b0e2-d444c7c738c0",
"metadata": {},
"outputs": [],
"source": [
"def stream_gpt(python): \n",
" stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python), stream=True)\n",
" reply = \"\"\n",
" for chunk in stream:\n",
" fragment = chunk.choices[0].delta.content or \"\"\n",
" reply += fragment\n",
" yield reply.replace('```cpp\\n','').replace('```','')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8669f56b-8314-4582-a167-78842caea131",
"metadata": {},
"outputs": [],
"source": [
"def stream_claude(python):\n",
" result = claude.messages.stream(\n",
" model=CLAUDE_MODEL,\n",
" max_tokens=2000,\n",
" system=system_message,\n",
" messages=[{\"role\": \"user\", \"content\": user_prompt_for(python)}],\n",
" )\n",
" reply = \"\"\n",
" with result as stream:\n",
" for text in stream.text_stream:\n",
" reply += text\n",
" yield reply.replace('```cpp\\n','').replace('```','')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d48d44df-c082-4ed1-b3ea-fc2a880591c2",
"metadata": {},
"outputs": [],
"source": [
"def stream_gemini(python):\n",
" stream = gemini.models.generate_content_stream(\n",
" model=GEMINI_MODEL,\n",
" contents=user_prompt_for(python),\n",
" config=types.GenerateContentConfig(\n",
" system_instruction=system_message,\n",
" ),\n",
" )\n",
" reply = \"\"\n",
" for chunk in stream:\n",
" reply += chunk.text\n",
" yield reply.replace('```cpp\\n','').replace('```','')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f1ae8f5-16c8-40a0-aa18-63b617df078d",
"metadata": {},
"outputs": [],
"source": [
"def optimize(python, model):\n",
" if model==\"GPT\":\n",
" result = stream_gpt(python)\n",
" elif model==\"Claude\":\n",
" result = stream_claude(python)\n",
" elif model==\"Gemini\":\n",
" result = stream_gemini(python)\n",
" else:\n",
" raise ValueError(\"Unknown model\")\n",
" for stream_so_far in result:\n",
" yield stream_so_far "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1ddb38e-6b0a-4c37-baa4-ace0b7de887a",
"metadata": {},
"outputs": [],
"source": [
"with gr.Blocks() as ui:\n",
" with gr.Row():\n",
" python = gr.Textbox(label=\"Python code:\", lines=10, value=python_hard)\n",
" cpp = gr.Textbox(label=\"C++ code:\", lines=10)\n",
" with gr.Row():\n",
" model = gr.Dropdown([\"GPT\", \"Claude\", \"Gemini\"], label=\"Select model\", value=\"GPT\")\n",
" convert = gr.Button(\"Convert code\")\n",
"\n",
" convert.click(optimize, inputs=[python, model], outputs=[cpp])\n",
"\n",
"ui.launch(inbrowser=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19bf2bff-a822-4009-a539-f003b1651383",
"metadata": {},
"outputs": [],
"source": [
"def execute_python(code):\n",
" try:\n",
" output = io.StringIO()\n",
" sys.stdout = output\n",
" exec(code)\n",
" finally:\n",
" sys.stdout = sys.__stdout__\n",
" return output.getvalue()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77f3ab5d-fcfb-4d3f-8728-9cacbf833ea6",
"metadata": {},
"outputs": [],
"source": [
"# M1 Mac version to compile and execute optimized C++ code:\n",
"\n",
"def execute_cpp(code):\n",
" write_output(code)\n",
" try:\n",
" compile_cmd = [\"clang++\", \"-Ofast\", \"-std=c++17\", \"-march=armv8.5-a\", \"-mtune=apple-m1\", \"-mcpu=apple-m1\", \"-o\", \"optimized\", \"optimized.cpp\"]\n",
" compile_result = subprocess.run(compile_cmd, check=True, text=True, capture_output=True)\n",
" run_cmd = [\"./optimized\"]\n",
" run_result = subprocess.run(run_cmd, check=True, text=True, capture_output=True)\n",
" return run_result.stdout\n",
" except subprocess.CalledProcessError as e:\n",
" return f\"An error occurred:\\n{e.stderr}\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a2274f1-d03b-42c0-8dcc-4ce159b18442",
"metadata": {},
"outputs": [],
"source": [
"css = \"\"\"\n",
".python {background-color: #306998;}\n",
".cpp {background-color: #050;}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1303932-160c-424b-97a8-d28c816721b2",
"metadata": {},
"outputs": [],
"source": [
"with gr.Blocks(css=css) as ui:\n",
" gr.Markdown(\"## Convert code from Python to C++\")\n",
" with gr.Row():\n",
" python = gr.Textbox(label=\"Python code:\", value=python_hard, lines=20)\n",
" cpp = gr.Textbox(label=\"C++ code:\", lines=20)\n",
" with gr.Row():\n",
" model = gr.Dropdown([\"GPT\", \"Claude\", \"Gemini\"], label=\"Select model\", value=\"GPT\")\n",
" convert = gr.Button(\"Convert code\")\n",
" with gr.Row():\n",
" python_run = gr.Button(\"Run Python\")\n",
" cpp_run = gr.Button(\"Run C++\")\n",
" with gr.Row():\n",
" python_out = gr.TextArea(label=\"Python result:\", elem_classes=[\"python\"])\n",
" cpp_out = gr.TextArea(label=\"C++ result:\", elem_classes=[\"cpp\"])\n",
"\n",
" convert.click(optimize, inputs=[python, model], outputs=[cpp])\n",
" python_run.click(execute_python, inputs=[python], outputs=[python_out])\n",
" cpp_run.click(execute_cpp, inputs=[cpp], outputs=[cpp_out])\n",
"\n",
"ui.launch(inbrowser=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea42883b-fdba-46ed-97be-f42e3cb41f11",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}