Files
LLM_Engineering_OLD/week2/community-contributions/ranskills-week2-mathxpert-with-tools.ipynb
2025-10-22 13:02:45 +00:00

658 lines
23 KiB
Plaintext
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "fe12c203-e6a6-452c-a655-afb8a03a4ff5",
"metadata": {},
"source": [
"# Week 2 exercise\n",
"\n",
"## MathXpert with tools integration\n",
"\n",
"- Provides the freedom to explore all the models available from the providers\n",
"- Handling of multiple tools calling simultaneously\n",
"- Efficiently run tools in parallel\n",
"- Tool response, i.e. the `plot_function`, that does not require going back to the LLM\n",
"- Uses the inbuilt logging package to allow the control of the verbosity of the logging, set to a higher level, like INFO, to reduce the noisy output"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1070317-3ed9-4659-abe3-828943230e03",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import logging\n",
"from enum import StrEnum\n",
"from getpass import getpass\n",
"from types import SimpleNamespace\n",
"from typing import Callable\n",
"\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import ipywidgets as widgets\n",
"from IPython.display import display, clear_output, Latex\n",
"import gradio as gr\n",
"\n",
"load_dotenv(override=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99901b80",
"metadata": {},
"outputs": [],
"source": [
"logging.basicConfig(level=logging.WARNING)\n",
"\n",
"logger = logging.getLogger('mathxpert')\n",
"logger.setLevel(logging.DEBUG)"
]
},
{
"cell_type": "markdown",
"id": "f169118a-645e-44e1-9a98-4f561adfbb08",
"metadata": {},
"source": [
"## Free Cloud Providers\n",
"\n",
"Grab your free API Keys from these generous sites:\n",
"\n",
"- https://openrouter.ai/\n",
"- https://ollama.com/\n",
"\n",
">**NOTE**: If you do not have a key for any provider, simply press ENTER to move on"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a456906-915a-4bfd-bb9d-57e505c5093f",
"metadata": {},
"outputs": [],
"source": [
"class Provider(StrEnum):\n",
" OLLAMA = 'Ollama'\n",
" OPENROUTER = 'OpenRouter'\n",
"\n",
"clients: dict[Provider, OpenAI] = {}\n",
"models: dict[Provider, list[str]] = {\n",
" Provider.OLLAMA: [],\n",
" Provider.OPENROUTER: [],\n",
"}\n",
"\n",
"DEFAULT_PROVIDER = Provider.OLLAMA\n",
"\n",
"selection_state: dict[Provider, str | None] = {\n",
" Provider.OLLAMA: 'gpt-oss:20b',\n",
" Provider.OPENROUTER: 'openai/gpt-oss-20b:free',\n",
"}\n",
"\n",
"def get_secret_in_google_colab(env_name: str) -> str:\n",
" try:\n",
" from google.colab import userdata\n",
" return userdata.get(env_name)\n",
" except Exception:\n",
" return ''\n",
" \n",
"\n",
"def get_secret(env_name: str) -> str:\n",
" '''Gets the value from the environment(s), otherwise ask the user for it if not set'''\n",
" key = os.environ.get(env_name) or get_secret_in_google_colab(env_name)\n",
"\n",
" if not key:\n",
" key = getpass(f'Enter {env_name}:').strip()\n",
"\n",
" if key:\n",
" logger.info(f'✅ {env_name} provided')\n",
" else:\n",
" logger.warning(f'❌ {env_name} not provided')\n",
" return key.strip()\n",
"\n",
"\n",
"if api_key := get_secret('OLLAMA_API_KEY'):\n",
" clients[Provider.OLLAMA] = OpenAI(api_key=api_key, base_url='https://ollama.com/v1')\n",
"\n",
"if api_key := get_secret('OPENROUTER_API_KEY'):\n",
" clients[Provider.OPENROUTER] = OpenAI(api_key=api_key, base_url='https://openrouter.ai/api/v1')\n",
"\n",
"available_providers = [str(p) for p in clients.keys()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aae1579b-7a02-459d-81c6-0f775d2a1410",
"metadata": {},
"outputs": [],
"source": [
"selected_provider, selected_model, client = '', '', None\n",
"\n",
"\n",
"def get_desired_value_or_first_item(desire, options) -> str | None:\n",
" logger.debug(f'Pick {desire} from {options}')\n",
" selected = desire if desire in options else None\n",
" if selected:\n",
" return selected\n",
"\n",
" return options[0] if options else None\n",
" \n",
"try:\n",
" selected_provider = get_desired_value_or_first_item(DEFAULT_PROVIDER, available_providers)\n",
" client = clients.get(selected_provider)\n",
"except Exception:\n",
" logger.warning(f'❌ no provider configured and everything else from here will FAIL 🤦, I know you know this already.')\n",
"\n",
"def load_models_if_needed(client: OpenAI, selected_provider):\n",
" global selected_model, models\n",
"\n",
" if client and not models.get(selected_provider):\n",
" logging.info(f'📡 Fetching {selected_provider} models...')\n",
" \n",
" models[selected_provider] = [model.id for model in client.models.list()]\n",
" selected_model = get_desired_value_or_first_item(\n",
" selection_state[selected_provider], \n",
" models[selected_provider],\n",
" )\n",
"\n",
"load_models_if_needed(client, selected_provider)\n",
"\n",
"logger.info(f' Provider: {selected_provider} Model: {selected_model}, Client: {client}')"
]
},
{
"cell_type": "markdown",
"id": "e04675c2-1b81-4187-868c-c7112cd77e37",
"metadata": {},
"source": [
"## Prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8d7923c-5f28-4c30-8556-342d7c8497c1",
"metadata": {},
"outputs": [],
"source": [
"def get_messages(question: str) -> list[dict[str, str]]:\n",
" \"\"\"Generate messages for the chat models.\"\"\"\n",
"\n",
" system_prompt = r'''\n",
" You are MathXpert, an expert Mathematician who makes math fun to learn by relating concepts to real \n",
" practical usage to whip up the interest in learners.\n",
" \n",
" Explain step-by-step thoroughly how to solve a math problem. \n",
" - ALWAYS use `$$...$$` for mathematical expressions.\n",
" - NEVER use square brackets `[...]` to delimit math.\n",
" - Example: Instead of \"[x = 2]\", write \"$$x = 2$$\".\n",
" - You may use `\\\\[4pt]` inside matrices for spacing.\n",
" '''\n",
"\n",
" return [\n",
" {'role': 'system', 'content': system_prompt },\n",
" {'role': 'user', 'content': question},\n",
" ]"
]
},
{
"cell_type": "markdown",
"id": "caa51866-f433-4b9a-ab20-fff5fc3b7d63",
"metadata": {},
"source": [
"## Tools"
]
},
{
"cell_type": "markdown",
"id": "a24c659a-5937-43b1-bb95-c0342f2786a9",
"metadata": {},
"source": [
"### Tools Definitions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f302f47-9a67-4410-ba16-56fa5a731c66",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel, Field\n",
"from openai.types.shared_params import FunctionDefinition\n",
"import sympy as sp\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import io\n",
"import base64\n",
"import random\n",
"\n",
"class ToolInput(BaseModel):\n",
" pass\n",
" \n",
"class GetCurrentDateTimeInput(ToolInput):\n",
" timezone: str = Field(default=\"UTC\", description=\"Timezone name, e.g., 'UTC' or 'Africa/Accra'\")\n",
"\n",
"\n",
"def get_current_datetime(req: GetCurrentDateTimeInput):\n",
" '''Returns the current date and time in the specified timezone.'''\n",
" from zoneinfo import ZoneInfo\n",
"\n",
" try:\n",
" from datetime import datetime\n",
" tz = ZoneInfo(req.timezone)\n",
" dt = datetime.now(tz)\n",
" return {\n",
" \"date\": dt.strftime(\"%Y-%m-%d\"),\n",
" \"time\": dt.strftime(\"%H:%M:%S %Z\"),\n",
" } \n",
" except:\n",
" return {\"error\": f\"Invalid timezone: {req.timezone}\"}\n",
"\n",
"\n",
"class GetTemperatureInput(ToolInput):\n",
" pass\n",
"\n",
"def get_temperature(req: GetTemperatureInput) -> float:\n",
" '''Returns the current temperature in degree celsius'''\n",
" return random.randint(-30, 70)\n",
"\n",
"\n",
"class PlotFunctionInput(ToolInput):\n",
" expression: str = Field(description=\"Mathematical expression to plot, e.g., 'sin(x)'\")\n",
" x_min: float = Field(default=-10, description=\"Minimum x value\")\n",
" x_max: float = Field(default=10, description=\"Maximum x value\")\n",
"\n",
"\n",
"def plot_function(req: PlotFunctionInput) -> dict[str, any]:\n",
" '''Plots a mathematical function and returns image data.'''\n",
" try:\n",
" x = sp.symbols('x')\n",
" expr = sp.sympify(req.expression)\n",
" lambdified = sp.lambdify(x, expr, 'numpy')\n",
" \n",
" x_vals = np.linspace(req.x_min, req.x_max, 400)\n",
" y_vals = lambdified(x_vals)\n",
" \n",
" plt.figure(figsize=(10, 6))\n",
" plt.plot(x_vals, y_vals, 'b-', linewidth=2)\n",
" plt.grid(True, alpha=0.3)\n",
" plt.title(f\"Plot of ${sp.latex(expr)}$\", fontsize=14)\n",
" plt.xlabel('x', fontsize=12)\n",
" plt.ylabel('f(x)', fontsize=12)\n",
" \n",
"\n",
" buf = io.BytesIO()\n",
" plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')\n",
" plt.close()\n",
" buf.seek(0)\n",
" img_str = base64.b64encode(buf.read()).decode()\n",
" \n",
" return {\n",
" \"plot_image\": f\"data:image/png;base64,{img_str}\",\n",
" \"expression\": req.expression,\n",
" \"x_range\": [req.x_min, req.x_max]\n",
" }\n",
" except Exception as e:\n",
" return {\"error\": f\"Could not plot function: {str(e)}\"}\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "fae3ef71-f6cd-4894-ae55-9f4f8dd2a1cd",
"metadata": {},
"source": [
"### Tools registration & execution"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4f18bc9f-f8d1-4208-a3d7-e4e911034572",
"metadata": {},
"outputs": [],
"source": [
"from concurrent.futures import ThreadPoolExecutor\n",
"\n",
"class ToolManager:\n",
" def __init__(self):\n",
" self._tools = []\n",
" self._tools_map: dict[str, tuple[Callable, ToolInput]] = {}\n",
"\n",
" def register_tool[T: ToolInput](self, fn: Callable, fn_input: T):\n",
" self._tools.append({\n",
" \"type\": \"function\",\n",
" \"function\": FunctionDefinition(\n",
" name=fn.__name__,\n",
" description=fn.__doc__,\n",
" parameters=fn_input.model_json_schema() if fn_input else None,\n",
" )\n",
" })\n",
" \n",
" self._tools_map[fn.__name__] = (fn, fn_input)\n",
"\n",
" def _run_single_tool(self, tool_call) -> dict[str, str] | None:\n",
" if not tool_call.id:\n",
" return None\n",
" \n",
" fn, fn_input = self._tools_map.get(tool_call.function.name)\n",
" args = tool_call.function.arguments\n",
" try:\n",
" if args:\n",
" result = fn(fn_input(**json.loads(args))) if fn_input else fn()\n",
" else:\n",
" result = fn(fn_input()) if fn_input else fn()\n",
" \n",
" logger.debug(f'Tool run result: {result}')\n",
" \n",
" return {\n",
" 'role': 'tool',\n",
" 'tool_call_id': tool_call.id,\n",
" 'content': json.dumps(result),\n",
" }\n",
" except Exception as e:\n",
" logger.error(f'Tool execution failed: {e}', extra={'name': tool_call.function.name})\n",
" return None\n",
"\n",
" def run(self, tool_calls) -> list[dict[str, str]]:\n",
" if not tool_calls:\n",
" return []\n",
"\n",
" logger.debug(tool_calls)\n",
"\n",
" tool_messages = []\n",
" \n",
" with ThreadPoolExecutor() as executor:\n",
" futures = [executor.submit(self._run_single_tool, tool_call) for tool_call in tool_calls]\n",
" \n",
" for future in futures:\n",
" result = future.result()\n",
" if result:\n",
" tool_messages.append(result)\n",
" \n",
" return tool_messages\n",
"\n",
" @property\n",
" def tools(self) -> list[any]:\n",
" return self._tools\n",
"\n",
" def dump_tools(self) -> str:\n",
" return json.dumps(self._tools, indent=True)\n",
"\n",
" \n",
"tool_manager = ToolManager()\n",
"\n",
"tool_manager.register_tool(get_current_datetime, GetCurrentDateTimeInput)\n",
"tool_manager.register_tool(get_temperature, GetTemperatureInput)\n",
"tool_manager.register_tool(plot_function, PlotFunctionInput)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b2e0634-de5d-45f6-a8d4-569e04d14a00",
"metadata": {},
"outputs": [],
"source": [
"logger.debug(tool_manager.dump_tools())"
]
},
{
"cell_type": "markdown",
"id": "bde4cd2a-b681-4b78-917c-d970c264b151",
"metadata": {},
"source": [
"## Interaction with LLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f7c8ea8-4082-4ad0-8751-3301adcf6538",
"metadata": {},
"outputs": [],
"source": [
"# handle = display(None, display_id=True)\n",
"\n",
"def ask(client: OpenAI | None, model: str, question: str, max_tool_turns=5):\n",
" if client is None:\n",
" logger.warning('You should have provided the API Keys you know. Fix 🔧 this and try again ♻️.')\n",
" return\n",
"\n",
" try:\n",
" logger.debug(f'# Tools: {len(tool_manager.tools)}')\n",
"\n",
" messages = get_messages(question=question)\n",
"\n",
" for turn in range(max_tool_turns):\n",
" logger.debug(f'Turn: {turn}')\n",
" response = client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" tools=tool_manager.tools,\n",
" stream=True,\n",
" )\n",
" \n",
" current_message = {}\n",
" tool_calls_accumulator = {}\n",
" \n",
" output = ''\n",
" call_id = None\n",
" \n",
" for chunk in response:\n",
" delta = chunk.choices[0].delta\n",
"\n",
" logger.debug(f' ✨ {chunk.choices[0]}')\n",
" if content := delta.content:\n",
" output += content\n",
" yield output\n",
"\n",
" if tool_calls := delta.tool_calls:\n",
" for tool_chunk in tool_calls:\n",
" print('x' * 50)\n",
" print(tool_chunk)\n",
"\n",
" if tool_chunk.id and call_id != tool_chunk.id:\n",
" call_id = tool_chunk.id\n",
"\n",
" print(f'Call ID: {call_id}')\n",
" # Streams of arguments don't come with the call id\n",
" # if not call_id:\n",
" # continue\n",
"\n",
" if call_id not in tool_calls_accumulator:\n",
" # tool_calls_accumulator[call_id] = {\n",
" # 'id': call_id,\n",
" # 'function': {'name': '', 'arguments': ''}\n",
" # }\n",
" tool_calls_accumulator[call_id] = SimpleNamespace(\n",
" id=call_id,\n",
" function=SimpleNamespace(name='', arguments='')\n",
" )\n",
"\n",
" if tool_chunk.function.name:\n",
" tool_calls_accumulator[call_id].function.name += tool_chunk.function.name\n",
" \n",
" if tool_chunk.function.arguments:\n",
" tool_calls_accumulator[call_id].function.arguments += tool_chunk.function.arguments\n",
"\n",
" if finish_reason := chunk.choices[0].finish_reason:\n",
" logger.debug('🧠 LLM interaction ended. Reason: {finish_reason}')\n",
"\n",
" final_tool_calls = list(tool_calls_accumulator.values())\n",
" if final_tool_calls:\n",
" logger.debug(f'Final tools to call {final_tool_calls}')\n",
"\n",
" tool_call_message = {\n",
" 'role': 'assistant',\n",
" 'content': None,\n",
" 'tool_calls': json.loads(json.dumps(final_tool_calls, default=lambda o: o.__dict__))\n",
" }\n",
"\n",
" messages.append(tool_call_message)\n",
" tool_messages = tool_manager.run(final_tool_calls)\n",
"\n",
" if tool_messages:\n",
" for tool_msg in tool_messages:\n",
" try:\n",
" data = json.loads(tool_msg['content'])\n",
" if 'plot_image' in data:\n",
" logger.debug('We have a plot')\n",
" yield f'<img src=\"{data[\"plot_image\"]}\" style=\"max-width: 100%; height: auto; border: 1px solid #ccc; border-radius: 5px;\">'\n",
" return\n",
" except:\n",
" pass\n",
" messages.extend(tool_messages)\n",
" else:\n",
" return\n",
" \n",
" except Exception as e:\n",
" logger.error(f'🔥 An error occurred during the interaction with the LLM: {e}', exc_info=True)\n",
" return str(e)"
]
},
{
"cell_type": "markdown",
"id": "eda786d3-5add-4bd1-804d-13eff60c3d1a",
"metadata": {},
"source": [
"### Verify streaming behaviour"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09bc9a11-adb4-4a9c-9c77-73b2b5a665cf",
"metadata": {},
"outputs": [],
"source": [
"# print(selected_provider, selected_model)\n",
"# print(client)\n",
"# for o in ask(client, selected_model, 'What is the time?'):\n",
"# for o in ask(client, selected_model, 'What is the temperature?'):\n",
"# for o in ask(client, selected_model, 'What is the time and the temperature?'):\n",
"# for o in ask(client, selected_model, 'Plot a for the expression sin(x)'):\n",
"for o in ask(client, selected_model, 'Plot a graph of y = x**2'):\n",
" print(o)"
]
},
{
"cell_type": "markdown",
"id": "27230463",
"metadata": {},
"source": [
"## Build Gradio UI"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50fc3577",
"metadata": {},
"outputs": [],
"source": [
"def chat(message: str, history: list[dict], selected_provider: str, model_selector: str):\n",
" # NOTE: I'm not interesting in maintaining a conversation\n",
" response = ask(client, selected_model, message)\n",
"\n",
" for chunk in response:\n",
" yield chunk\n",
"\n",
"def on_provider_change(change):\n",
" global selected_provider, client, models\n",
" logger.info(f'Provider changed to {change}')\n",
" selected_provider = change\n",
" client = clients.get(selected_provider)\n",
" load_models_if_needed(client, selected_provider)\n",
"\n",
" return gr.Dropdown(\n",
" choices=models.get(selected_provider, []),\n",
" value=selection_state[selected_provider],\n",
" interactive=True,\n",
" )\n",
"\n",
"\n",
"def on_model_change(change):\n",
" global selected_provider, selected_model, selection_state\n",
"\n",
" selected_model = change\n",
" selection_state[selected_provider] = selected_model\n",
" logger.info(f'👉 Selected model: {selected_model}')\n",
"\n",
"\n",
"with gr.Blocks(title='MathXpert', fill_width=True, \n",
" \n",
" ) as ui:\n",
" def get_value_if_exist(v, ls) -> str:\n",
" print(ls)\n",
" selected = v if v in ls else None\n",
" if selected:\n",
" return selected\n",
"\n",
" return ls[0] if ls else None\n",
"\n",
" with gr.Row():\n",
" provider_selector = gr.Dropdown(\n",
" choices=available_providers, \n",
" value=get_desired_value_or_first_item(selected_provider, available_providers),\n",
" label='Provider',\n",
" )\n",
" model_selector = gr.Dropdown(\n",
" choices=models[selected_provider],\n",
" value=get_desired_value_or_first_item(selection_state[selected_provider], models[selected_provider]),\n",
" label='Model',\n",
" )\n",
" \n",
" provider_selector.change(fn=on_provider_change, inputs=provider_selector, outputs=model_selector)\n",
" model_selector.change(fn=on_model_change, inputs=model_selector)\n",
"\n",
" examples = [\n",
" ['Where can substitutions be applied in real life?', None, None],\n",
" ['Give 1 differential equation question and solve it', None, None],\n",
" ['Plot x**2 - 3x', None, None],\n",
" ['What is the time now?', None, None],\n",
" ['What is the temperature?', None, None],\n",
" ['Tell me the time and the temperature now', None, None],\n",
" ]\n",
"\n",
" \n",
" gr.ChatInterface(\n",
" fn=chat, \n",
" type='messages', \n",
" chatbot=gr.Chatbot(type='messages', height='75vh', resizable=True),\n",
" additional_inputs=[provider_selector, model_selector],\n",
" examples=examples,\n",
" )\n",
"\n",
"ui.launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}