Merge pull request #669 from Krabulek/week-4-contributions
Week 4 excercises: added Gemini and Python Code Documentation Assistant
This commit is contained in:
@@ -0,0 +1,828 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4a6ab9a2-28a2-445d-8512-a0dc8d1b54e9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Python Code Documentation Assistant\n",
|
||||
"\n",
|
||||
"The requirement: use a Frontier model to add docstrings and comments to your Python code\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d4634170-c444-4326-9e68-5f87c63fa0e0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1f72dfaf-9f20-4d81-b082-018eda152c9f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -U -q \"google-genai\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e610bf56-a46e-4aff-8de1-ab49d62b1ad3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import io\n",
|
||||
"import sys\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from openai import OpenAI\n",
|
||||
"from google import genai\n",
|
||||
"from google.genai import types\n",
|
||||
"import anthropic\n",
|
||||
"from IPython.display import Markdown, display, update_display\n",
|
||||
"import gradio as gr\n",
|
||||
"import subprocess"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f91e8b32-4c98-4210-a1e1-bfe0b1fddab7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4f672e1c-87e9-4865-b760-370fa605e614",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"load_dotenv(override=True)\n",
|
||||
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
||||
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
|
||||
"google_api_key = os.getenv('GOOGLE_API_KEY')\n",
|
||||
"\n",
|
||||
"if openai_api_key:\n",
|
||||
" print(f\"OpenAI API Key exists and begins with: {openai_api_key[:8]}\")\n",
|
||||
"else:\n",
|
||||
" print(\"OpenAI API Key not set\")\n",
|
||||
" \n",
|
||||
"if anthropic_api_key:\n",
|
||||
" print(f\"Anthropic API Key exists and begins with: {anthropic_api_key[:7]}\")\n",
|
||||
"else:\n",
|
||||
" print(\"Anthropic API Key not set\")\n",
|
||||
"\n",
|
||||
"if google_api_key:\n",
|
||||
" print(f\"Google API Key exists and begins with: {google_api_key[:4]}\")\n",
|
||||
"else:\n",
|
||||
" print(\"Google API Key not set\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8aa149ed-9298-4d69-8fe2-8f5de0f667da",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"openai = OpenAI()\n",
|
||||
"claude = anthropic.Anthropic()\n",
|
||||
"gemini = genai.Client()\n",
|
||||
"\n",
|
||||
"OPENAI_MODEL = \"o4-mini\"\n",
|
||||
"CLAUDE_MODEL = \"claude-3-7-sonnet-latest\"\n",
|
||||
"GEMINI_MODEL = \"gemini-2.5-flash\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "88a18c58-40d5-4592-8dd3-d7c7b0d951aa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prompts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6896636f-923e-4a2c-9d6c-fac07828a201",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"system_message = \"\"\"\n",
|
||||
"You are an assistant that documents Python code. \n",
|
||||
"Your task: \n",
|
||||
"- Add concise, clear, and informative docstrings to functions, classes, and modules. \n",
|
||||
"- Add inline comments only where they improve readability or clarify intent. \n",
|
||||
"- Do not modify the code logic or structure. \n",
|
||||
"- Respond with Python code only. \n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8e7b3546-57aa-4c29-bc5d-f211970d04eb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def user_prompt_for(python):\n",
|
||||
" user_prompt = \"Add docstrings and comments to the following Python code:\\n\"\n",
|
||||
" user_prompt += python\n",
|
||||
" return user_prompt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c6190659-f54c-4951-bef4-4960f8e51cc4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def messages_for(python):\n",
|
||||
" return [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt_for(python)}\n",
|
||||
" ]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "624e5066-bcf6-490d-a790-608d2bb34184",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Helper functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "71e1ba8c-5b05-4726-a9f3-8d8c6257350b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def write_output(python, filename_suffix):\n",
|
||||
" filename = f\"annotated_{filename_suffix}.py\"\n",
|
||||
" code = python.replace(\"```python\",\"\").replace(\"```\",\"\")\n",
|
||||
" with open(filename, \"w\") as f:\n",
|
||||
" f.write(code)\n",
|
||||
" print(f\"\\nWritten code to {filename}\")\n",
|
||||
" return filename"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e7d2fea8-74c6-4421-8f1e-0e76d5b201b9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def annotate_with_gpt(python, task_name): \n",
|
||||
" stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python), stream=True)\n",
|
||||
" reply = \"\"\n",
|
||||
" for chunk in stream:\n",
|
||||
" fragment = chunk.choices[0].delta.content or \"\"\n",
|
||||
" reply += fragment\n",
|
||||
" print(fragment, end='', flush=True)\n",
|
||||
" return write_output(reply, f\"{task_name}_gpt\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7cd84ad8-d55c-4fe0-9eeb-1895c95c4a9d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def annotate_with_claude(python, task_name):\n",
|
||||
" result = claude.messages.stream(\n",
|
||||
" model=CLAUDE_MODEL,\n",
|
||||
" max_tokens=2000,\n",
|
||||
" system=system_message,\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": user_prompt_for(python)}],\n",
|
||||
" )\n",
|
||||
" reply = \"\"\n",
|
||||
" with result as stream:\n",
|
||||
" for text in stream.text_stream:\n",
|
||||
" reply += text\n",
|
||||
" print(text, end=\"\", flush=True)\n",
|
||||
" return write_output(reply, f\"{task_name}_claude\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e8a35102-1c95-469b-8855-e85f4c9bdbdf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def annotate_with_gemini(python, task_name):\n",
|
||||
" reply = gemini.models.generate_content(\n",
|
||||
" model=GEMINI_MODEL,\n",
|
||||
" contents=user_prompt_for(python),\n",
|
||||
" config=types.GenerateContentConfig(\n",
|
||||
" system_instruction=system_message,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" print(reply.text)\n",
|
||||
" return write_output(reply.text, f\"{task_name}_gemini\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "028dcfdd-2d52-4e11-a79e-2214a97cb26d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run the Annotator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7462d9f9-6215-4fb0-9471-1d0141d33205",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Pi example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a1cbb778-fa57-43de-b04b-ed523f396c38",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pi = \"\"\"\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"def calculate(iterations, param1, param2):\n",
|
||||
" result = 1.0\n",
|
||||
" for i in range(1, iterations+1):\n",
|
||||
" j = i * param1 - param2\n",
|
||||
" result -= (1/j)\n",
|
||||
" j = i * param1 + param2\n",
|
||||
" result += (1/j)\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
"start_time = time.time()\n",
|
||||
"result = calculate(100_000_000, 4, 1) * 4\n",
|
||||
"end_time = time.time()\n",
|
||||
"\n",
|
||||
"print(f\"Result: {result:.12f}\")\n",
|
||||
"print(f\"Execution Time: {(end_time - start_time):.6f} seconds\")\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "105db6f9-343c-491d-8e44-3a5328b81719",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gpt_pi = annotate_with_gpt(pi, \"pi))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "415819d0-fc95-4f78-a6ae-5c7d6781c6a7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if the script works\n",
|
||||
"\n",
|
||||
"exec(open(gpt_pi).read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "983a11fe-e24d-4c65-8269-9802c5ef3ae6",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"claude_pi = annotate_with_claude(pi, \"pi\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "52f5b710-0dea-4884-8ed7-a94059d88281",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exec(open(claude_pi).read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "01f331f2-caac-48f6-9a03-8a228ee521bc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gemini_pi = annotate_with_gemini(pi, \"pi\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "23529942-53fa-46ad-a5db-1f3096dd6607",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exec(open(gemini_pi).read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7d1eaeca-61be-4d0a-a525-dd09f52aaa0f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Hard example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c3b497b3-f569-420e-b92e-fb0f49957ce0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"python_hard = \"\"\"# Be careful to support large number sizes\n",
|
||||
"\n",
|
||||
"def lcg(seed, a=1664525, c=1013904223, m=2**32):\n",
|
||||
" value = seed\n",
|
||||
" while True:\n",
|
||||
" value = (a * value + c) % m\n",
|
||||
" yield value\n",
|
||||
" \n",
|
||||
"def max_subarray_sum(n, seed, min_val, max_val):\n",
|
||||
" lcg_gen = lcg(seed)\n",
|
||||
" random_numbers = [next(lcg_gen) % (max_val - min_val + 1) + min_val for _ in range(n)]\n",
|
||||
" max_sum = float('-inf')\n",
|
||||
" for i in range(n):\n",
|
||||
" current_sum = 0\n",
|
||||
" for j in range(i, n):\n",
|
||||
" current_sum += random_numbers[j]\n",
|
||||
" if current_sum > max_sum:\n",
|
||||
" max_sum = current_sum\n",
|
||||
" return max_sum\n",
|
||||
"\n",
|
||||
"def total_max_subarray_sum(n, initial_seed, min_val, max_val):\n",
|
||||
" total_sum = 0\n",
|
||||
" lcg_gen = lcg(initial_seed)\n",
|
||||
" for _ in range(20):\n",
|
||||
" seed = next(lcg_gen)\n",
|
||||
" total_sum += max_subarray_sum(n, seed, min_val, max_val)\n",
|
||||
" return total_sum\n",
|
||||
"\n",
|
||||
"# Parameters\n",
|
||||
"n = 10000 # Number of random numbers\n",
|
||||
"initial_seed = 42 # Initial seed for the LCG\n",
|
||||
"min_val = -10 # Minimum value of random numbers\n",
|
||||
"max_val = 10 # Maximum value of random numbers\n",
|
||||
"\n",
|
||||
"# Timing the function\n",
|
||||
"import time\n",
|
||||
"start_time = time.time()\n",
|
||||
"result = total_max_subarray_sum(n, initial_seed, min_val, max_val)\n",
|
||||
"end_time = time.time()\n",
|
||||
"\n",
|
||||
"print(\"Total Maximum Subarray Sum (20 runs):\", result)\n",
|
||||
"print(\"Execution Time: {:.6f} seconds\".format(end_time - start_time))\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dab5e4bc-276c-4555-bd4c-12c699d5e899",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exec(python_hard)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e8d24ed5-2c15-4f55-80e7-13a3952b3cb8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gpt_hard = annotate_with_gpt(python_hard, \"hard\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80a15259-3d51-47b8-953c-6271fbd4b6fb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exec(open(gpt_hard).read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e9305446-1d0c-4b51-866a-b8c1e299bf5c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gemini_hard = annotate_with_gemini(python_hard, \"hard\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ad6eecc8-0517-43d8-bd21-5bbdedae7a10",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exec(open(gemini_hard).read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2ee75e72-9ecb-4edd-a74a-4d3a83c1eb79",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"claude_hard = annotate_with_claude(python_hard, \"hard\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47af1516-455f-4d1c-8a1c-2da5a38c0ba5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exec(open(claude_hard).read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7f60d33c-f6b7-4fc5-bc2b-57957b076e34",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"\n",
|
||||
"This module implements a Linear Congruential Generator (LCG) and uses it\n",
|
||||
"to generate random numbers for calculating the maximum subarray sum.\n",
|
||||
"It includes functions for the LCG, finding the maximum subarray sum, and\n",
|
||||
"aggregating results over multiple runs.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"def lcg(seed, a=1664525, c=1013904223, m=2**32):\n",
|
||||
" \"\"\"\n",
|
||||
" Implements a Linear Congruential Generator (LCG) to produce a sequence of\n",
|
||||
" pseudorandom numbers.\n",
|
||||
"\n",
|
||||
" The generator uses the formula: X_{n+1} = (a * X_n + c) % m.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" seed (int): The initial seed value for the generator (X_0).\n",
|
||||
" a (int, optional): The multiplier. Defaults to 1664525 (common LCG parameter).\n",
|
||||
" c (int, optional): The increment. Defaults to 1013904223 (common LCG parameter).\n",
|
||||
" m (int, optional): The modulus. Defaults to 2**32, meaning numbers will be\n",
|
||||
" between 0 and m-1.\n",
|
||||
"\n",
|
||||
" Yields:\n",
|
||||
" int: The next pseudorandom number in the sequence.\n",
|
||||
" \"\"\"\n",
|
||||
" value = seed\n",
|
||||
" while True:\n",
|
||||
" # Calculate the next pseudorandom number using the LCG formula.\n",
|
||||
" value = (a * value + c) % m\n",
|
||||
" yield value\n",
|
||||
"\n",
|
||||
"def max_subarray_sum(n, seed, min_val, max_val):\n",
|
||||
" \"\"\"\n",
|
||||
" Calculates the maximum possible sum of a contiguous subarray within a list\n",
|
||||
" of 'n' pseudorandom numbers.\n",
|
||||
"\n",
|
||||
" The random numbers are generated using an LCG based on the provided seed,\n",
|
||||
" and then mapped to the range [min_val, max_val].\n",
|
||||
" This implementation uses a brute-force approach with O(n^2) complexity.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" n (int): The number of random integers to generate for the array.\n",
|
||||
" seed (int): The seed for the LCG to generate the random numbers.\n",
|
||||
" min_val (int): The minimum possible value for the generated random numbers.\n",
|
||||
" max_val (int): The maximum possible value for the generated random numbers.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" int: The maximum sum found among all contiguous subarrays.\n",
|
||||
" \"\"\"\n",
|
||||
" lcg_gen = lcg(seed)\n",
|
||||
" # Generate a list of 'n' random numbers within the specified range [min_val, max_val].\n",
|
||||
" random_numbers = [next(lcg_gen) % (max_val - min_val + 1) + min_val for _ in range(n)]\n",
|
||||
"\n",
|
||||
" max_sum = float('-inf') # Initialize max_sum to negative infinity to handle all negative numbers.\n",
|
||||
"\n",
|
||||
" # Iterate through all possible starting points of a subarray.\n",
|
||||
" for i in range(n):\n",
|
||||
" current_sum = 0\n",
|
||||
" # Iterate through all possible ending points for the current starting point.\n",
|
||||
" for j in range(i, n):\n",
|
||||
" current_sum += random_numbers[j]\n",
|
||||
" # Update max_sum if the current subarray sum is greater.\n",
|
||||
" if current_sum > max_sum:\n",
|
||||
" max_sum = current_sum\n",
|
||||
" return max_sum\n",
|
||||
"\n",
|
||||
"def total_max_subarray_sum(n, initial_seed, min_val, max_val):\n",
|
||||
" \"\"\"\n",
|
||||
" Calculates the sum of maximum subarray sums over 20 separate runs.\n",
|
||||
"\n",
|
||||
" Each run generates a new set of 'n' random numbers for `max_subarray_sum`\n",
|
||||
" using a new seed derived from the initial LCG sequence.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" n (int): The number of random integers for each subarray sum calculation.\n",
|
||||
" initial_seed (int): The initial seed for the LCG that generates seeds\n",
|
||||
" for individual `max_subarray_sum` runs.\n",
|
||||
" min_val (int): The minimum possible value for random numbers in each run.\n",
|
||||
" max_val (int): The maximum possible value for random numbers in each run.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" int: The sum of the maximum subarray sums across all 20 runs.\n",
|
||||
" \"\"\"\n",
|
||||
" total_sum = 0\n",
|
||||
" lcg_gen = lcg(initial_seed) # LCG to generate seeds for subsequent runs.\n",
|
||||
" # Perform 20 independent runs.\n",
|
||||
" for _ in range(20):\n",
|
||||
" # Get a new seed for each run from the initial LCG generator.\n",
|
||||
" seed = next(lcg_gen)\n",
|
||||
" # Add the maximum subarray sum of the current run to the total sum.\n",
|
||||
" total_sum += max_subarray_sum(n, seed, min_val, max_val)\n",
|
||||
" return total_sum\n",
|
||||
"\n",
|
||||
"# Parameters for the simulation\n",
|
||||
"n = 10000 # Number of random numbers to generate for each subarray\n",
|
||||
"initial_seed = 42 # Initial seed for the LCG that generates seeds for runs\n",
|
||||
"min_val = -10 # Minimum value for the random numbers\n",
|
||||
"max_val = 10 # Maximum value for the random numbers\n",
|
||||
"\n",
|
||||
"# Import the time module to measure execution time.\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"# Record the start time before executing the main function.\n",
|
||||
"start_time = time.time()\n",
|
||||
"# Call the function to calculate the total maximum subarray sum over multiple runs.\n",
|
||||
"result = total_max_subarray_sum(n, initial_seed, min_val, max_val)\n",
|
||||
"# Record the end time after the function completes.\n",
|
||||
"end_time = time.time()\n",
|
||||
"\n",
|
||||
"# Print the final aggregated result.\n",
|
||||
"print(\"Total Maximum Subarray Sum (20 runs):\", result)\n",
|
||||
"# Print the total execution time, formatted to 6 decimal places.\n",
|
||||
"print(\"Execution Time: {:.6f} seconds\".format(end_time - start_time))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ff02ce09-0544-49a5-944d-a57b25bf9b72",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Streaming"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0be9f47d-5213-4700-b0e2-d444c7c738c0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_gpt(python): \n",
|
||||
" stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python), stream=True)\n",
|
||||
" reply = \"\"\n",
|
||||
" for chunk in stream:\n",
|
||||
" fragment = chunk.choices[0].delta.content or \"\"\n",
|
||||
" reply += fragment\n",
|
||||
" yield reply.replace('```python\\n','').replace('```','')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8669f56b-8314-4582-a167-78842caea131",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_claude(python):\n",
|
||||
" result = claude.messages.stream(\n",
|
||||
" model=CLAUDE_MODEL,\n",
|
||||
" max_tokens=2000,\n",
|
||||
" system=system_message,\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": user_prompt_for(python)}],\n",
|
||||
" )\n",
|
||||
" reply = \"\"\n",
|
||||
" with result as stream:\n",
|
||||
" for text in stream.text_stream:\n",
|
||||
" reply += text\n",
|
||||
" yield reply.replace('```python\\n','').replace('```','')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d48d44df-c082-4ed1-b3ea-fc2a880591c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_gemini(python):\n",
|
||||
" stream = gemini.models.generate_content_stream(\n",
|
||||
" model=GEMINI_MODEL,\n",
|
||||
" contents=user_prompt_for(python),\n",
|
||||
" config=types.GenerateContentConfig(\n",
|
||||
" system_instruction=system_message,\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
" reply = \"\"\n",
|
||||
" for chunk in stream:\n",
|
||||
" reply += chunk.text\n",
|
||||
" yield reply.replace('```python\\n','').replace('```','')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2f1ae8f5-16c8-40a0-aa18-63b617df078d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def annotate(python, model):\n",
|
||||
" if model == \"GPT\":\n",
|
||||
" result = stream_gpt(python)\n",
|
||||
" elif model == \"Claude\":\n",
|
||||
" result = stream_claude(python)\n",
|
||||
" elif model == \"Gemini\":\n",
|
||||
" result = stream_gemini(python)\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"Unknown model\")\n",
|
||||
" for stream_so_far in result:\n",
|
||||
" yield stream_so_far "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "19bf2bff-a822-4009-a539-f003b1651383",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def execute_python(code):\n",
|
||||
" try:\n",
|
||||
" output = io.StringIO()\n",
|
||||
" sys.stdout = output\n",
|
||||
" exec(code)\n",
|
||||
" finally:\n",
|
||||
" sys.stdout = sys.__stdout__\n",
|
||||
" return output.getvalue()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8391444b-b938-4f92-982f-91439b38d901",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Gradio App"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9a2274f1-d03b-42c0-8dcc-4ce159b18442",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"css = \"\"\"\n",
|
||||
".python {background-color: #306998;}\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "76167ea9-d0a1-4bc6-8d73-633d3b8c8df6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gradio as gr\n",
|
||||
"\n",
|
||||
"# Parameters\n",
|
||||
"LINES = 25\n",
|
||||
"LINE_HEIGHT = 20 # px, typical CodeMirror line height\n",
|
||||
"PADDING = 10 # px, top + bottom padding\n",
|
||||
"\n",
|
||||
"CODE_HEIGHT = LINES * LINE_HEIGHT + PADDING\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"with gr.Blocks(\n",
|
||||
" theme=gr.themes.Soft(),\n",
|
||||
" css=f\"\"\"\n",
|
||||
"#code_input .cm-editor, #annotated_code .cm-editor {{\n",
|
||||
" height: {CODE_HEIGHT}px !important;\n",
|
||||
" overflow-y: auto !important;\n",
|
||||
"}}\n",
|
||||
"\"\"\"\n",
|
||||
") as demo_v2:\n",
|
||||
" gr.Markdown(\"## 🐍 Annotate Python Code with Docstrings and Comments\")\n",
|
||||
"\n",
|
||||
" with gr.Row():\n",
|
||||
" with gr.Column(scale=1):\n",
|
||||
" gr.Markdown(\"### Python code:\")\n",
|
||||
" code_input = gr.Code(\n",
|
||||
" language=\"python\", \n",
|
||||
" value=python_hard,\n",
|
||||
" elem_id=\"code_input\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" with gr.Column(scale=1):\n",
|
||||
" gr.Markdown(\"### Annotated code:\")\n",
|
||||
" annotated_output = gr.Code(\n",
|
||||
" language=\"python\",\n",
|
||||
" elem_id=\"annotated_code\",\n",
|
||||
" interactive=False\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" with gr.Row():\n",
|
||||
" with gr.Column(scale=1):\n",
|
||||
" model_dropdown = gr.Dropdown(\n",
|
||||
" choices=[\"Gemini\", \"GPT-4\", \"Claude\"],\n",
|
||||
" value=\"Gemini\",\n",
|
||||
" label=\"Select model\"\n",
|
||||
" )\n",
|
||||
" with gr.Column(scale=1):\n",
|
||||
" annotate_btn = gr.Button(\"✨ Annotate code\", variant=\"primary\")\n",
|
||||
" run_btn = gr.Button(\"▶️ Run Python\", variant=\"secondary\")\n",
|
||||
"\n",
|
||||
" with gr.Row():\n",
|
||||
" with gr.Column():\n",
|
||||
" gr.Markdown(\"### Python result:\")\n",
|
||||
" result_output = gr.Textbox(\n",
|
||||
" lines=5, \n",
|
||||
" label=\"Output\",\n",
|
||||
" interactive=False\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" annotate_btn.click(\n",
|
||||
" annotate,\n",
|
||||
" inputs=[code_input, model_dropdown],\n",
|
||||
" outputs=[annotated_output]\n",
|
||||
" )\n",
|
||||
" run_btn.click(execute_python, inputs=[annotated_output], outputs=[result_output])\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"demo_v2.launch(inbrowser=True)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ea42883b-fdba-46ed-97be-f42e3cb41f11",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c8f17b7-dc42-408f-9b21-cdcfd7dbfb78",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# AutoTrader Code Generator\n",
|
||||
"\n",
|
||||
"Gemini-driven autonomous equities trading bot code generator for simulated market APIs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fcffcfd7-000f-4995-82ae-94232fef1654",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "43d8c659-08d4-4a65-8c39-e42f6c458ba8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -U -q \"google-genai\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "de542411-6b4d-47cd-bf84-d80562b333a5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import io\n",
|
||||
"import sys\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from google import genai\n",
|
||||
"from google.genai import types\n",
|
||||
"from IPython.display import Markdown, display, update_display\n",
|
||||
"import gradio as gr\n",
|
||||
"import subprocess"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e9b78b19-2d47-4973-adbc-d281d8ac8224",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Google API Key Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d9a2e07f-9b07-4afe-8938-ba40c41701ff",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"load_dotenv(override=True)\n",
|
||||
"google_api_key = os.getenv('GOOGLE_API_KEY')\n",
|
||||
"\n",
|
||||
"if google_api_key:\n",
|
||||
" print(f\"Google API Key exists and begins with: {google_api_key[:4]}\")\n",
|
||||
"else:\n",
|
||||
" print(\"Google API Key not set\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4da3664d-9d73-41a2-8a71-fe9468f3955f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!python ./gemini_trading_code_generator.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1f7b7c74-3d77-49d3-ac9f-f3a04743946c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"name": "ExampleEquitySim",
|
||||
"base_url": "https://sim.example.com/api",
|
||||
"endpoints": {
|
||||
"get_price": {
|
||||
"path": "/market/price",
|
||||
"method": "GET",
|
||||
"params": ["symbol"]
|
||||
},
|
||||
"place_order": {
|
||||
"path": "/orders",
|
||||
"method": "POST",
|
||||
"body": ["symbol", "side", "quantity", "order_type", "price_optional"]
|
||||
},
|
||||
"cancel_order": {
|
||||
"path": "/orders/{order_id}/cancel",
|
||||
"method": "POST"
|
||||
},
|
||||
"get_balance": {
|
||||
"path": "/account/balance",
|
||||
"method": "GET"
|
||||
},
|
||||
"get_positions": {
|
||||
"path": "/account/positions",
|
||||
"method": "GET"
|
||||
}
|
||||
},
|
||||
"auth": {
|
||||
"type": "api_key_header",
|
||||
"header_name": "X-API-KEY",
|
||||
"api_key_placeholder": "<SIM_API_KEY>"
|
||||
},
|
||||
"notes": "This simulated API uses JSON and returns ISO timestamps in UTC."
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
gemini_trading_code_generator.py
|
||||
|
||||
Usage:
|
||||
- Prepare you API Specification JSON file with your simulated API details.
|
||||
- Run: pip install google-genai.
|
||||
- Set GOOGLE_API_KEY env var before running.
|
||||
- Run: python gemini_trading_code_generator.py
|
||||
- The generated bot will be saved as `generated_trading_bot.py`.
|
||||
|
||||
Notes:
|
||||
- THIS GENERATES CODE FOR A SIMULATED ENVIRONMENT. Read and review generated code before running.
|
||||
- Keep your API keys safe.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# Gemini client import (Google GenAI SDK)
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
except Exception as e:
|
||||
raise RuntimeError("google-genai not installed. Run: pip install google-genai") from e
|
||||
|
||||
|
||||
# ------------ Gemini / Prompting helpers -------------
|
||||
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
||||
if not GOOGLE_API_KEY:
|
||||
# We won't fail here — the generator will raise when trying to call the client.
|
||||
pass
|
||||
|
||||
GEMINI_MODEL = "gemini-2.5-flash"
|
||||
MAX_TOKEN_COUNT = 12000
|
||||
|
||||
|
||||
def build_prompt(api_spec: Dict[str, Any], strategy: str = "sma_crossover") -> str:
|
||||
"""
|
||||
Create a clear instruction for Gemini to generate the trading bot code.
|
||||
|
||||
Strategy choices:
|
||||
- sma_crossover (default): simple moving-average crossover strategy
|
||||
- random: random buy/sell (for testing)
|
||||
- placeholder for the user to request others
|
||||
"""
|
||||
now = datetime.utcnow().isoformat() + "Z"
|
||||
prompt = f"""
|
||||
You are a code-writing assistant. Produce a single, self-contained Python script named `generated_trading_bot.py`
|
||||
that implements a trading bot for a *simulated equities API*. The simulator has the following specification (JSON):
|
||||
{json.dumps(api_spec, indent=2)}
|
||||
|
||||
Requirements for the generated script:
|
||||
1. The script must be runnable as-is (except for inserting API keys/config). Use only stdlib + `requests` (no other external deps).
|
||||
2. Implement a simple trading strategy: {strategy}. For `sma_crossover`, implement:
|
||||
- Fetch historical or recent prices (you may simulate historical by sampling current price in a loop if the API doesn't return history).
|
||||
- Compute short and long simple moving averages (e.g., 5-period and 20-period).
|
||||
- When short SMA crosses above long SMA: submit a MARKET or SIMULATED BUY order sized to use a configurable fraction of available cash.
|
||||
- When short SMA crosses below long SMA: submit a MARKET or SIMULATED SELL order to close position.
|
||||
3. Use the API endpoints from the spec exactly (build URLs using base_url + path). Respect auth header scheme from spec.
|
||||
4. Include robust error handling and logging (print statements acceptable).
|
||||
5. Include a `--dry-run` flag that prints actions instead of placing orders.
|
||||
6. Include a safe simulation mode: throttle requests, avoid rapid-fire orders, and include a configurable `min_time_between_trades_seconds`.
|
||||
7. Add inline comments explaining important functions and a short README-like docstring at the top of the generated file describing how to configure and run it.
|
||||
8. At the end of the generated file, add a __main__ section that demonstrates a short run (e.g., a 60-second loop) in dry-run mode.
|
||||
9. Do NOT assume any third-party libraries beyond `requests`. Use dataclasses where helpful. Use typing annotations.
|
||||
10. Always document any assumptions you make in a top-level comment block.
|
||||
11. Keep the entire output as valid Python code only (no additional text around it).
|
||||
|
||||
Generate code now.
|
||||
Timestamp for reproducibility: {now}
|
||||
"""
|
||||
return prompt.strip()
|
||||
|
||||
# ------------ Gemini call -------------
|
||||
def generate_code_with_gemini(prompt: str, model: str = GEMINI_MODEL, max_tokens: int = MAX_TOKEN_COUNT) -> str:
|
||||
"""
|
||||
Call the Gemini model to generate the code.
|
||||
|
||||
Uses google-genai SDK. Make sure env var GOOGLE_API_KEY or GOOGLE_API_KEY is set.
|
||||
"""
|
||||
if not GOOGLE_API_KEY:
|
||||
raise RuntimeError("No Google API key found. Set GOOGLE_API_KEY environment variable.")
|
||||
|
||||
# Create client (per Google Gen AI quickstart)
|
||||
client = genai.Client(api_key=GOOGLE_API_KEY)
|
||||
|
||||
# The SDK surface has varied; using the documented 'models.generate_content' style.
|
||||
# If your SDK differs, adapt accordingly.
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
max_output_tokens=max_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
text = None
|
||||
if hasattr(response, "text") and response.text:
|
||||
text = response.text
|
||||
else:
|
||||
# attempt to dig into typical structures
|
||||
try:
|
||||
# some SDKs return dict-like object
|
||||
if isinstance(response, dict):
|
||||
# Try common keys
|
||||
for k in ("text", "content", "output", "candidates"):
|
||||
if k in response and response[k]:
|
||||
text = json.dumps(response[k]) if not isinstance(response[k], str) else response[k]
|
||||
break
|
||||
else:
|
||||
# object with attributes
|
||||
if hasattr(response, "output") and response.output:
|
||||
# navigate first candidate -> text
|
||||
out = response.output
|
||||
if isinstance(out, (list, tuple)) and len(out) > 0:
|
||||
first = out[0]
|
||||
if isinstance(first, dict) and "content" in first:
|
||||
text = first["content"][0].get("text")
|
||||
elif hasattr(first, "content"):
|
||||
text = first.content[0].text
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not text:
|
||||
raise RuntimeError("Could not extract generated text from Gemini response. Inspect `response` object: " + repr(response))
|
||||
|
||||
return text
|
||||
|
||||
# ------------ Save & basic verification -------------
|
||||
def basic_sanity_check(code_text: str) -> bool:
|
||||
"""Do a quick check that the output looks like Python file and contains required sections."""
|
||||
checks = [
|
||||
"import requests" in code_text or "import urllib" in code_text,
|
||||
"def " in code_text,
|
||||
"if __name__" in code_text,
|
||||
"place_order" in code_text or "order" in code_text
|
||||
]
|
||||
return all(checks)
|
||||
|
||||
def save_generated_file(code_text: str, filename: str = "generated_trading_bot.py") -> str:
|
||||
code_text = code_text.replace("```python","").replace("```","")
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write(code_text)
|
||||
return os.path.abspath(filename)
|
||||
|
||||
# ------------ Main CLI -------------
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate trading bot code using Gemini (Google GenAI).")
|
||||
parser.add_argument("--api-spec", type=str, default="api_spec.json", help="Path to JSON file with API spec.")
|
||||
parser.add_argument("--out", type=str, default="generated_trading_bot.py", help="Output filename.")
|
||||
parser.add_argument("--model", type=str, default=GEMINI_MODEL, help="Gemini model to use.")
|
||||
parser.add_argument("--max-tokens", type=int, default=MAX_TOKEN_COUNT, help="Max tokens for generation.")
|
||||
parser.add_argument("--strategy", type=str, default="sma_crossover", help="Trading strategy to request.")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.api_spec, "r", encoding="utf-8") as f:
|
||||
api_spec = json.load(f)
|
||||
|
||||
prompt = build_prompt(api_spec, strategy=args.strategy)
|
||||
print("Calling Gemini to generate code... (this will use your GOOGLE_API_KEY)")
|
||||
generated = generate_code_with_gemini(prompt, model=args.model, max_tokens=args.max_tokens)
|
||||
|
||||
print("Performing sanity checks on the generated code...")
|
||||
if not basic_sanity_check(generated):
|
||||
print("Warning: basic sanity checks failed. Still saving the file for inspection.")
|
||||
|
||||
path = save_generated_file(generated, filename=args.out)
|
||||
print(f"Generated code saved to: {path}")
|
||||
print("Important: Review the generated code carefully before running against any system (even a simulator).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,532 @@
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import collections
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Deque, Optional, List
|
||||
|
||||
# --- Assumptions ---
|
||||
# 1. Historical Price Data Simulation: The simulated API's `/market/price` endpoint
|
||||
# only provides the current price. To implement SMA crossover, which requires
|
||||
# historical data, this bot simulates history by repeatedly calling `get_price`
|
||||
# over time and storing the results. It assumes that calling `get_price` at regular
|
||||
# intervals (e.g., every 5 seconds) effectively provides a time-series of prices.
|
||||
# 2. API Response Formats:
|
||||
# - `get_price`: Assumed to return `{"symbol": "SYM", "price": 123.45}`.
|
||||
# - `get_balance`: Assumed to return `{"cash_balance": 10000.00, ...}`.
|
||||
# - `get_positions`: Assumed to return a list of dictionaries, e.
|
||||
# `[{"symbol": "SYM", "quantity": 10}, ...]`. If no position, an empty list or
|
||||
# a list without the symbol.
|
||||
# - `place_order`: Assumed to return `{"order_id": "...", "status": "accepted"}`.
|
||||
# 3. Order Type: For `place_order`, `order_type` is assumed to be "MARKET" for simplicity,
|
||||
# as no other types are specified and "price_optional" implies it's for limit orders.
|
||||
# For MARKET orders, `price_optional` will not be sent.
|
||||
# 4. Error Handling: Basic network and API-level error checking is implemented.
|
||||
# More complex retry logic or backoff strategies are not included to keep the example concise.
|
||||
# 5. Time Zones: The API notes specify ISO timestamps in UTC. For internal logic,
|
||||
# `time.time()` (epoch seconds in UTC) is used for time comparisons, which is
|
||||
# sufficient for throttling and trade timing.
|
||||
|
||||
# --- Configuration ---
|
||||
# You can override these defaults using command-line arguments.
|
||||
DEFAULT_API_KEY = os.environ.get("SIM_API_KEY", "<YOUR_SIM_API_KEY>") # Set SIM_API_KEY env var or replace
|
||||
DEFAULT_BASE_URL = "https://sim.example.com/api"
|
||||
DEFAULT_SYMBOL = "AAPL" # Example stock symbol
|
||||
|
||||
# Trading Strategy Parameters
|
||||
DEFAULT_SHORT_SMA_PERIOD = 5 # Number of price points for short SMA
|
||||
DEFAULT_LONG_SMA_PERIOD = 20 # Number of price points for long SMA
|
||||
DEFAULT_BUY_CASH_FRACTION = 0.95 # Fraction of available cash to use for a BUY order
|
||||
|
||||
# Bot Operation Parameters
|
||||
DEFAULT_PRICE_FETCH_INTERVAL_SECONDS = 5 # How often to fetch a new price point for SMA calculation
|
||||
DEFAULT_MAIN_LOOP_INTERVAL_SECONDS = 10 # How often the bot evaluates the strategy
|
||||
DEFAULT_MIN_TIME_BETWEEN_TRADES_SECONDS = 60 # Minimum time (seconds) between placing orders
|
||||
DEFAULT_INITIAL_HISTORY_COLLECTION_COUNT = DEFAULT_LONG_SMA_PERIOD + 5 # Ensure enough data for long SMA
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradingBotConfig:
|
||||
api_key: str
|
||||
base_url: str
|
||||
symbol: str
|
||||
short_sma_period: int
|
||||
long_sma_period: int
|
||||
buy_cash_fraction: float
|
||||
price_fetch_interval_seconds: int
|
||||
main_loop_interval_seconds: int
|
||||
min_time_between_trades_seconds: int
|
||||
initial_history_collection_count: int
|
||||
dry_run: bool
|
||||
|
||||
|
||||
class SimulatedAPIClient:
|
||||
"""
|
||||
Client for interacting with the ExampleEquitySim API.
|
||||
Handles request building, authentication, and basic error parsing.
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, api_key: str):
|
||||
self.base_url = base_url
|
||||
self.headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
|
||||
self.session = requests.Session() # Use a session for connection pooling
|
||||
|
||||
def _log(self, message: str) -> None:
|
||||
"""Simple logging utility."""
|
||||
print(f"[API Client] {message}")
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Generic helper to make API requests.
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
try:
|
||||
response = self.session.request(
|
||||
method, url, headers=self.headers, params=params, json=json_data, timeout=10
|
||||
)
|
||||
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
self._log(f"HTTP error for {method} {url}: {e.response.status_code} - {e.response.text}")
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
self._log(f"Connection error for {method} {url}: {e}")
|
||||
except requests.exceptions.Timeout as e:
|
||||
self._log(f"Timeout error for {method} {url}: {e}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._log(f"An unexpected request error occurred for {method} {url}: {e}")
|
||||
except json.JSONDecodeError:
|
||||
self._log(f"Failed to decode JSON from response for {method} {url}: {response.text}")
|
||||
return None
|
||||
|
||||
def get_price(self, symbol: str) -> Optional[float]:
|
||||
"""
|
||||
Fetches the current market price for a given symbol.
|
||||
Returns the price as a float, or None on error.
|
||||
"""
|
||||
path = "/market/price"
|
||||
params = {"symbol": symbol}
|
||||
response = self._make_request("GET", path, params=params)
|
||||
if response and "price" in response:
|
||||
return float(response["price"])
|
||||
self._log(f"Could not get price for {symbol}.")
|
||||
return None
|
||||
|
||||
def place_order(
|
||||
self,
|
||||
symbol: str,
|
||||
side: str, # "BUY" or "SELL"
|
||||
quantity: float,
|
||||
order_type: str = "MARKET",
|
||||
price_optional: Optional[float] = None # For LIMIT orders, not used for MARKET
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Places a trading order.
|
||||
"""
|
||||
path = "/orders"
|
||||
payload = {
|
||||
"symbol": symbol,
|
||||
"side": side,
|
||||
"quantity": quantity,
|
||||
"order_type": order_type,
|
||||
}
|
||||
if order_type != "MARKET" and price_optional is not None:
|
||||
payload["price_optional"] = price_optional
|
||||
|
||||
self._log(f"Placing {side} order: {quantity} {symbol} ({order_type})...")
|
||||
response = self._make_request("POST", path, json_data=payload)
|
||||
if response and response.get("status") == "accepted":
|
||||
self._log(f"Order placed successfully: {response.get('order_id')}")
|
||||
return response
|
||||
self._log(f"Failed to place {side} order for {quantity} {symbol}. Response: {response}")
|
||||
return None
|
||||
|
||||
def cancel_order(self, order_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Cancels an existing order.
|
||||
"""
|
||||
path = f"/orders/{order_id}/cancel"
|
||||
self._log(f"Cancelling order {order_id}...")
|
||||
response = self._make_request("POST", path)
|
||||
if response and response.get("status") == "cancelled":
|
||||
self._log(f"Order {order_id} cancelled.")
|
||||
return response
|
||||
self._log(f"Failed to cancel order {order_id}. Response: {response}")
|
||||
return None
|
||||
|
||||
def get_balance(self) -> Optional[float]:
|
||||
"""
|
||||
Fetches the current cash balance.
|
||||
Returns cash balance as float, or None on error.
|
||||
"""
|
||||
path = "/account/balance"
|
||||
response = self._make_request("GET", path)
|
||||
if response and "cash_balance" in response:
|
||||
return float(response["cash_balance"])
|
||||
self._log("Could not get account balance.")
|
||||
return None
|
||||
|
||||
def get_positions(self) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Fetches all current open positions.
|
||||
Returns a list of position dictionaries, or None on error.
|
||||
"""
|
||||
path = "/account/positions"
|
||||
response = self._make_request("GET", path)
|
||||
if response is not None:
|
||||
# Assuming the API returns a list, even if empty
|
||||
if isinstance(response, list):
|
||||
return response
|
||||
else:
|
||||
self._log(f"Unexpected response format for get_positions: {response}")
|
||||
return []
|
||||
self._log("Could not get account positions.")
|
||||
return None
|
||||
|
||||
|
||||
class TradingBot:
|
||||
"""
|
||||
Implements the SMA Crossover trading strategy for a simulated equities API.
|
||||
"""
|
||||
|
||||
def __init__(self, config: TradingBotConfig, api_client: SimulatedAPIClient):
|
||||
self.config = config
|
||||
self.api_client = api_client
|
||||
# Deque for efficient rolling window of prices
|
||||
self.price_history: Deque[float] = collections.deque(
|
||||
maxlen=self.config.long_sma_period
|
||||
)
|
||||
self.last_trade_timestamp: float = 0.0
|
||||
self.current_position_quantity: float = 0.0
|
||||
self.previous_short_sma: Optional[float] = None
|
||||
self.previous_long_sma: Optional[float] = None
|
||||
|
||||
self._log(f"Trading bot initialized for symbol: {self.config.symbol}")
|
||||
self._log(f"Short SMA: {self.config.short_sma_period} periods, Long SMA: {self.config.long_sma_period} periods")
|
||||
if self.config.dry_run:
|
||||
self._log("!!! DRY RUN MODE ACTIVE - NO REAL ORDERS WILL BE PLACED !!!")
|
||||
|
||||
def _log(self, message: str) -> None:
|
||||
"""Simple logging utility for the bot."""
|
||||
print(f"[Bot] {message}")
|
||||
|
||||
def _fetch_and_store_price(self) -> Optional[float]:
|
||||
"""
|
||||
Fetches the current price from the API and adds it to the price history.
|
||||
Returns the fetched price or None if failed.
|
||||
"""
|
||||
price = self.api_client.get_price(self.config.symbol)
|
||||
if price is not None:
|
||||
self.price_history.append(price)
|
||||
self._log(f"Fetched price for {self.config.symbol}: {price}. History size: {len(self.price_history)}")
|
||||
return price
|
||||
self._log(f"Failed to fetch current price for {self.config.symbol}.")
|
||||
return None
|
||||
|
||||
def _calculate_sma(self, period: int) -> Optional[float]:
|
||||
"""
|
||||
Calculates the Simple Moving Average (SMA) for a given period
|
||||
using the stored price history.
|
||||
"""
|
||||
if len(self.price_history) < period:
|
||||
return None
|
||||
# Get the last 'period' prices from the deque
|
||||
# Python's deque doesn't have direct slicing like list[-period:]
|
||||
# So we convert to list for slicing or iterate last 'n' elements
|
||||
recent_prices = list(self.price_history)[-period:]
|
||||
return sum(recent_prices) / period
|
||||
|
||||
def _update_current_position(self) -> None:
|
||||
"""
|
||||
Fetches the current position for the trading symbol from the API
|
||||
and updates the bot's internal state.
|
||||
"""
|
||||
positions = self.api_client.get_positions()
|
||||
self.current_position_quantity = 0.0
|
||||
if positions:
|
||||
for pos in positions:
|
||||
if pos.get("symbol") == self.config.symbol:
|
||||
self.current_position_quantity = float(pos.get("quantity", 0))
|
||||
break
|
||||
self._log(f"Current position in {self.config.symbol}: {self.current_position_quantity}")
|
||||
|
||||
def _can_trade(self) -> bool:
|
||||
"""
|
||||
Checks if enough time has passed since the last trade to place a new one.
|
||||
"""
|
||||
time_since_last_trade = time.time() - self.last_trade_timestamp
|
||||
if time_since_last_trade < self.config.min_time_between_trades_seconds:
|
||||
self._log(f"Throttling: Waiting {math.ceil(self.config.min_time_between_trades_seconds - time_since_last_trade)}s before next trade.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def collect_initial_history(self) -> None:
|
||||
"""
|
||||
Collects an initial set of price data before starting the trading strategy.
|
||||
This is crucial for calculating SMAs from the start.
|
||||
"""
|
||||
self._log(f"Collecting initial price history ({self.config.initial_history_collection_count} points required)...")
|
||||
for i in range(self.config.initial_history_collection_count):
|
||||
if self._fetch_and_store_price() is None:
|
||||
self._log("Failed to collect initial price. Retrying...")
|
||||
# Wait before fetching next price to simulate time passing
|
||||
time.sleep(self.config.price_fetch_interval_seconds)
|
||||
self._log(f"Collected {i+1}/{self.config.initial_history_collection_count} prices.")
|
||||
self._log("Initial price history collection complete.")
|
||||
|
||||
def run_strategy_iteration(self) -> None:
|
||||
"""
|
||||
Executes one iteration of the SMA crossover strategy.
|
||||
"""
|
||||
self._log("--- Running strategy iteration ---")
|
||||
|
||||
# 1. Fetch current position and balance
|
||||
self._update_current_position()
|
||||
cash_balance = self.api_client.get_balance()
|
||||
if cash_balance is None:
|
||||
self._log("Could not get cash balance. Skipping iteration.")
|
||||
return
|
||||
|
||||
# 2. Fetch new price and update history
|
||||
if self._fetch_and_store_price() is None:
|
||||
return # Skip iteration if price fetch fails
|
||||
|
||||
# 3. Ensure enough data for SMAs
|
||||
if len(self.price_history) < self.config.long_sma_period:
|
||||
self._log(f"Not enough price history for SMAs (need {self.config.long_sma_period}, have {len(self.price_history)}). Waiting for more data.")
|
||||
return
|
||||
|
||||
# 4. Calculate SMAs
|
||||
short_sma = self._calculate_sma(self.config.short_sma_period)
|
||||
long_sma = self._calculate_sma(self.config.long_sma_period)
|
||||
|
||||
if short_sma is None or long_sma is None:
|
||||
self._log("Could not calculate SMAs. Skipping iteration.")
|
||||
return
|
||||
|
||||
self._log(f"Current SMAs: Short={short_sma:.2f}, Long={long_sma:.2f}")
|
||||
|
||||
# If this is the first time we calculated SMAs, just store them and exit
|
||||
if self.previous_short_sma is None or self.previous_long_sma is None:
|
||||
self._log("First SMA calculation. Storing values for next iteration comparison.")
|
||||
self.previous_short_sma = short_sma
|
||||
self.previous_long_sma = long_sma
|
||||
return
|
||||
|
||||
# 5. Check for crossover signals
|
||||
# Buy Signal: Short SMA crosses above Long SMA
|
||||
if (self.previous_short_sma < self.previous_long_sma) and (short_sma >= long_sma):
|
||||
self._log("!!! BUY SIGNAL DETECTED: Short SMA crossed above Long SMA !!!")
|
||||
if self.current_position_quantity > 0:
|
||||
self._log(f"Already hold a position of {self.current_position_quantity} {self.config.symbol}. No new buy order.")
|
||||
elif not self._can_trade():
|
||||
pass # Message already logged by _can_trade()
|
||||
else:
|
||||
buy_amount_dollars = cash_balance * self.config.buy_cash_fraction
|
||||
# Use the most recent price for calculating quantity
|
||||
current_price = self.price_history[-1]
|
||||
if current_price > 0:
|
||||
quantity_to_buy = math.floor(buy_amount_dollars / current_price)
|
||||
if quantity_to_buy > 0:
|
||||
self._log(f"Attempting to BUY {quantity_to_buy} shares of {self.config.symbol} at approx ${current_price:.2f} using ${buy_amount_dollars:.2f} of cash.")
|
||||
if not self.config.dry_run:
|
||||
order_response = self.api_client.place_order(self.config.symbol, "BUY", quantity_to_buy)
|
||||
if order_response:
|
||||
self.last_trade_timestamp = time.time()
|
||||
self._update_current_position() # Refresh position after order
|
||||
else:
|
||||
self._log(f"DRY RUN: Would have placed BUY order for {quantity_to_buy} {self.config.symbol}.")
|
||||
self.last_trade_timestamp = time.time() # Still simulate trade delay
|
||||
else:
|
||||
self._log("Calculated quantity to buy is zero.")
|
||||
else:
|
||||
self._log("Current price is zero, cannot calculate buy quantity.")
|
||||
|
||||
# Sell Signal: Short SMA crosses below Long SMA
|
||||
elif (self.previous_short_sma > self.previous_long_sma) and (short_sma <= long_sma):
|
||||
self._log("!!! SELL SIGNAL DETECTED: Short SMA crossed below Long SMA !!!")
|
||||
if self.current_position_quantity == 0:
|
||||
self._log("No open position to sell. No new sell order.")
|
||||
elif not self._can_trade():
|
||||
pass # Message already logged by _can_trade()
|
||||
else:
|
||||
quantity_to_sell = self.current_position_quantity
|
||||
self._log(f"Attempting to SELL {quantity_to_sell} shares of {self.config.symbol}.")
|
||||
if not self.config.dry_run:
|
||||
order_response = self.api_client.place_order(self.config.symbol, "SELL", quantity_to_sell)
|
||||
if order_response:
|
||||
self.last_trade_timestamp = time.time()
|
||||
self._update_current_position() # Refresh position after order
|
||||
else:
|
||||
self._log(f"DRY RUN: Would have placed SELL order for {quantity_to_sell} {self.config.symbol}.")
|
||||
self.last_trade_timestamp = time.time() # Still simulate trade delay
|
||||
|
||||
else:
|
||||
self._log("No crossover signal detected.")
|
||||
|
||||
# 6. Update previous SMA values for the next iteration
|
||||
self.previous_short_sma = short_sma
|
||||
self.previous_long_sma = long_sma
|
||||
|
||||
self._log("--- Iteration complete ---")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to parse arguments, configure the bot, and run the trading loop.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SMA Crossover Trading Bot for Simulated Equities API."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=DEFAULT_API_KEY,
|
||||
help=f"Your API key for the simulator. Default: '{DEFAULT_API_KEY}' (or SIM_API_KEY env var)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default=DEFAULT_BASE_URL,
|
||||
help=f"Base URL of the simulated API. Default: {DEFAULT_BASE_URL}"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--symbol",
|
||||
type=str,
|
||||
default=DEFAULT_SYMBOL,
|
||||
help=f"Trading symbol (e.g., AAPL). Default: {DEFAULT_SYMBOL}"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, the bot will log trade actions instead of placing real orders."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--short-sma-period",
|
||||
type=int,
|
||||
default=DEFAULT_SHORT_SMA_PERIOD,
|
||||
help=f"Number of periods for the short SMA. Default: {DEFAULT_SHORT_SMA_PERIOD}"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--long-sma-period",
|
||||
type=int,
|
||||
default=DEFAULT_LONG_SMA_PERIOD,
|
||||
help=f"Number of periods for the long SMA. Default: {DEFAULT_LONG_SMA_PERIOD}"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buy-cash-fraction",
|
||||
type=float,
|
||||
default=DEFAULT_BUY_CASH_FRACTION,
|
||||
help=f"Fraction of available cash to use for a BUY order (e.g., 0.95). Default: {DEFAULT_BUY_CASH_FRACTION}"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--price-fetch-interval",
|
||||
type=int,
|
||||
default=DEFAULT_PRICE_FETCH_INTERVAL_SECONDS,
|
||||
help=f"Interval in seconds to fetch new price data for SMA calculation. Default: {DEFAULT_PRICE_FETCH_INTERVAL_SECONDS}"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--main-loop-interval",
|
||||
type=int,
|
||||
default=DEFAULT_MAIN_LOOP_INTERVAL_SECONDS,
|
||||
help=f"Interval in seconds between strategy evaluations. Default: {DEFAULT_MAIN_LOOP_INTERVAL_SECONDS}"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-trade-interval",
|
||||
type=int,
|
||||
default=DEFAULT_MIN_TIME_BETWEEN_TRADES_SECONDS,
|
||||
help=f"Minimum time in seconds between placing actual orders. Default: {DEFAULT_MIN_TIME_BETWEEN_TRADES_SECONDS}"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-history-count",
|
||||
type=int,
|
||||
default=DEFAULT_INITIAL_HISTORY_COLLECTION_COUNT,
|
||||
help=f"Number of initial price points to collect before starting strategy. Default: {DEFAULT_INITIAL_HISTORY_COLLECTION_COUNT}"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run-duration",
|
||||
type=int,
|
||||
default=300, # Default to 5 minutes for demonstration
|
||||
help="Total duration in seconds to run the bot loop. (0 for indefinite run)."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.api_key == "<YOUR_SIM_API_KEY>":
|
||||
print("WARNING: API Key is not set. Please replace <YOUR_SIM_API_KEY> in the script or set SIM_API_KEY environment variable, or pass with --api-key.")
|
||||
print("Exiting...")
|
||||
return
|
||||
|
||||
config = TradingBotConfig(
|
||||
api_key=args.api_key,
|
||||
base_url=args.base_url,
|
||||
symbol=args.symbol,
|
||||
short_sma_period=args.short_sma_period,
|
||||
long_sma_period=args.long_sma_period,
|
||||
buy_cash_fraction=args.buy_cash_fraction,
|
||||
price_fetch_interval_seconds=args.price_fetch_interval,
|
||||
main_loop_interval_seconds=args.main_loop_interval,
|
||||
min_time_between_trades_seconds=args.min_trade_interval,
|
||||
initial_history_collection_count=args.initial_history_count,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
api_client = SimulatedAPIClient(config.base_url, config.api_key)
|
||||
trading_bot = TradingBot(config, api_client)
|
||||
|
||||
# Ensure enough history for SMA calculations
|
||||
if config.initial_history_collection_count < config.long_sma_period:
|
||||
trading_bot._log(f"WARNING: Initial history collection count ({config.initial_history_collection_count}) is less than long SMA period ({config.long_sma_period}). Adjusting to {config.long_sma_period + 5}.")
|
||||
config.initial_history_collection_count = config.long_sma_period + 5
|
||||
|
||||
# Collect initial price data
|
||||
trading_bot.collect_initial_history()
|
||||
|
||||
# Main trading loop
|
||||
start_time = time.time()
|
||||
iteration = 0
|
||||
trading_bot._log(f"Starting main trading loop for {args.run_duration} seconds (0 for indefinite)...")
|
||||
|
||||
try:
|
||||
while True:
|
||||
iteration += 1
|
||||
trading_bot._log(f"\n--- Main Loop Iteration {iteration} ---")
|
||||
trading_bot.run_strategy_iteration()
|
||||
|
||||
if args.run_duration > 0 and (time.time() - start_time) >= args.run_duration:
|
||||
trading_bot._log(f"Run duration of {args.run_duration} seconds completed. Exiting.")
|
||||
break
|
||||
|
||||
trading_bot._log(f"Sleeping for {config.main_loop_interval_seconds} seconds...")
|
||||
time.sleep(config.main_loop_interval_seconds)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
trading_bot._log("Bot stopped manually by user (KeyboardInterrupt).")
|
||||
except Exception as e:
|
||||
trading_bot._log(f"An unexpected error occurred in the main loop: {e}")
|
||||
finally:
|
||||
trading_bot._log("Trading bot shutting down.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# --- Demonstration Run ---
|
||||
# To run this example:
|
||||
# 1. Save this script as `generated_trading_bot.py`.
|
||||
# 2. Install requests: `pip install requests`
|
||||
# 3. Replace `<YOUR_SIM_API_KEY>` with an actual API key or set the SIM_API_KEY environment variable.
|
||||
# 4. Run from your terminal:
|
||||
# `python generated_trading_bot.py --dry-run --run-duration 60 --symbol MSFT`
|
||||
# This will simulate a 60-second run for MSFT in dry-run mode,
|
||||
# printing potential trades without actually executing them.
|
||||
# For a longer run, change --run-duration (e.g., 3600 for 1 hour).
|
||||
# Remove --dry-run to enable live trading (use with caution!).
|
||||
main()
|
||||
690
week4/community-contributions/day3-with-gemini.ipynb
Normal file
690
week4/community-contributions/day3-with-gemini.ipynb
Normal file
@@ -0,0 +1,690 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4a6ab9a2-28a2-445d-8512-a0dc8d1b54e9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Code Generator\n",
|
||||
"\n",
|
||||
"The requirement: use a Frontier model to generate high performance C++ code from Python code\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1f72dfaf-9f20-4d81-b082-018eda152c9f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -U -q \"google-genai\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e610bf56-a46e-4aff-8de1-ab49d62b1ad3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import io\n",
|
||||
"import sys\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from openai import OpenAI\n",
|
||||
"from google import genai\n",
|
||||
"from google.genai import types\n",
|
||||
"import anthropic\n",
|
||||
"from IPython.display import Markdown, display, update_display\n",
|
||||
"import gradio as gr\n",
|
||||
"import subprocess"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4f672e1c-87e9-4865-b760-370fa605e614",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# environment\n",
|
||||
"\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
||||
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
|
||||
"google_api_key = os.getenv('GOOGLE_API_KEY')\n",
|
||||
"\n",
|
||||
"if openai_api_key:\n",
|
||||
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
|
||||
"else:\n",
|
||||
" print(\"OpenAI API Key not set\")\n",
|
||||
" \n",
|
||||
"if anthropic_api_key:\n",
|
||||
" print(f\"Anthropic API Key exists and begins {anthropic_api_key[:7]}\")\n",
|
||||
"else:\n",
|
||||
" print(\"Anthropic API Key not set\")\n",
|
||||
"\n",
|
||||
"if google_api_key:\n",
|
||||
" print(f\"Google API Key exists and begins {google_api_key[:8]}\")\n",
|
||||
"else:\n",
|
||||
" print(\"Google API Key not set\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8aa149ed-9298-4d69-8fe2-8f5de0f667da",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# initialize\n",
|
||||
"\n",
|
||||
"openai = OpenAI()\n",
|
||||
"claude = anthropic.Anthropic()\n",
|
||||
"gemini = genai.Client()\n",
|
||||
"\n",
|
||||
"OPENAI_MODEL = \"o4-mini\"\n",
|
||||
"CLAUDE_MODEL = \"claude-3-7-sonnet-latest\"\n",
|
||||
"GEMINI_MODEL = \"gemini-2.5-flash\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6896636f-923e-4a2c-9d6c-fac07828a201",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"system_message = \"You are an assistant that reimplements Python code in high performance C++ for an M1 Mac. \"\n",
|
||||
"system_message += \"Respond only with C++ code; use comments sparingly and do not provide any explanation other than occasional comments. \"\n",
|
||||
"system_message += \"The C++ response needs to produce an identical output in the fastest possible time.\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8e7b3546-57aa-4c29-bc5d-f211970d04eb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def user_prompt_for(python):\n",
|
||||
" user_prompt = \"Rewrite this Python code in C++ with the fastest possible implementation that produces identical output in the least time. \"\n",
|
||||
" user_prompt += \"Respond only with C++ code; do not explain your work other than a few comments. \"\n",
|
||||
" user_prompt += \"Pay attention to number types to ensure no int overflows. Remember to #include all necessary C++ packages such as iomanip.\\n\\n\"\n",
|
||||
" user_prompt += python\n",
|
||||
" return user_prompt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c6190659-f54c-4951-bef4-4960f8e51cc4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def messages_for(python):\n",
|
||||
" return [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt_for(python)}\n",
|
||||
" ]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "71e1ba8c-5b05-4726-a9f3-8d8c6257350b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# write to a file called optimized.cpp\n",
|
||||
"\n",
|
||||
"def write_output(cpp):\n",
|
||||
" code = cpp.replace(\"```cpp\",\"\").replace(\"```\",\"\")\n",
|
||||
" with open(\"optimized.cpp\", \"w\") as f:\n",
|
||||
" f.write(code)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e7d2fea8-74c6-4421-8f1e-0e76d5b201b9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def optimize_gpt(python): \n",
|
||||
" stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python), stream=True)\n",
|
||||
" reply = \"\"\n",
|
||||
" for chunk in stream:\n",
|
||||
" fragment = chunk.choices[0].delta.content or \"\"\n",
|
||||
" reply += fragment\n",
|
||||
" print(fragment, end='', flush=True)\n",
|
||||
" write_output(reply)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7cd84ad8-d55c-4fe0-9eeb-1895c95c4a9d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def optimize_claude(python):\n",
|
||||
" result = claude.messages.stream(\n",
|
||||
" model=CLAUDE_MODEL,\n",
|
||||
" max_tokens=2000,\n",
|
||||
" system=system_message,\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": user_prompt_for(python)}],\n",
|
||||
" )\n",
|
||||
" reply = \"\"\n",
|
||||
" with result as stream:\n",
|
||||
" for text in stream.text_stream:\n",
|
||||
" reply += text\n",
|
||||
" print(text, end=\"\", flush=True)\n",
|
||||
" write_output(reply)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e8a35102-1c95-469b-8855-e85f4c9bdbdf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def optimize_gemini(python):\n",
|
||||
" reply = gemini.models.generate_content(\n",
|
||||
" model=GEMINI_MODEL,\n",
|
||||
" contents=user_prompt_for(python),\n",
|
||||
" config=types.GenerateContentConfig(\n",
|
||||
" system_instruction=system_message,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" print(reply.text)\n",
|
||||
" write_output(reply.text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a1cbb778-fa57-43de-b04b-ed523f396c38",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pi = \"\"\"\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"def calculate(iterations, param1, param2):\n",
|
||||
" result = 1.0\n",
|
||||
" for i in range(1, iterations+1):\n",
|
||||
" j = i * param1 - param2\n",
|
||||
" result -= (1/j)\n",
|
||||
" j = i * param1 + param2\n",
|
||||
" result += (1/j)\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
"start_time = time.time()\n",
|
||||
"result = calculate(100_000_000, 4, 1) * 4\n",
|
||||
"end_time = time.time()\n",
|
||||
"\n",
|
||||
"print(f\"Result: {result:.12f}\")\n",
|
||||
"print(f\"Execution Time: {(end_time - start_time):.6f} seconds\")\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7fe1cd4b-d2c5-4303-afed-2115a3fef200",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exec(pi)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "105db6f9-343c-491d-8e44-3a5328b81719",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"optimize_gpt(pi)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bf8f8018-f64d-425c-a0e1-d7862aa9592d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Compiling C++ and executing\n",
|
||||
"\n",
|
||||
"This next cell contains the command to compile a C++ file on my M1 Mac. \n",
|
||||
"It compiles the file `optimized.cpp` into an executable called `optimized` \n",
|
||||
"Then it runs the program called `optimized`\n",
|
||||
"\n",
|
||||
"In the next lab (day4), a student has contributed a full solution that compiles to efficient code on Mac, PC and Linux!\n",
|
||||
"\n",
|
||||
"You can wait for this, or you can google (or ask ChatGPT!) for how to do this on your platform, then replace the lines below.\n",
|
||||
"If you're not comfortable with this step, you can skip it for sure - I'll show you exactly how it performs on my Mac.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"OR alternatively: student Sandeep K.G. points out that you can run Python and C++ code online to test it out that way. Thank you Sandeep! \n",
|
||||
"> Not an exact comparison but you can still get the idea of performance difference.\n",
|
||||
"> For example here: https://www.programiz.com/cpp-programming/online-compiler/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4194e40c-04ab-4940-9d64-b4ad37c5bb40",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Compile C++ and run the executable\n",
|
||||
"\n",
|
||||
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
|
||||
"!./optimized"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "983a11fe-e24d-4c65-8269-9802c5ef3ae6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"optimize_claude(pi)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5a766f9-3d23-4bb4-a1d4-88ec44b61ddf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Repeat for Claude - again, use the right approach for your platform\n",
|
||||
"\n",
|
||||
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
|
||||
"!./optimized"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "01f331f2-caac-48f6-9a03-8a228ee521bc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"optimize_gemini(pi)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5ef707a4-930e-4b8b-9443-e7e4fd309c2a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
|
||||
"!./optimized"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7d1eaeca-61be-4d0a-a525-dd09f52aaa0f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Python Hard Version"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c3b497b3-f569-420e-b92e-fb0f49957ce0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"python_hard = \"\"\"# Be careful to support large number sizes\n",
|
||||
"\n",
|
||||
"def lcg(seed, a=1664525, c=1013904223, m=2**32):\n",
|
||||
" value = seed\n",
|
||||
" while True:\n",
|
||||
" value = (a * value + c) % m\n",
|
||||
" yield value\n",
|
||||
" \n",
|
||||
"def max_subarray_sum(n, seed, min_val, max_val):\n",
|
||||
" lcg_gen = lcg(seed)\n",
|
||||
" random_numbers = [next(lcg_gen) % (max_val - min_val + 1) + min_val for _ in range(n)]\n",
|
||||
" max_sum = float('-inf')\n",
|
||||
" for i in range(n):\n",
|
||||
" current_sum = 0\n",
|
||||
" for j in range(i, n):\n",
|
||||
" current_sum += random_numbers[j]\n",
|
||||
" if current_sum > max_sum:\n",
|
||||
" max_sum = current_sum\n",
|
||||
" return max_sum\n",
|
||||
"\n",
|
||||
"def total_max_subarray_sum(n, initial_seed, min_val, max_val):\n",
|
||||
" total_sum = 0\n",
|
||||
" lcg_gen = lcg(initial_seed)\n",
|
||||
" for _ in range(20):\n",
|
||||
" seed = next(lcg_gen)\n",
|
||||
" total_sum += max_subarray_sum(n, seed, min_val, max_val)\n",
|
||||
" return total_sum\n",
|
||||
"\n",
|
||||
"# Parameters\n",
|
||||
"n = 10000 # Number of random numbers\n",
|
||||
"initial_seed = 42 # Initial seed for the LCG\n",
|
||||
"min_val = -10 # Minimum value of random numbers\n",
|
||||
"max_val = 10 # Maximum value of random numbers\n",
|
||||
"\n",
|
||||
"# Timing the function\n",
|
||||
"import time\n",
|
||||
"start_time = time.time()\n",
|
||||
"result = total_max_subarray_sum(n, initial_seed, min_val, max_val)\n",
|
||||
"end_time = time.time()\n",
|
||||
"\n",
|
||||
"print(\"Total Maximum Subarray Sum (20 runs):\", result)\n",
|
||||
"print(\"Execution Time: {:.6f} seconds\".format(end_time - start_time))\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dab5e4bc-276c-4555-bd4c-12c699d5e899",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exec(python_hard)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e8d24ed5-2c15-4f55-80e7-13a3952b3cb8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"optimize_gpt(python_hard)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e0b3d073-88a2-40b2-831c-6f0c345c256f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Replace this with the right C++ compile + execute command for your platform\n",
|
||||
"\n",
|
||||
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
|
||||
"!./optimized"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e9305446-1d0c-4b51-866a-b8c1e299bf5c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"optimize_gemini(python_hard)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0c181036-8193-4fdd-aef3-fc513b218d43",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Replace this with the right C++ compile + execute command for your platform\n",
|
||||
"\n",
|
||||
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
|
||||
"!./optimized"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2ee75e72-9ecb-4edd-a74a-4d3a83c1eb79",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"optimize_claude(python_hard)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4a4ab43c-7df2-4770-bd05-6bbc198a8c45",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Replace this with the right C++ compile + execute command for your platform\n",
|
||||
"\n",
|
||||
"!clang++ -O3 -std=c++17 -march=armv8.3-a -o optimized optimized.cpp\n",
|
||||
"!./optimized"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ff02ce09-0544-49a5-944d-a57b25bf9b72",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Streaming"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0be9f47d-5213-4700-b0e2-d444c7c738c0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_gpt(python): \n",
|
||||
" stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python), stream=True)\n",
|
||||
" reply = \"\"\n",
|
||||
" for chunk in stream:\n",
|
||||
" fragment = chunk.choices[0].delta.content or \"\"\n",
|
||||
" reply += fragment\n",
|
||||
" yield reply.replace('```cpp\\n','').replace('```','')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8669f56b-8314-4582-a167-78842caea131",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_claude(python):\n",
|
||||
" result = claude.messages.stream(\n",
|
||||
" model=CLAUDE_MODEL,\n",
|
||||
" max_tokens=2000,\n",
|
||||
" system=system_message,\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": user_prompt_for(python)}],\n",
|
||||
" )\n",
|
||||
" reply = \"\"\n",
|
||||
" with result as stream:\n",
|
||||
" for text in stream.text_stream:\n",
|
||||
" reply += text\n",
|
||||
" yield reply.replace('```cpp\\n','').replace('```','')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d48d44df-c082-4ed1-b3ea-fc2a880591c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_gemini(python):\n",
|
||||
" stream = gemini.models.generate_content_stream(\n",
|
||||
" model=GEMINI_MODEL,\n",
|
||||
" contents=user_prompt_for(python),\n",
|
||||
" config=types.GenerateContentConfig(\n",
|
||||
" system_instruction=system_message,\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
" reply = \"\"\n",
|
||||
" for chunk in stream:\n",
|
||||
" reply += chunk.text\n",
|
||||
" yield reply.replace('```cpp\\n','').replace('```','')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2f1ae8f5-16c8-40a0-aa18-63b617df078d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def optimize(python, model):\n",
|
||||
" if model==\"GPT\":\n",
|
||||
" result = stream_gpt(python)\n",
|
||||
" elif model==\"Claude\":\n",
|
||||
" result = stream_claude(python)\n",
|
||||
" elif model==\"Gemini\":\n",
|
||||
" result = stream_gemini(python)\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"Unknown model\")\n",
|
||||
" for stream_so_far in result:\n",
|
||||
" yield stream_so_far "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f1ddb38e-6b0a-4c37-baa4-ace0b7de887a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with gr.Blocks() as ui:\n",
|
||||
" with gr.Row():\n",
|
||||
" python = gr.Textbox(label=\"Python code:\", lines=10, value=python_hard)\n",
|
||||
" cpp = gr.Textbox(label=\"C++ code:\", lines=10)\n",
|
||||
" with gr.Row():\n",
|
||||
" model = gr.Dropdown([\"GPT\", \"Claude\", \"Gemini\"], label=\"Select model\", value=\"GPT\")\n",
|
||||
" convert = gr.Button(\"Convert code\")\n",
|
||||
"\n",
|
||||
" convert.click(optimize, inputs=[python, model], outputs=[cpp])\n",
|
||||
"\n",
|
||||
"ui.launch(inbrowser=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "19bf2bff-a822-4009-a539-f003b1651383",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def execute_python(code):\n",
|
||||
" try:\n",
|
||||
" output = io.StringIO()\n",
|
||||
" sys.stdout = output\n",
|
||||
" exec(code)\n",
|
||||
" finally:\n",
|
||||
" sys.stdout = sys.__stdout__\n",
|
||||
" return output.getvalue()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "77f3ab5d-fcfb-4d3f-8728-9cacbf833ea6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# M1 Mac version to compile and execute optimized C++ code:\n",
|
||||
"\n",
|
||||
"def execute_cpp(code):\n",
|
||||
" write_output(code)\n",
|
||||
" try:\n",
|
||||
" compile_cmd = [\"clang++\", \"-Ofast\", \"-std=c++17\", \"-march=armv8.5-a\", \"-mtune=apple-m1\", \"-mcpu=apple-m1\", \"-o\", \"optimized\", \"optimized.cpp\"]\n",
|
||||
" compile_result = subprocess.run(compile_cmd, check=True, text=True, capture_output=True)\n",
|
||||
" run_cmd = [\"./optimized\"]\n",
|
||||
" run_result = subprocess.run(run_cmd, check=True, text=True, capture_output=True)\n",
|
||||
" return run_result.stdout\n",
|
||||
" except subprocess.CalledProcessError as e:\n",
|
||||
" return f\"An error occurred:\\n{e.stderr}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9a2274f1-d03b-42c0-8dcc-4ce159b18442",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"css = \"\"\"\n",
|
||||
".python {background-color: #306998;}\n",
|
||||
".cpp {background-color: #050;}\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f1303932-160c-424b-97a8-d28c816721b2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with gr.Blocks(css=css) as ui:\n",
|
||||
" gr.Markdown(\"## Convert code from Python to C++\")\n",
|
||||
" with gr.Row():\n",
|
||||
" python = gr.Textbox(label=\"Python code:\", value=python_hard, lines=20)\n",
|
||||
" cpp = gr.Textbox(label=\"C++ code:\", lines=20)\n",
|
||||
" with gr.Row():\n",
|
||||
" model = gr.Dropdown([\"GPT\", \"Claude\", \"Gemini\"], label=\"Select model\", value=\"GPT\")\n",
|
||||
" convert = gr.Button(\"Convert code\")\n",
|
||||
" with gr.Row():\n",
|
||||
" python_run = gr.Button(\"Run Python\")\n",
|
||||
" cpp_run = gr.Button(\"Run C++\")\n",
|
||||
" with gr.Row():\n",
|
||||
" python_out = gr.TextArea(label=\"Python result:\", elem_classes=[\"python\"])\n",
|
||||
" cpp_out = gr.TextArea(label=\"C++ result:\", elem_classes=[\"cpp\"])\n",
|
||||
"\n",
|
||||
" convert.click(optimize, inputs=[python, model], outputs=[cpp])\n",
|
||||
" python_run.click(execute_python, inputs=[python], outputs=[python_out])\n",
|
||||
" cpp_run.click(execute_cpp, inputs=[cpp], outputs=[cpp_out])\n",
|
||||
"\n",
|
||||
"ui.launch(inbrowser=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ea42883b-fdba-46ed-97be-f42e3cb41f11",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user