diff --git a/week2/community-contributions/ranskills-week2-mathxpert-with-tools.ipynb b/week2/community-contributions/ranskills-week2-mathxpert-with-tools.ipynb new file mode 100644 index 0000000..3891b8e --- /dev/null +++ b/week2/community-contributions/ranskills-week2-mathxpert-with-tools.ipynb @@ -0,0 +1,657 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fe12c203-e6a6-452c-a655-afb8a03a4ff5", + "metadata": {}, + "source": [ + "# Week 2 exercise\n", + "\n", + "## MathXpert with tools integration\n", + "\n", + "- Provides the freedom to explore all the models available from the providers\n", + "- Handling of multiple tools calling simultaneously\n", + "- Efficiently run tools in parallel\n", + "- Tool response, i.e. the `plot_function`, that does not require going back to the LLM\n", + "- Uses the inbuilt logging package to allow the control of the verbosity of the logging, set to a higher level, like INFO, to reduce the noisy output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1070317-3ed9-4659-abe3-828943230e03", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import logging\n", + "from enum import StrEnum\n", + "from getpass import getpass\n", + "from types import SimpleNamespace\n", + "from typing import Callable\n", + "\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import ipywidgets as widgets\n", + "from IPython.display import display, clear_output, Latex\n", + "import gradio as gr\n", + "\n", + "load_dotenv(override=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99901b80", + "metadata": {}, + "outputs": [], + "source": [ + "logging.basicConfig(level=logging.WARNING)\n", + "\n", + "logger = logging.getLogger('mathxpert')\n", + "logger.setLevel(logging.DEBUG)" + ] + }, + { + "cell_type": "markdown", + "id": "f169118a-645e-44e1-9a98-4f561adfbb08", + "metadata": {}, + "source": [ + "## Free Cloud Providers\n", + "\n", + "Grab your free API Keys from these generous sites:\n", + "\n", + "- https://openrouter.ai/\n", + "- https://ollama.com/\n", + "\n", + ">**NOTE**: If you do not have a key for any provider, simply press ENTER to move on" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a456906-915a-4bfd-bb9d-57e505c5093f", + "metadata": {}, + "outputs": [], + "source": [ + "class Provider(StrEnum):\n", + " OLLAMA = 'Ollama'\n", + " OPENROUTER = 'OpenRouter'\n", + "\n", + "clients: dict[Provider, OpenAI] = {}\n", + "models: dict[Provider, list[str]] = {\n", + " Provider.OLLAMA: [],\n", + " Provider.OPENROUTER: [],\n", + "}\n", + "\n", + "DEFAULT_PROVIDER = Provider.OLLAMA\n", + "\n", + "selection_state: dict[Provider, str | None] = {\n", + " Provider.OLLAMA: 'gpt-oss:20b',\n", + " Provider.OPENROUTER: 'openai/gpt-oss-20b:free',\n", + "}\n", + "\n", + "def get_secret_in_google_colab(env_name: str) -> str:\n", + " try:\n", + " from google.colab import userdata\n", + " return userdata.get(env_name)\n", + " except Exception:\n", + " return ''\n", + " \n", + "\n", + "def get_secret(env_name: str) -> str:\n", + " '''Gets the value from the environment(s), otherwise ask the user for it if not set'''\n", + " key = os.environ.get(env_name) or get_secret_in_google_colab(env_name)\n", + "\n", + " if not key:\n", + " key = getpass(f'Enter {env_name}:').strip()\n", + "\n", + " if key:\n", + " logger.info(f'✅ {env_name} provided')\n", + " else:\n", + " logger.warning(f'❌ {env_name} not provided')\n", + " return key.strip()\n", + "\n", + "\n", + "if api_key := get_secret('OLLAMA_API_KEY'):\n", + " clients[Provider.OLLAMA] = OpenAI(api_key=api_key, base_url='https://ollama.com/v1')\n", + "\n", + "if api_key := get_secret('OPENROUTER_API_KEY'):\n", + " clients[Provider.OPENROUTER] = OpenAI(api_key=api_key, base_url='https://openrouter.ai/api/v1')\n", + "\n", + "available_providers = [str(p) for p in clients.keys()]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aae1579b-7a02-459d-81c6-0f775d2a1410", + "metadata": {}, + "outputs": [], + "source": [ + "selected_provider, selected_model, client = '', '', None\n", + "\n", + "\n", + "def get_desired_value_or_first_item(desire, options) -> str | None:\n", + " logger.debug(f'Pick {desire} from {options}')\n", + " selected = desire if desire in options else None\n", + " if selected:\n", + " return selected\n", + "\n", + " return options[0] if options else None\n", + " \n", + "try:\n", + " selected_provider = get_desired_value_or_first_item(DEFAULT_PROVIDER, available_providers)\n", + " client = clients.get(selected_provider)\n", + "except Exception:\n", + " logger.warning(f'❌ no provider configured and everything else from here will FAIL đŸ¤Ļ, I know you know this already.')\n", + "\n", + "def load_models_if_needed(client: OpenAI, selected_provider):\n", + " global selected_model, models\n", + "\n", + " if client and not models.get(selected_provider):\n", + " logging.info(f'📡 Fetching {selected_provider} models...')\n", + " \n", + " models[selected_provider] = [model.id for model in client.models.list()]\n", + " selected_model = get_desired_value_or_first_item(\n", + " selection_state[selected_provider], \n", + " models[selected_provider],\n", + " )\n", + "\n", + "load_models_if_needed(client, selected_provider)\n", + "\n", + "logger.info(f'â„šī¸ Provider: {selected_provider} Model: {selected_model}, Client: {client}')" + ] + }, + { + "cell_type": "markdown", + "id": "e04675c2-1b81-4187-868c-c7112cd77e37", + "metadata": {}, + "source": [ + "## Prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8d7923c-5f28-4c30-8556-342d7c8497c1", + "metadata": {}, + "outputs": [], + "source": [ + "def get_messages(question: str) -> list[dict[str, str]]:\n", + " \"\"\"Generate messages for the chat models.\"\"\"\n", + "\n", + " system_prompt = r'''\n", + " You are MathXpert, an expert Mathematician who makes math fun to learn by relating concepts to real \n", + " practical usage to whip up the interest in learners.\n", + " \n", + " Explain step-by-step thoroughly how to solve a math problem. \n", + " - ALWAYS use `$$...$$` for mathematical expressions.\n", + " - NEVER use square brackets `[...]` to delimit math.\n", + " - Example: Instead of \"[x = 2]\", write \"$$x = 2$$\".\n", + " - You may use `\\\\[4pt]` inside matrices for spacing.\n", + " '''\n", + "\n", + " return [\n", + " {'role': 'system', 'content': system_prompt },\n", + " {'role': 'user', 'content': question},\n", + " ]" + ] + }, + { + "cell_type": "markdown", + "id": "caa51866-f433-4b9a-ab20-fff5fc3b7d63", + "metadata": {}, + "source": [ + "## Tools" + ] + }, + { + "cell_type": "markdown", + "id": "a24c659a-5937-43b1-bb95-c0342f2786a9", + "metadata": {}, + "source": [ + "### Tools Definitions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f302f47-9a67-4410-ba16-56fa5a731c66", + "metadata": {}, + "outputs": [], + "source": [ + "from pydantic import BaseModel, Field\n", + "from openai.types.shared_params import FunctionDefinition\n", + "import sympy as sp\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import io\n", + "import base64\n", + "import random\n", + "\n", + "class ToolInput(BaseModel):\n", + " pass\n", + " \n", + "class GetCurrentDateTimeInput(ToolInput):\n", + " timezone: str = Field(default=\"UTC\", description=\"Timezone name, e.g., 'UTC' or 'Africa/Accra'\")\n", + "\n", + "\n", + "def get_current_datetime(req: GetCurrentDateTimeInput):\n", + " '''Returns the current date and time in the specified timezone.'''\n", + " from zoneinfo import ZoneInfo\n", + "\n", + " try:\n", + " from datetime import datetime\n", + " tz = ZoneInfo(req.timezone)\n", + " dt = datetime.now(tz)\n", + " return {\n", + " \"date\": dt.strftime(\"%Y-%m-%d\"),\n", + " \"time\": dt.strftime(\"%H:%M:%S %Z\"),\n", + " } \n", + " except:\n", + " return {\"error\": f\"Invalid timezone: {req.timezone}\"}\n", + "\n", + "\n", + "class GetTemperatureInput(ToolInput):\n", + " pass\n", + "\n", + "def get_temperature(req: GetTemperatureInput) -> float:\n", + " '''Returns the current temperature in degree celsius'''\n", + " return random.randint(-30, 70)\n", + "\n", + "\n", + "class PlotFunctionInput(ToolInput):\n", + " expression: str = Field(description=\"Mathematical expression to plot, e.g., 'sin(x)'\")\n", + " x_min: float = Field(default=-10, description=\"Minimum x value\")\n", + " x_max: float = Field(default=10, description=\"Maximum x value\")\n", + "\n", + "\n", + "def plot_function(req: PlotFunctionInput) -> dict[str, any]:\n", + " '''Plots a mathematical function and returns image data.'''\n", + " try:\n", + " x = sp.symbols('x')\n", + " expr = sp.sympify(req.expression)\n", + " lambdified = sp.lambdify(x, expr, 'numpy')\n", + " \n", + " x_vals = np.linspace(req.x_min, req.x_max, 400)\n", + " y_vals = lambdified(x_vals)\n", + " \n", + " plt.figure(figsize=(10, 6))\n", + " plt.plot(x_vals, y_vals, 'b-', linewidth=2)\n", + " plt.grid(True, alpha=0.3)\n", + " plt.title(f\"Plot of ${sp.latex(expr)}$\", fontsize=14)\n", + " plt.xlabel('x', fontsize=12)\n", + " plt.ylabel('f(x)', fontsize=12)\n", + " \n", + "\n", + " buf = io.BytesIO()\n", + " plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')\n", + " plt.close()\n", + " buf.seek(0)\n", + " img_str = base64.b64encode(buf.read()).decode()\n", + " \n", + " return {\n", + " \"plot_image\": f\"data:image/png;base64,{img_str}\",\n", + " \"expression\": req.expression,\n", + " \"x_range\": [req.x_min, req.x_max]\n", + " }\n", + " except Exception as e:\n", + " return {\"error\": f\"Could not plot function: {str(e)}\"}\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "fae3ef71-f6cd-4894-ae55-9f4f8dd2a1cd", + "metadata": {}, + "source": [ + "### Tools registration & execution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f18bc9f-f8d1-4208-a3d7-e4e911034572", + "metadata": {}, + "outputs": [], + "source": [ + "from concurrent.futures import ThreadPoolExecutor\n", + "\n", + "class ToolManager:\n", + " def __init__(self):\n", + " self._tools = []\n", + " self._tools_map: dict[str, tuple[Callable, ToolInput]] = {}\n", + "\n", + " def register_tool[T: ToolInput](self, fn: Callable, fn_input: T):\n", + " self._tools.append({\n", + " \"type\": \"function\",\n", + " \"function\": FunctionDefinition(\n", + " name=fn.__name__,\n", + " description=fn.__doc__,\n", + " parameters=fn_input.model_json_schema() if fn_input else None,\n", + " )\n", + " })\n", + " \n", + " self._tools_map[fn.__name__] = (fn, fn_input)\n", + "\n", + " def _run_single_tool(self, tool_call) -> dict[str, str] | None:\n", + " if not tool_call.id:\n", + " return None\n", + " \n", + " fn, fn_input = self._tools_map.get(tool_call.function.name)\n", + " args = tool_call.function.arguments\n", + " try:\n", + " if args:\n", + " result = fn(fn_input(**json.loads(args))) if fn_input else fn()\n", + " else:\n", + " result = fn(fn_input()) if fn_input else fn()\n", + " \n", + " logger.debug(f'Tool run result: {result}')\n", + " \n", + " return {\n", + " 'role': 'tool',\n", + " 'tool_call_id': tool_call.id,\n", + " 'content': json.dumps(result),\n", + " }\n", + " except Exception as e:\n", + " logger.error(f'Tool execution failed: {e}', extra={'name': tool_call.function.name})\n", + " return None\n", + "\n", + " def run(self, tool_calls) -> list[dict[str, str]]:\n", + " if not tool_calls:\n", + " return []\n", + "\n", + " logger.debug(tool_calls)\n", + "\n", + " tool_messages = []\n", + " \n", + " with ThreadPoolExecutor() as executor:\n", + " futures = [executor.submit(self._run_single_tool, tool_call) for tool_call in tool_calls]\n", + " \n", + " for future in futures:\n", + " result = future.result()\n", + " if result:\n", + " tool_messages.append(result)\n", + " \n", + " return tool_messages\n", + "\n", + " @property\n", + " def tools(self) -> list[any]:\n", + " return self._tools\n", + "\n", + " def dump_tools(self) -> str:\n", + " return json.dumps(self._tools, indent=True)\n", + "\n", + " \n", + "tool_manager = ToolManager()\n", + "\n", + "tool_manager.register_tool(get_current_datetime, GetCurrentDateTimeInput)\n", + "tool_manager.register_tool(get_temperature, GetTemperatureInput)\n", + "tool_manager.register_tool(plot_function, PlotFunctionInput)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b2e0634-de5d-45f6-a8d4-569e04d14a00", + "metadata": {}, + "outputs": [], + "source": [ + "logger.debug(tool_manager.dump_tools())" + ] + }, + { + "cell_type": "markdown", + "id": "bde4cd2a-b681-4b78-917c-d970c264b151", + "metadata": {}, + "source": [ + "## Interaction with LLM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f7c8ea8-4082-4ad0-8751-3301adcf6538", + "metadata": {}, + "outputs": [], + "source": [ + "# handle = display(None, display_id=True)\n", + "\n", + "def ask(client: OpenAI | None, model: str, question: str, max_tool_turns=5):\n", + " if client is None:\n", + " logger.warning('You should have provided the API Keys you know. Fix 🔧 this and try again â™ģī¸.')\n", + " return\n", + "\n", + " try:\n", + " logger.debug(f'# Tools: {len(tool_manager.tools)}')\n", + "\n", + " messages = get_messages(question=question)\n", + "\n", + " for turn in range(max_tool_turns):\n", + " logger.debug(f'Turn: {turn}')\n", + " response = client.chat.completions.create(\n", + " model=model,\n", + " messages=messages,\n", + " tools=tool_manager.tools,\n", + " stream=True,\n", + " )\n", + " \n", + " current_message = {}\n", + " tool_calls_accumulator = {}\n", + " \n", + " output = ''\n", + " call_id = None\n", + " \n", + " for chunk in response:\n", + " delta = chunk.choices[0].delta\n", + "\n", + " logger.debug(f' ✨ {chunk.choices[0]}')\n", + " if content := delta.content:\n", + " output += content\n", + " yield output\n", + "\n", + " if tool_calls := delta.tool_calls:\n", + " for tool_chunk in tool_calls:\n", + " print('x' * 50)\n", + " print(tool_chunk)\n", + "\n", + " if tool_chunk.id and call_id != tool_chunk.id:\n", + " call_id = tool_chunk.id\n", + "\n", + " print(f'Call ID: {call_id}')\n", + " # Streams of arguments don't come with the call id\n", + " # if not call_id:\n", + " # continue\n", + "\n", + " if call_id not in tool_calls_accumulator:\n", + " # tool_calls_accumulator[call_id] = {\n", + " # 'id': call_id,\n", + " # 'function': {'name': '', 'arguments': ''}\n", + " # }\n", + " tool_calls_accumulator[call_id] = SimpleNamespace(\n", + " id=call_id,\n", + " function=SimpleNamespace(name='', arguments='')\n", + " )\n", + "\n", + " if tool_chunk.function.name:\n", + " tool_calls_accumulator[call_id].function.name += tool_chunk.function.name\n", + " \n", + " if tool_chunk.function.arguments:\n", + " tool_calls_accumulator[call_id].function.arguments += tool_chunk.function.arguments\n", + "\n", + " if finish_reason := chunk.choices[0].finish_reason:\n", + " logger.debug('🧠 LLM interaction ended. Reason: {finish_reason}')\n", + "\n", + " final_tool_calls = list(tool_calls_accumulator.values())\n", + " if final_tool_calls:\n", + " logger.debug(f'Final tools to call {final_tool_calls}')\n", + "\n", + " tool_call_message = {\n", + " 'role': 'assistant',\n", + " 'content': None,\n", + " 'tool_calls': json.loads(json.dumps(final_tool_calls, default=lambda o: o.__dict__))\n", + " }\n", + "\n", + " messages.append(tool_call_message)\n", + " tool_messages = tool_manager.run(final_tool_calls)\n", + "\n", + " if tool_messages:\n", + " for tool_msg in tool_messages:\n", + " try:\n", + " data = json.loads(tool_msg['content'])\n", + " if 'plot_image' in data:\n", + " logger.debug('We have a plot')\n", + " yield f''\n", + " return\n", + " except:\n", + " pass\n", + " messages.extend(tool_messages)\n", + " else:\n", + " return\n", + " \n", + " except Exception as e:\n", + " logger.error(f'đŸ”Ĩ An error occurred during the interaction with the LLM: {e}', exc_info=True)\n", + " return str(e)" + ] + }, + { + "cell_type": "markdown", + "id": "eda786d3-5add-4bd1-804d-13eff60c3d1a", + "metadata": {}, + "source": [ + "### Verify streaming behaviour" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09bc9a11-adb4-4a9c-9c77-73b2b5a665cf", + "metadata": {}, + "outputs": [], + "source": [ + "# print(selected_provider, selected_model)\n", + "# print(client)\n", + "# for o in ask(client, selected_model, 'What is the time?'):\n", + "# for o in ask(client, selected_model, 'What is the temperature?'):\n", + "# for o in ask(client, selected_model, 'What is the time and the temperature?'):\n", + "# for o in ask(client, selected_model, 'Plot a for the expression sin(x)'):\n", + "for o in ask(client, selected_model, 'Plot a graph of y = x**2'):\n", + " print(o)" + ] + }, + { + "cell_type": "markdown", + "id": "27230463", + "metadata": {}, + "source": [ + "## Build Gradio UI" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50fc3577", + "metadata": {}, + "outputs": [], + "source": [ + "def chat(message: str, history: list[dict], selected_provider: str, model_selector: str):\n", + " # NOTE: I'm not interesting in maintaining a conversation\n", + " response = ask(client, selected_model, message)\n", + "\n", + " for chunk in response:\n", + " yield chunk\n", + "\n", + "def on_provider_change(change):\n", + " global selected_provider, client, models\n", + " logger.info(f'Provider changed to {change}')\n", + " selected_provider = change\n", + " client = clients.get(selected_provider)\n", + " load_models_if_needed(client, selected_provider)\n", + "\n", + " return gr.Dropdown(\n", + " choices=models.get(selected_provider, []),\n", + " value=selection_state[selected_provider],\n", + " interactive=True,\n", + " )\n", + "\n", + "\n", + "def on_model_change(change):\n", + " global selected_provider, selected_model, selection_state\n", + "\n", + " selected_model = change\n", + " selection_state[selected_provider] = selected_model\n", + " logger.info(f'👉 Selected model: {selected_model}')\n", + "\n", + "\n", + "with gr.Blocks(title='MathXpert', fill_width=True, \n", + " \n", + " ) as ui:\n", + " def get_value_if_exist(v, ls) -> str:\n", + " print(ls)\n", + " selected = v if v in ls else None\n", + " if selected:\n", + " return selected\n", + "\n", + " return ls[0] if ls else None\n", + "\n", + " with gr.Row():\n", + " provider_selector = gr.Dropdown(\n", + " choices=available_providers, \n", + " value=get_desired_value_or_first_item(selected_provider, available_providers),\n", + " label='Provider',\n", + " )\n", + " model_selector = gr.Dropdown(\n", + " choices=models[selected_provider],\n", + " value=get_desired_value_or_first_item(selection_state[selected_provider], models[selected_provider]),\n", + " label='Model',\n", + " )\n", + " \n", + " provider_selector.change(fn=on_provider_change, inputs=provider_selector, outputs=model_selector)\n", + " model_selector.change(fn=on_model_change, inputs=model_selector)\n", + "\n", + " examples = [\n", + " ['Where can substitutions be applied in real life?', None, None],\n", + " ['Give 1 differential equation question and solve it', None, None],\n", + " ['Plot x**2 - 3x', None, None],\n", + " ['What is the time now?', None, None],\n", + " ['What is the temperature?', None, None],\n", + " ['Tell me the time and the temperature now', None, None],\n", + " ]\n", + "\n", + " \n", + " gr.ChatInterface(\n", + " fn=chat, \n", + " type='messages', \n", + " chatbot=gr.Chatbot(type='messages', height='75vh', resizable=True),\n", + " additional_inputs=[provider_selector, model_selector],\n", + " examples=examples,\n", + " )\n", + "\n", + "ui.launch()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}