Merge pull request #787 from ranskills/ranskills/week2

Bootcamp(Ransford) - MathXpert with tools integration - week 2
This commit is contained in:
Ed Donner
2025-10-22 09:18:28 -04:00
committed by GitHub

View File

@@ -0,0 +1,657 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "fe12c203-e6a6-452c-a655-afb8a03a4ff5",
"metadata": {},
"source": [
"# Week 2 exercise\n",
"\n",
"## MathXpert with tools integration\n",
"\n",
"- Provides the freedom to explore all the models available from the providers\n",
"- Handling of multiple tools calling simultaneously\n",
"- Efficiently run tools in parallel\n",
"- Tool response, i.e. the `plot_function`, that does not require going back to the LLM\n",
"- Uses the inbuilt logging package to allow the control of the verbosity of the logging, set to a higher level, like INFO, to reduce the noisy output"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1070317-3ed9-4659-abe3-828943230e03",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import logging\n",
"from enum import StrEnum\n",
"from getpass import getpass\n",
"from types import SimpleNamespace\n",
"from typing import Callable\n",
"\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import ipywidgets as widgets\n",
"from IPython.display import display, clear_output, Latex\n",
"import gradio as gr\n",
"\n",
"load_dotenv(override=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99901b80",
"metadata": {},
"outputs": [],
"source": [
"logging.basicConfig(level=logging.WARNING)\n",
"\n",
"logger = logging.getLogger('mathxpert')\n",
"logger.setLevel(logging.DEBUG)"
]
},
{
"cell_type": "markdown",
"id": "f169118a-645e-44e1-9a98-4f561adfbb08",
"metadata": {},
"source": [
"## Free Cloud Providers\n",
"\n",
"Grab your free API Keys from these generous sites:\n",
"\n",
"- https://openrouter.ai/\n",
"- https://ollama.com/\n",
"\n",
">**NOTE**: If you do not have a key for any provider, simply press ENTER to move on"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a456906-915a-4bfd-bb9d-57e505c5093f",
"metadata": {},
"outputs": [],
"source": [
"class Provider(StrEnum):\n",
" OLLAMA = 'Ollama'\n",
" OPENROUTER = 'OpenRouter'\n",
"\n",
"clients: dict[Provider, OpenAI] = {}\n",
"models: dict[Provider, list[str]] = {\n",
" Provider.OLLAMA: [],\n",
" Provider.OPENROUTER: [],\n",
"}\n",
"\n",
"DEFAULT_PROVIDER = Provider.OLLAMA\n",
"\n",
"selection_state: dict[Provider, str | None] = {\n",
" Provider.OLLAMA: 'gpt-oss:20b',\n",
" Provider.OPENROUTER: 'openai/gpt-oss-20b:free',\n",
"}\n",
"\n",
"def get_secret_in_google_colab(env_name: str) -> str:\n",
" try:\n",
" from google.colab import userdata\n",
" return userdata.get(env_name)\n",
" except Exception:\n",
" return ''\n",
" \n",
"\n",
"def get_secret(env_name: str) -> str:\n",
" '''Gets the value from the environment(s), otherwise ask the user for it if not set'''\n",
" key = os.environ.get(env_name) or get_secret_in_google_colab(env_name)\n",
"\n",
" if not key:\n",
" key = getpass(f'Enter {env_name}:').strip()\n",
"\n",
" if key:\n",
" logger.info(f'✅ {env_name} provided')\n",
" else:\n",
" logger.warning(f'❌ {env_name} not provided')\n",
" return key.strip()\n",
"\n",
"\n",
"if api_key := get_secret('OLLAMA_API_KEY'):\n",
" clients[Provider.OLLAMA] = OpenAI(api_key=api_key, base_url='https://ollama.com/v1')\n",
"\n",
"if api_key := get_secret('OPENROUTER_API_KEY'):\n",
" clients[Provider.OPENROUTER] = OpenAI(api_key=api_key, base_url='https://openrouter.ai/api/v1')\n",
"\n",
"available_providers = [str(p) for p in clients.keys()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aae1579b-7a02-459d-81c6-0f775d2a1410",
"metadata": {},
"outputs": [],
"source": [
"selected_provider, selected_model, client = '', '', None\n",
"\n",
"\n",
"def get_desired_value_or_first_item(desire, options) -> str | None:\n",
" logger.debug(f'Pick {desire} from {options}')\n",
" selected = desire if desire in options else None\n",
" if selected:\n",
" return selected\n",
"\n",
" return options[0] if options else None\n",
" \n",
"try:\n",
" selected_provider = get_desired_value_or_first_item(DEFAULT_PROVIDER, available_providers)\n",
" client = clients.get(selected_provider)\n",
"except Exception:\n",
" logger.warning(f'❌ no provider configured and everything else from here will FAIL 🤦, I know you know this already.')\n",
"\n",
"def load_models_if_needed(client: OpenAI, selected_provider):\n",
" global selected_model, models\n",
"\n",
" if client and not models.get(selected_provider):\n",
" logging.info(f'📡 Fetching {selected_provider} models...')\n",
" \n",
" models[selected_provider] = [model.id for model in client.models.list()]\n",
" selected_model = get_desired_value_or_first_item(\n",
" selection_state[selected_provider], \n",
" models[selected_provider],\n",
" )\n",
"\n",
"load_models_if_needed(client, selected_provider)\n",
"\n",
"logger.info(f' Provider: {selected_provider} Model: {selected_model}, Client: {client}')"
]
},
{
"cell_type": "markdown",
"id": "e04675c2-1b81-4187-868c-c7112cd77e37",
"metadata": {},
"source": [
"## Prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8d7923c-5f28-4c30-8556-342d7c8497c1",
"metadata": {},
"outputs": [],
"source": [
"def get_messages(question: str) -> list[dict[str, str]]:\n",
" \"\"\"Generate messages for the chat models.\"\"\"\n",
"\n",
" system_prompt = r'''\n",
" You are MathXpert, an expert Mathematician who makes math fun to learn by relating concepts to real \n",
" practical usage to whip up the interest in learners.\n",
" \n",
" Explain step-by-step thoroughly how to solve a math problem. \n",
" - ALWAYS use `$$...$$` for mathematical expressions.\n",
" - NEVER use square brackets `[...]` to delimit math.\n",
" - Example: Instead of \"[x = 2]\", write \"$$x = 2$$\".\n",
" - You may use `\\\\[4pt]` inside matrices for spacing.\n",
" '''\n",
"\n",
" return [\n",
" {'role': 'system', 'content': system_prompt },\n",
" {'role': 'user', 'content': question},\n",
" ]"
]
},
{
"cell_type": "markdown",
"id": "caa51866-f433-4b9a-ab20-fff5fc3b7d63",
"metadata": {},
"source": [
"## Tools"
]
},
{
"cell_type": "markdown",
"id": "a24c659a-5937-43b1-bb95-c0342f2786a9",
"metadata": {},
"source": [
"### Tools Definitions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f302f47-9a67-4410-ba16-56fa5a731c66",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel, Field\n",
"from openai.types.shared_params import FunctionDefinition\n",
"import sympy as sp\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import io\n",
"import base64\n",
"import random\n",
"\n",
"class ToolInput(BaseModel):\n",
" pass\n",
" \n",
"class GetCurrentDateTimeInput(ToolInput):\n",
" timezone: str = Field(default=\"UTC\", description=\"Timezone name, e.g., 'UTC' or 'Africa/Accra'\")\n",
"\n",
"\n",
"def get_current_datetime(req: GetCurrentDateTimeInput):\n",
" '''Returns the current date and time in the specified timezone.'''\n",
" from zoneinfo import ZoneInfo\n",
"\n",
" try:\n",
" from datetime import datetime\n",
" tz = ZoneInfo(req.timezone)\n",
" dt = datetime.now(tz)\n",
" return {\n",
" \"date\": dt.strftime(\"%Y-%m-%d\"),\n",
" \"time\": dt.strftime(\"%H:%M:%S %Z\"),\n",
" } \n",
" except:\n",
" return {\"error\": f\"Invalid timezone: {req.timezone}\"}\n",
"\n",
"\n",
"class GetTemperatureInput(ToolInput):\n",
" pass\n",
"\n",
"def get_temperature(req: GetTemperatureInput) -> float:\n",
" '''Returns the current temperature in degree celsius'''\n",
" return random.randint(-30, 70)\n",
"\n",
"\n",
"class PlotFunctionInput(ToolInput):\n",
" expression: str = Field(description=\"Mathematical expression to plot, e.g., 'sin(x)'\")\n",
" x_min: float = Field(default=-10, description=\"Minimum x value\")\n",
" x_max: float = Field(default=10, description=\"Maximum x value\")\n",
"\n",
"\n",
"def plot_function(req: PlotFunctionInput) -> dict[str, any]:\n",
" '''Plots a mathematical function and returns image data.'''\n",
" try:\n",
" x = sp.symbols('x')\n",
" expr = sp.sympify(req.expression)\n",
" lambdified = sp.lambdify(x, expr, 'numpy')\n",
" \n",
" x_vals = np.linspace(req.x_min, req.x_max, 400)\n",
" y_vals = lambdified(x_vals)\n",
" \n",
" plt.figure(figsize=(10, 6))\n",
" plt.plot(x_vals, y_vals, 'b-', linewidth=2)\n",
" plt.grid(True, alpha=0.3)\n",
" plt.title(f\"Plot of ${sp.latex(expr)}$\", fontsize=14)\n",
" plt.xlabel('x', fontsize=12)\n",
" plt.ylabel('f(x)', fontsize=12)\n",
" \n",
"\n",
" buf = io.BytesIO()\n",
" plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')\n",
" plt.close()\n",
" buf.seek(0)\n",
" img_str = base64.b64encode(buf.read()).decode()\n",
" \n",
" return {\n",
" \"plot_image\": f\"data:image/png;base64,{img_str}\",\n",
" \"expression\": req.expression,\n",
" \"x_range\": [req.x_min, req.x_max]\n",
" }\n",
" except Exception as e:\n",
" return {\"error\": f\"Could not plot function: {str(e)}\"}\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "fae3ef71-f6cd-4894-ae55-9f4f8dd2a1cd",
"metadata": {},
"source": [
"### Tools registration & execution"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4f18bc9f-f8d1-4208-a3d7-e4e911034572",
"metadata": {},
"outputs": [],
"source": [
"from concurrent.futures import ThreadPoolExecutor\n",
"\n",
"class ToolManager:\n",
" def __init__(self):\n",
" self._tools = []\n",
" self._tools_map: dict[str, tuple[Callable, ToolInput]] = {}\n",
"\n",
" def register_tool[T: ToolInput](self, fn: Callable, fn_input: T):\n",
" self._tools.append({\n",
" \"type\": \"function\",\n",
" \"function\": FunctionDefinition(\n",
" name=fn.__name__,\n",
" description=fn.__doc__,\n",
" parameters=fn_input.model_json_schema() if fn_input else None,\n",
" )\n",
" })\n",
" \n",
" self._tools_map[fn.__name__] = (fn, fn_input)\n",
"\n",
" def _run_single_tool(self, tool_call) -> dict[str, str] | None:\n",
" if not tool_call.id:\n",
" return None\n",
" \n",
" fn, fn_input = self._tools_map.get(tool_call.function.name)\n",
" args = tool_call.function.arguments\n",
" try:\n",
" if args:\n",
" result = fn(fn_input(**json.loads(args))) if fn_input else fn()\n",
" else:\n",
" result = fn(fn_input()) if fn_input else fn()\n",
" \n",
" logger.debug(f'Tool run result: {result}')\n",
" \n",
" return {\n",
" 'role': 'tool',\n",
" 'tool_call_id': tool_call.id,\n",
" 'content': json.dumps(result),\n",
" }\n",
" except Exception as e:\n",
" logger.error(f'Tool execution failed: {e}', extra={'name': tool_call.function.name})\n",
" return None\n",
"\n",
" def run(self, tool_calls) -> list[dict[str, str]]:\n",
" if not tool_calls:\n",
" return []\n",
"\n",
" logger.debug(tool_calls)\n",
"\n",
" tool_messages = []\n",
" \n",
" with ThreadPoolExecutor() as executor:\n",
" futures = [executor.submit(self._run_single_tool, tool_call) for tool_call in tool_calls]\n",
" \n",
" for future in futures:\n",
" result = future.result()\n",
" if result:\n",
" tool_messages.append(result)\n",
" \n",
" return tool_messages\n",
"\n",
" @property\n",
" def tools(self) -> list[any]:\n",
" return self._tools\n",
"\n",
" def dump_tools(self) -> str:\n",
" return json.dumps(self._tools, indent=True)\n",
"\n",
" \n",
"tool_manager = ToolManager()\n",
"\n",
"tool_manager.register_tool(get_current_datetime, GetCurrentDateTimeInput)\n",
"tool_manager.register_tool(get_temperature, GetTemperatureInput)\n",
"tool_manager.register_tool(plot_function, PlotFunctionInput)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b2e0634-de5d-45f6-a8d4-569e04d14a00",
"metadata": {},
"outputs": [],
"source": [
"logger.debug(tool_manager.dump_tools())"
]
},
{
"cell_type": "markdown",
"id": "bde4cd2a-b681-4b78-917c-d970c264b151",
"metadata": {},
"source": [
"## Interaction with LLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f7c8ea8-4082-4ad0-8751-3301adcf6538",
"metadata": {},
"outputs": [],
"source": [
"# handle = display(None, display_id=True)\n",
"\n",
"def ask(client: OpenAI | None, model: str, question: str, max_tool_turns=5):\n",
" if client is None:\n",
" logger.warning('You should have provided the API Keys you know. Fix 🔧 this and try again ♻️.')\n",
" return\n",
"\n",
" try:\n",
" logger.debug(f'# Tools: {len(tool_manager.tools)}')\n",
"\n",
" messages = get_messages(question=question)\n",
"\n",
" for turn in range(max_tool_turns):\n",
" logger.debug(f'Turn: {turn}')\n",
" response = client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" tools=tool_manager.tools,\n",
" stream=True,\n",
" )\n",
" \n",
" current_message = {}\n",
" tool_calls_accumulator = {}\n",
" \n",
" output = ''\n",
" call_id = None\n",
" \n",
" for chunk in response:\n",
" delta = chunk.choices[0].delta\n",
"\n",
" logger.debug(f' ✨ {chunk.choices[0]}')\n",
" if content := delta.content:\n",
" output += content\n",
" yield output\n",
"\n",
" if tool_calls := delta.tool_calls:\n",
" for tool_chunk in tool_calls:\n",
" print('x' * 50)\n",
" print(tool_chunk)\n",
"\n",
" if tool_chunk.id and call_id != tool_chunk.id:\n",
" call_id = tool_chunk.id\n",
"\n",
" print(f'Call ID: {call_id}')\n",
" # Streams of arguments don't come with the call id\n",
" # if not call_id:\n",
" # continue\n",
"\n",
" if call_id not in tool_calls_accumulator:\n",
" # tool_calls_accumulator[call_id] = {\n",
" # 'id': call_id,\n",
" # 'function': {'name': '', 'arguments': ''}\n",
" # }\n",
" tool_calls_accumulator[call_id] = SimpleNamespace(\n",
" id=call_id,\n",
" function=SimpleNamespace(name='', arguments='')\n",
" )\n",
"\n",
" if tool_chunk.function.name:\n",
" tool_calls_accumulator[call_id].function.name += tool_chunk.function.name\n",
" \n",
" if tool_chunk.function.arguments:\n",
" tool_calls_accumulator[call_id].function.arguments += tool_chunk.function.arguments\n",
"\n",
" if finish_reason := chunk.choices[0].finish_reason:\n",
" logger.debug('🧠 LLM interaction ended. Reason: {finish_reason}')\n",
"\n",
" final_tool_calls = list(tool_calls_accumulator.values())\n",
" if final_tool_calls:\n",
" logger.debug(f'Final tools to call {final_tool_calls}')\n",
"\n",
" tool_call_message = {\n",
" 'role': 'assistant',\n",
" 'content': None,\n",
" 'tool_calls': json.loads(json.dumps(final_tool_calls, default=lambda o: o.__dict__))\n",
" }\n",
"\n",
" messages.append(tool_call_message)\n",
" tool_messages = tool_manager.run(final_tool_calls)\n",
"\n",
" if tool_messages:\n",
" for tool_msg in tool_messages:\n",
" try:\n",
" data = json.loads(tool_msg['content'])\n",
" if 'plot_image' in data:\n",
" logger.debug('We have a plot')\n",
" yield f'<img src=\"{data[\"plot_image\"]}\" style=\"max-width: 100%; height: auto; border: 1px solid #ccc; border-radius: 5px;\">'\n",
" return\n",
" except:\n",
" pass\n",
" messages.extend(tool_messages)\n",
" else:\n",
" return\n",
" \n",
" except Exception as e:\n",
" logger.error(f'🔥 An error occurred during the interaction with the LLM: {e}', exc_info=True)\n",
" return str(e)"
]
},
{
"cell_type": "markdown",
"id": "eda786d3-5add-4bd1-804d-13eff60c3d1a",
"metadata": {},
"source": [
"### Verify streaming behaviour"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09bc9a11-adb4-4a9c-9c77-73b2b5a665cf",
"metadata": {},
"outputs": [],
"source": [
"# print(selected_provider, selected_model)\n",
"# print(client)\n",
"# for o in ask(client, selected_model, 'What is the time?'):\n",
"# for o in ask(client, selected_model, 'What is the temperature?'):\n",
"# for o in ask(client, selected_model, 'What is the time and the temperature?'):\n",
"# for o in ask(client, selected_model, 'Plot a for the expression sin(x)'):\n",
"for o in ask(client, selected_model, 'Plot a graph of y = x**2'):\n",
" print(o)"
]
},
{
"cell_type": "markdown",
"id": "27230463",
"metadata": {},
"source": [
"## Build Gradio UI"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50fc3577",
"metadata": {},
"outputs": [],
"source": [
"def chat(message: str, history: list[dict], selected_provider: str, model_selector: str):\n",
" # NOTE: I'm not interesting in maintaining a conversation\n",
" response = ask(client, selected_model, message)\n",
"\n",
" for chunk in response:\n",
" yield chunk\n",
"\n",
"def on_provider_change(change):\n",
" global selected_provider, client, models\n",
" logger.info(f'Provider changed to {change}')\n",
" selected_provider = change\n",
" client = clients.get(selected_provider)\n",
" load_models_if_needed(client, selected_provider)\n",
"\n",
" return gr.Dropdown(\n",
" choices=models.get(selected_provider, []),\n",
" value=selection_state[selected_provider],\n",
" interactive=True,\n",
" )\n",
"\n",
"\n",
"def on_model_change(change):\n",
" global selected_provider, selected_model, selection_state\n",
"\n",
" selected_model = change\n",
" selection_state[selected_provider] = selected_model\n",
" logger.info(f'👉 Selected model: {selected_model}')\n",
"\n",
"\n",
"with gr.Blocks(title='MathXpert', fill_width=True, \n",
" \n",
" ) as ui:\n",
" def get_value_if_exist(v, ls) -> str:\n",
" print(ls)\n",
" selected = v if v in ls else None\n",
" if selected:\n",
" return selected\n",
"\n",
" return ls[0] if ls else None\n",
"\n",
" with gr.Row():\n",
" provider_selector = gr.Dropdown(\n",
" choices=available_providers, \n",
" value=get_desired_value_or_first_item(selected_provider, available_providers),\n",
" label='Provider',\n",
" )\n",
" model_selector = gr.Dropdown(\n",
" choices=models[selected_provider],\n",
" value=get_desired_value_or_first_item(selection_state[selected_provider], models[selected_provider]),\n",
" label='Model',\n",
" )\n",
" \n",
" provider_selector.change(fn=on_provider_change, inputs=provider_selector, outputs=model_selector)\n",
" model_selector.change(fn=on_model_change, inputs=model_selector)\n",
"\n",
" examples = [\n",
" ['Where can substitutions be applied in real life?', None, None],\n",
" ['Give 1 differential equation question and solve it', None, None],\n",
" ['Plot x**2 - 3x', None, None],\n",
" ['What is the time now?', None, None],\n",
" ['What is the temperature?', None, None],\n",
" ['Tell me the time and the temperature now', None, None],\n",
" ]\n",
"\n",
" \n",
" gr.ChatInterface(\n",
" fn=chat, \n",
" type='messages', \n",
" chatbot=gr.Chatbot(type='messages', height='75vh', resizable=True),\n",
" additional_inputs=[provider_selector, model_selector],\n",
" examples=examples,\n",
" )\n",
"\n",
"ui.launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}