{ "cells": [ { "cell_type": "markdown", "id": "fe12c203-e6a6-452c-a655-afb8a03a4ff5", "metadata": {}, "source": [ "# Week 2 exercise\n", "\n", "## MathXpert with tools integration\n", "\n", "- Provides the freedom to explore all the models available from the providers\n", "- Handling of multiple tools calling simultaneously\n", "- Efficiently run tools in parallel\n", "- Tool response, i.e. the `plot_function`, that does not require going back to the LLM\n", "- Uses the inbuilt logging package to allow the control of the verbosity of the logging, set to a higher level, like INFO, to reduce the noisy output" ] }, { "cell_type": "code", "execution_count": null, "id": "c1070317-3ed9-4659-abe3-828943230e03", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import logging\n", "from enum import StrEnum\n", "from getpass import getpass\n", "from types import SimpleNamespace\n", "from typing import Callable\n", "\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, Latex\n", "import gradio as gr\n", "\n", "load_dotenv(override=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "99901b80", "metadata": {}, "outputs": [], "source": [ "logging.basicConfig(level=logging.WARNING)\n", "\n", "logger = logging.getLogger('mathxpert')\n", "logger.setLevel(logging.DEBUG)" ] }, { "cell_type": "markdown", "id": "f169118a-645e-44e1-9a98-4f561adfbb08", "metadata": {}, "source": [ "## Free Cloud Providers\n", "\n", "Grab your free API Keys from these generous sites:\n", "\n", "- https://openrouter.ai/\n", "- https://ollama.com/\n", "\n", ">**NOTE**: If you do not have a key for any provider, simply press ENTER to move on" ] }, { "cell_type": "code", "execution_count": null, "id": "4a456906-915a-4bfd-bb9d-57e505c5093f", "metadata": {}, "outputs": [], "source": [ "class Provider(StrEnum):\n", " OLLAMA = 'Ollama'\n", " OPENROUTER = 'OpenRouter'\n", "\n", "clients: dict[Provider, OpenAI] = {}\n", "models: dict[Provider, list[str]] = {\n", " Provider.OLLAMA: [],\n", " Provider.OPENROUTER: [],\n", "}\n", "\n", "DEFAULT_PROVIDER = Provider.OLLAMA\n", "\n", "selection_state: dict[Provider, str | None] = {\n", " Provider.OLLAMA: 'gpt-oss:20b',\n", " Provider.OPENROUTER: 'openai/gpt-oss-20b:free',\n", "}\n", "\n", "def get_secret_in_google_colab(env_name: str) -> str:\n", " try:\n", " from google.colab import userdata\n", " return userdata.get(env_name)\n", " except Exception:\n", " return ''\n", " \n", "\n", "def get_secret(env_name: str) -> str:\n", " '''Gets the value from the environment(s), otherwise ask the user for it if not set'''\n", " key = os.environ.get(env_name) or get_secret_in_google_colab(env_name)\n", "\n", " if not key:\n", " key = getpass(f'Enter {env_name}:').strip()\n", "\n", " if key:\n", " logger.info(f'✅ {env_name} provided')\n", " else:\n", " logger.warning(f'❌ {env_name} not provided')\n", " return key.strip()\n", "\n", "\n", "if api_key := get_secret('OLLAMA_API_KEY'):\n", " clients[Provider.OLLAMA] = OpenAI(api_key=api_key, base_url='https://ollama.com/v1')\n", "\n", "if api_key := get_secret('OPENROUTER_API_KEY'):\n", " clients[Provider.OPENROUTER] = OpenAI(api_key=api_key, base_url='https://openrouter.ai/api/v1')\n", "\n", "available_providers = [str(p) for p in clients.keys()]" ] }, { "cell_type": "code", "execution_count": null, "id": "aae1579b-7a02-459d-81c6-0f775d2a1410", "metadata": {}, "outputs": [], "source": [ "selected_provider, selected_model, client = '', '', None\n", "\n", "\n", "def get_desired_value_or_first_item(desire, options) -> str | None:\n", " logger.debug(f'Pick {desire} from {options}')\n", " selected = desire if desire in options else None\n", " if selected:\n", " return selected\n", "\n", " return options[0] if options else None\n", " \n", "try:\n", " selected_provider = get_desired_value_or_first_item(DEFAULT_PROVIDER, available_providers)\n", " client = clients.get(selected_provider)\n", "except Exception:\n", " logger.warning(f'❌ no provider configured and everything else from here will FAIL đŸ¤Ļ, I know you know this already.')\n", "\n", "def load_models_if_needed(client: OpenAI, selected_provider):\n", " global selected_model, models\n", "\n", " if client and not models.get(selected_provider):\n", " logging.info(f'📡 Fetching {selected_provider} models...')\n", " \n", " models[selected_provider] = [model.id for model in client.models.list()]\n", " selected_model = get_desired_value_or_first_item(\n", " selection_state[selected_provider], \n", " models[selected_provider],\n", " )\n", "\n", "load_models_if_needed(client, selected_provider)\n", "\n", "logger.info(f'â„šī¸ Provider: {selected_provider} Model: {selected_model}, Client: {client}')" ] }, { "cell_type": "markdown", "id": "e04675c2-1b81-4187-868c-c7112cd77e37", "metadata": {}, "source": [ "## Prompt" ] }, { "cell_type": "code", "execution_count": null, "id": "a8d7923c-5f28-4c30-8556-342d7c8497c1", "metadata": {}, "outputs": [], "source": [ "def get_messages(question: str) -> list[dict[str, str]]:\n", " \"\"\"Generate messages for the chat models.\"\"\"\n", "\n", " system_prompt = r'''\n", " You are MathXpert, an expert Mathematician who makes math fun to learn by relating concepts to real \n", " practical usage to whip up the interest in learners.\n", " \n", " Explain step-by-step thoroughly how to solve a math problem. \n", " - ALWAYS use `$$...$$` for mathematical expressions.\n", " - NEVER use square brackets `[...]` to delimit math.\n", " - Example: Instead of \"[x = 2]\", write \"$$x = 2$$\".\n", " - You may use `\\\\[4pt]` inside matrices for spacing.\n", " '''\n", "\n", " return [\n", " {'role': 'system', 'content': system_prompt },\n", " {'role': 'user', 'content': question},\n", " ]" ] }, { "cell_type": "markdown", "id": "caa51866-f433-4b9a-ab20-fff5fc3b7d63", "metadata": {}, "source": [ "## Tools" ] }, { "cell_type": "markdown", "id": "a24c659a-5937-43b1-bb95-c0342f2786a9", "metadata": {}, "source": [ "### Tools Definitions" ] }, { "cell_type": "code", "execution_count": null, "id": "3f302f47-9a67-4410-ba16-56fa5a731c66", "metadata": {}, "outputs": [], "source": [ "from pydantic import BaseModel, Field\n", "from openai.types.shared_params import FunctionDefinition\n", "import sympy as sp\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import io\n", "import base64\n", "import random\n", "\n", "class ToolInput(BaseModel):\n", " pass\n", " \n", "class GetCurrentDateTimeInput(ToolInput):\n", " timezone: str = Field(default=\"UTC\", description=\"Timezone name, e.g., 'UTC' or 'Africa/Accra'\")\n", "\n", "\n", "def get_current_datetime(req: GetCurrentDateTimeInput):\n", " '''Returns the current date and time in the specified timezone.'''\n", " from zoneinfo import ZoneInfo\n", "\n", " try:\n", " from datetime import datetime\n", " tz = ZoneInfo(req.timezone)\n", " dt = datetime.now(tz)\n", " return {\n", " \"date\": dt.strftime(\"%Y-%m-%d\"),\n", " \"time\": dt.strftime(\"%H:%M:%S %Z\"),\n", " } \n", " except:\n", " return {\"error\": f\"Invalid timezone: {req.timezone}\"}\n", "\n", "\n", "class GetTemperatureInput(ToolInput):\n", " pass\n", "\n", "def get_temperature(req: GetTemperatureInput) -> float:\n", " '''Returns the current temperature in degree celsius'''\n", " return random.randint(-30, 70)\n", "\n", "\n", "class PlotFunctionInput(ToolInput):\n", " expression: str = Field(description=\"Mathematical expression to plot, e.g., 'sin(x)'\")\n", " x_min: float = Field(default=-10, description=\"Minimum x value\")\n", " x_max: float = Field(default=10, description=\"Maximum x value\")\n", "\n", "\n", "def plot_function(req: PlotFunctionInput) -> dict[str, any]:\n", " '''Plots a mathematical function and returns image data.'''\n", " try:\n", " x = sp.symbols('x')\n", " expr = sp.sympify(req.expression)\n", " lambdified = sp.lambdify(x, expr, 'numpy')\n", " \n", " x_vals = np.linspace(req.x_min, req.x_max, 400)\n", " y_vals = lambdified(x_vals)\n", " \n", " plt.figure(figsize=(10, 6))\n", " plt.plot(x_vals, y_vals, 'b-', linewidth=2)\n", " plt.grid(True, alpha=0.3)\n", " plt.title(f\"Plot of ${sp.latex(expr)}$\", fontsize=14)\n", " plt.xlabel('x', fontsize=12)\n", " plt.ylabel('f(x)', fontsize=12)\n", " \n", "\n", " buf = io.BytesIO()\n", " plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')\n", " plt.close()\n", " buf.seek(0)\n", " img_str = base64.b64encode(buf.read()).decode()\n", " \n", " return {\n", " \"plot_image\": f\"data:image/png;base64,{img_str}\",\n", " \"expression\": req.expression,\n", " \"x_range\": [req.x_min, req.x_max]\n", " }\n", " except Exception as e:\n", " return {\"error\": f\"Could not plot function: {str(e)}\"}\n", "\n" ] }, { "cell_type": "markdown", "id": "fae3ef71-f6cd-4894-ae55-9f4f8dd2a1cd", "metadata": {}, "source": [ "### Tools registration & execution" ] }, { "cell_type": "code", "execution_count": null, "id": "4f18bc9f-f8d1-4208-a3d7-e4e911034572", "metadata": {}, "outputs": [], "source": [ "from concurrent.futures import ThreadPoolExecutor\n", "\n", "class ToolManager:\n", " def __init__(self):\n", " self._tools = []\n", " self._tools_map: dict[str, tuple[Callable, ToolInput]] = {}\n", "\n", " def register_tool[T: ToolInput](self, fn: Callable, fn_input: T):\n", " self._tools.append({\n", " \"type\": \"function\",\n", " \"function\": FunctionDefinition(\n", " name=fn.__name__,\n", " description=fn.__doc__,\n", " parameters=fn_input.model_json_schema() if fn_input else None,\n", " )\n", " })\n", " \n", " self._tools_map[fn.__name__] = (fn, fn_input)\n", "\n", " def _run_single_tool(self, tool_call) -> dict[str, str] | None:\n", " if not tool_call.id:\n", " return None\n", " \n", " fn, fn_input = self._tools_map.get(tool_call.function.name)\n", " args = tool_call.function.arguments\n", " try:\n", " if args:\n", " result = fn(fn_input(**json.loads(args))) if fn_input else fn()\n", " else:\n", " result = fn(fn_input()) if fn_input else fn()\n", " \n", " logger.debug(f'Tool run result: {result}')\n", " \n", " return {\n", " 'role': 'tool',\n", " 'tool_call_id': tool_call.id,\n", " 'content': json.dumps(result),\n", " }\n", " except Exception as e:\n", " logger.error(f'Tool execution failed: {e}', extra={'name': tool_call.function.name})\n", " return None\n", "\n", " def run(self, tool_calls) -> list[dict[str, str]]:\n", " if not tool_calls:\n", " return []\n", "\n", " logger.debug(tool_calls)\n", "\n", " tool_messages = []\n", " \n", " with ThreadPoolExecutor() as executor:\n", " futures = [executor.submit(self._run_single_tool, tool_call) for tool_call in tool_calls]\n", " \n", " for future in futures:\n", " result = future.result()\n", " if result:\n", " tool_messages.append(result)\n", " \n", " return tool_messages\n", "\n", " @property\n", " def tools(self) -> list[any]:\n", " return self._tools\n", "\n", " def dump_tools(self) -> str:\n", " return json.dumps(self._tools, indent=True)\n", "\n", " \n", "tool_manager = ToolManager()\n", "\n", "tool_manager.register_tool(get_current_datetime, GetCurrentDateTimeInput)\n", "tool_manager.register_tool(get_temperature, GetTemperatureInput)\n", "tool_manager.register_tool(plot_function, PlotFunctionInput)" ] }, { "cell_type": "code", "execution_count": null, "id": "9b2e0634-de5d-45f6-a8d4-569e04d14a00", "metadata": {}, "outputs": [], "source": [ "logger.debug(tool_manager.dump_tools())" ] }, { "cell_type": "markdown", "id": "bde4cd2a-b681-4b78-917c-d970c264b151", "metadata": {}, "source": [ "## Interaction with LLM" ] }, { "cell_type": "code", "execution_count": null, "id": "8f7c8ea8-4082-4ad0-8751-3301adcf6538", "metadata": {}, "outputs": [], "source": [ "# handle = display(None, display_id=True)\n", "\n", "def ask(client: OpenAI | None, model: str, question: str, max_tool_turns=5):\n", " if client is None:\n", " logger.warning('You should have provided the API Keys you know. Fix 🔧 this and try again â™ģī¸.')\n", " return\n", "\n", " try:\n", " logger.debug(f'# Tools: {len(tool_manager.tools)}')\n", "\n", " messages = get_messages(question=question)\n", "\n", " for turn in range(max_tool_turns):\n", " logger.debug(f'Turn: {turn}')\n", " response = client.chat.completions.create(\n", " model=model,\n", " messages=messages,\n", " tools=tool_manager.tools,\n", " stream=True,\n", " )\n", " \n", " current_message = {}\n", " tool_calls_accumulator = {}\n", " \n", " output = ''\n", " call_id = None\n", " \n", " for chunk in response:\n", " delta = chunk.choices[0].delta\n", "\n", " logger.debug(f' ✨ {chunk.choices[0]}')\n", " if content := delta.content:\n", " output += content\n", " yield output\n", "\n", " if tool_calls := delta.tool_calls:\n", " for tool_chunk in tool_calls:\n", " print('x' * 50)\n", " print(tool_chunk)\n", "\n", " if tool_chunk.id and call_id != tool_chunk.id:\n", " call_id = tool_chunk.id\n", "\n", " print(f'Call ID: {call_id}')\n", " # Streams of arguments don't come with the call id\n", " # if not call_id:\n", " # continue\n", "\n", " if call_id not in tool_calls_accumulator:\n", " # tool_calls_accumulator[call_id] = {\n", " # 'id': call_id,\n", " # 'function': {'name': '', 'arguments': ''}\n", " # }\n", " tool_calls_accumulator[call_id] = SimpleNamespace(\n", " id=call_id,\n", " function=SimpleNamespace(name='', arguments='')\n", " )\n", "\n", " if tool_chunk.function.name:\n", " tool_calls_accumulator[call_id].function.name += tool_chunk.function.name\n", " \n", " if tool_chunk.function.arguments:\n", " tool_calls_accumulator[call_id].function.arguments += tool_chunk.function.arguments\n", "\n", " if finish_reason := chunk.choices[0].finish_reason:\n", " logger.debug('🧠 LLM interaction ended. Reason: {finish_reason}')\n", "\n", " final_tool_calls = list(tool_calls_accumulator.values())\n", " if final_tool_calls:\n", " logger.debug(f'Final tools to call {final_tool_calls}')\n", "\n", " tool_call_message = {\n", " 'role': 'assistant',\n", " 'content': None,\n", " 'tool_calls': json.loads(json.dumps(final_tool_calls, default=lambda o: o.__dict__))\n", " }\n", "\n", " messages.append(tool_call_message)\n", " tool_messages = tool_manager.run(final_tool_calls)\n", "\n", " if tool_messages:\n", " for tool_msg in tool_messages:\n", " try:\n", " data = json.loads(tool_msg['content'])\n", " if 'plot_image' in data:\n", " logger.debug('We have a plot')\n", " yield f''\n", " return\n", " except:\n", " pass\n", " messages.extend(tool_messages)\n", " else:\n", " return\n", " \n", " except Exception as e:\n", " logger.error(f'đŸ”Ĩ An error occurred during the interaction with the LLM: {e}', exc_info=True)\n", " return str(e)" ] }, { "cell_type": "markdown", "id": "eda786d3-5add-4bd1-804d-13eff60c3d1a", "metadata": {}, "source": [ "### Verify streaming behaviour" ] }, { "cell_type": "code", "execution_count": null, "id": "09bc9a11-adb4-4a9c-9c77-73b2b5a665cf", "metadata": {}, "outputs": [], "source": [ "# print(selected_provider, selected_model)\n", "# print(client)\n", "# for o in ask(client, selected_model, 'What is the time?'):\n", "# for o in ask(client, selected_model, 'What is the temperature?'):\n", "# for o in ask(client, selected_model, 'What is the time and the temperature?'):\n", "# for o in ask(client, selected_model, 'Plot a for the expression sin(x)'):\n", "for o in ask(client, selected_model, 'Plot a graph of y = x**2'):\n", " print(o)" ] }, { "cell_type": "markdown", "id": "27230463", "metadata": {}, "source": [ "## Build Gradio UI" ] }, { "cell_type": "code", "execution_count": null, "id": "50fc3577", "metadata": {}, "outputs": [], "source": [ "def chat(message: str, history: list[dict], selected_provider: str, model_selector: str):\n", " # NOTE: I'm not interesting in maintaining a conversation\n", " response = ask(client, selected_model, message)\n", "\n", " for chunk in response:\n", " yield chunk\n", "\n", "def on_provider_change(change):\n", " global selected_provider, client, models\n", " logger.info(f'Provider changed to {change}')\n", " selected_provider = change\n", " client = clients.get(selected_provider)\n", " load_models_if_needed(client, selected_provider)\n", "\n", " return gr.Dropdown(\n", " choices=models.get(selected_provider, []),\n", " value=selection_state[selected_provider],\n", " interactive=True,\n", " )\n", "\n", "\n", "def on_model_change(change):\n", " global selected_provider, selected_model, selection_state\n", "\n", " selected_model = change\n", " selection_state[selected_provider] = selected_model\n", " logger.info(f'👉 Selected model: {selected_model}')\n", "\n", "\n", "with gr.Blocks(title='MathXpert', fill_width=True, \n", " \n", " ) as ui:\n", " def get_value_if_exist(v, ls) -> str:\n", " print(ls)\n", " selected = v if v in ls else None\n", " if selected:\n", " return selected\n", "\n", " return ls[0] if ls else None\n", "\n", " with gr.Row():\n", " provider_selector = gr.Dropdown(\n", " choices=available_providers, \n", " value=get_desired_value_or_first_item(selected_provider, available_providers),\n", " label='Provider',\n", " )\n", " model_selector = gr.Dropdown(\n", " choices=models[selected_provider],\n", " value=get_desired_value_or_first_item(selection_state[selected_provider], models[selected_provider]),\n", " label='Model',\n", " )\n", " \n", " provider_selector.change(fn=on_provider_change, inputs=provider_selector, outputs=model_selector)\n", " model_selector.change(fn=on_model_change, inputs=model_selector)\n", "\n", " examples = [\n", " ['Where can substitutions be applied in real life?', None, None],\n", " ['Give 1 differential equation question and solve it', None, None],\n", " ['Plot x**2 - 3x', None, None],\n", " ['What is the time now?', None, None],\n", " ['What is the temperature?', None, None],\n", " ['Tell me the time and the temperature now', None, None],\n", " ]\n", "\n", " \n", " gr.ChatInterface(\n", " fn=chat, \n", " type='messages', \n", " chatbot=gr.Chatbot(type='messages', height='75vh', resizable=True),\n", " additional_inputs=[provider_selector, model_selector],\n", " examples=examples,\n", " )\n", "\n", "ui.launch()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }