{ "cells": [ { "cell_type": "markdown", "id": "fe12c203-e6a6-452c-a655-afb8a03a4ff5", "metadata": {}, "source": [ "# End of week 1 exercise\n", "\n", "## Dynamically pick an LLM provider to let MathXpert answer your math questions." ] }, { "cell_type": "code", "execution_count": null, "id": "c1070317-3ed9-4659-abe3-828943230e03", "metadata": {}, "outputs": [], "source": [ "import os\n", "from enum import StrEnum\n", "from getpass import getpass\n", "\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, Markdown, Latex\n", "\n", "load_dotenv()" ] }, { "cell_type": "markdown", "id": "f169118a-645e-44e1-9a98-4f561adfbb08", "metadata": {}, "source": [ "## Free Cloud Providers\n", "\n", "Grab your free API Keys from these generous sites:\n", "\n", "- https://openrouter.ai/\n", "- https://ollama.com/" ] }, { "cell_type": "code", "execution_count": null, "id": "4a456906-915a-4bfd-bb9d-57e505c5093f", "metadata": {}, "outputs": [], "source": [ "class Provider(StrEnum):\n", " OLLAMA = 'Ollama'\n", " OPENROUTER = 'OpenRouter'\n", "\n", "models: dict[Provider, str] = {\n", " Provider.OLLAMA: 'gpt-oss:120b-cloud',\n", " Provider.OPENROUTER: 'qwen/qwen3-4b:free'\n", "}\n", "\n", "def get_api_key(env_name: str) -> str:\n", " '''Gets the value from the environment, otherwise ask the user for it if not set'''\n", " key = os.environ.get(env_name)\n", " if not key:\n", " key = getpass(f'Enter {env_name}:').strip()\n", "\n", " if key:\n", " print(f'✅ {env_name} provided')\n", " else:\n", " print(f'❌ {env_name} provided')\n", " return key\n", "\n", "\n", "providers: dict[Provider, OpenAI] = {}\n", "\n", "if api_key := get_api_key('OLLAMA_API_KEY'):\n", " providers[Provider.OLLAMA] = OpenAI(base_url='https://ollama.com/v1', api_key=api_key)\n", "\n", "if api_key := get_api_key('OPENROUTER_API_KEY'):\n", " providers[Provider.OPENROUTER] = OpenAI(base_url='https://openrouter.ai/api/v1', api_key=api_key)" ] }, { "cell_type": "code", "execution_count": null, "id": "a8d7923c-5f28-4c30-8556-342d7c8497c1", "metadata": {}, "outputs": [], "source": [ "def get_messages(question: str) -> list[dict[str, str]]:\n", " \"\"\"Generate messages for the chat models.\"\"\"\n", "\n", " system_prompt = '''\n", " You are MathXpert, an expert Mathematician who makes math fun to learn by relating concepts to real \n", " practical usage to whip up the interest in learners.\n", " \n", " Explain step-by-step thoroughly how to solve a math problem. Respond in **LaTex**'\n", " '''\n", "\n", " return [\n", " {'role': 'system', 'content': system_prompt },\n", " {'role': 'user', 'content': question},\n", " ]" ] }, { "cell_type": "code", "execution_count": null, "id": "ef72c85e", "metadata": {}, "outputs": [], "source": [ "get_messages('Explain how to solve a differentiation problem')" ] }, { "cell_type": "code", "execution_count": null, "id": "aae1579b-7a02-459d-81c6-0f775d2a1410", "metadata": {}, "outputs": [], "source": [ "selected_provider, client = next(iter(providers.items()))\n", "\n", "def on_provider_change(change):\n", " global selected_provider, client\n", "\n", " selected_provider = change['new']\n", " client = providers.get(selected_provider)\n", "\n", "\n", "provider_selector = widgets.Dropdown(\n", " options=list(providers.keys()),\n", " description='Select an LLM provider:',\n", " style={'description_width': 'initial'},\n", ")\n", "\n", "provider_selector.observe(on_provider_change, names='value')" ] }, { "cell_type": "code", "execution_count": null, "id": "8f7c8ea8-4082-4ad0-8751-3301adcf6538", "metadata": {}, "outputs": [], "source": [ "handle = display(None, display_id=True)\n", "\n", "def ask(client: OpenAI, model: str, question: str):\n", " try:\n", " prompt = get_messages(question=question)\n", " response = client.chat.completions.create(\n", " model=model,\n", " messages=prompt,\n", " stream=True,\n", " )\n", " \n", " output = ''\n", " for chunk in response:\n", " output += chunk.choices[0].delta.content or ''\n", " \n", " handle.update(Latex(output))\n", " except Exception as e:\n", " clear_output(wait=True) \n", " print(f'🔥 An error occurred: {e}')" ] }, { "cell_type": "code", "execution_count": null, "id": "09bc9a11-adb4-4a9c-9c77-73b2b5a665cf", "metadata": {}, "outputs": [], "source": [ "display(provider_selector)" ] }, { "cell_type": "code", "execution_count": null, "id": "e01069b2-fd2c-446f-b385-09c1d9624225", "metadata": {}, "outputs": [], "source": [ "input_label = \"Ask your question (Type 'q' to quit): \"\n", "question = input(input_label)\n", "\n", "while question.strip().lower() not in ['quit', 'q']:\n", " clear_output(wait=True)\n", " print(selected_provider, models[selected_provider])\n", " model = models[selected_provider]\n", " ask(client, model, question)\n", "\n", " question = input(input_label)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }