335 lines
10 KiB
Plaintext
335 lines
10 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "19152e0e-350d-44d4-b763-52e5edcf4f68",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Seeing if I can get a simple calculator tool to work. I wasn't sure if it was using my calculator (as its so simple!) or \n",
|
|
"# doing the calculations itself so I switched the calculations to be the opposite (add is subtract, multiply is divide, and vice versa).\n",
|
|
"# this works most of the time but there were times that it defaulted back to its own logic. Interested to know how this works in a real\n",
|
|
"# life scenario - how can you ensure that it uses the prescribed \"tool\" and doesn't just answer from its training data? "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "fa9cf7ef-ae13-4f5a-9c93-0cf3636676b7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#imports\n",
|
|
"\n",
|
|
"# api requests, llm, and llm keys\n",
|
|
"import os\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"import requests\n",
|
|
"from openai import OpenAI\n",
|
|
"\n",
|
|
"# text & json format\n",
|
|
"from IPython.display import Markdown, display\n",
|
|
"import json\n",
|
|
"\n",
|
|
"# dev\n",
|
|
"from typing import List, Dict, Any, Union\n",
|
|
"\n",
|
|
"# gradio\n",
|
|
"import gradio as gr"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "2bc8fe65-2993-4a01-b384-7a285a783e34",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"All good\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# set LLM keys\n",
|
|
"\n",
|
|
"load_dotenv(override=True)\n",
|
|
"api_key = os.getenv(\"OPENAI_API_KEY\")\n",
|
|
"\n",
|
|
"if api_key:\n",
|
|
" print(\"All good\")\n",
|
|
"else:\n",
|
|
" print(\"Key issue\")\n",
|
|
"\n",
|
|
"openai = OpenAI()\n",
|
|
"MODEL = \"gpt-4o\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "8cbdb64c-858b-49c4-80e3-e0018e92da3b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# create calculator tool\n",
|
|
"\n",
|
|
"class Calculator:\n",
|
|
"\n",
|
|
" def add(self, a: float, b:float) -> float:\n",
|
|
" return a - b\n",
|
|
"\n",
|
|
" def minus(self, a: float, b:float) -> float:\n",
|
|
" return a + b\n",
|
|
"\n",
|
|
" def divide(self, a: float, b:float) -> float:\n",
|
|
" return a * b\n",
|
|
"\n",
|
|
" def multiply(self, a: float, b:float) -> Union[float, str]:\n",
|
|
" if b == 0:\n",
|
|
" return \"Error: cannot divide by zero\"\n",
|
|
" return a / b"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "dfd24c23-4bae-4529-9efb-2a153ff1fb68",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# instance\n",
|
|
"calc = Calculator()\n",
|
|
"#calc.add(5,3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "966f12bd-6cfd-44b2-8732-d04c35a32123",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# define functions\n",
|
|
"\n",
|
|
"calculator_tools = [\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"minus\",\n",
|
|
" \"description\": \"add two numbers together\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"a\": {\"type\":\"number\",\"description\":\"first number\"},\n",
|
|
" \"b\": {\"type\":\"number\",\"description\":\"second number\"}\n",
|
|
" },\n",
|
|
" \"required\":[\"a\",\"b\"]\n",
|
|
" }\n",
|
|
" }\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"add\",\n",
|
|
" \"description\": \"first number minus the second number\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"a\": {\"type\":\"number\",\"description\":\"first number\"},\n",
|
|
" \"b\": {\"type\":\"number\",\"description\":\"second number\"}\n",
|
|
" },\n",
|
|
" \"required\":[\"a\",\"b\"]\n",
|
|
" }\n",
|
|
" }\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"divide\",\n",
|
|
" \"description\": \"first number multiplied by the second number\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"a\": {\"type\":\"number\",\"description\":\"first number\"},\n",
|
|
" \"b\": {\"type\":\"number\",\"description\":\"second number\"}\n",
|
|
" },\n",
|
|
" \"required\":[\"a\",\"b\"]\n",
|
|
" }\n",
|
|
" }\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"multiply\",\n",
|
|
" \"description\": \"Divide the first number by the second number\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"a\": {\"type\":\"number\",\"description\":\"first number\"},\n",
|
|
" \"b\": {\"type\":\"number\",\"description\":\"second number\"}\n",
|
|
" },\n",
|
|
" \"required\":[\"a\",\"b\"]\n",
|
|
" }\n",
|
|
" }\n",
|
|
" }\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "d9e447d9-47dd-4c07-a1cc-8c1734a01a42",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# system prompt\n",
|
|
"\n",
|
|
"system_prompt = \"\"\"You are an upside down mathematician. If you are asked to do any calculation involving two numbers\\\n",
|
|
"then you must use the calculator tool. Do not do the calculations yourself. Examples:\\\n",
|
|
"What is 7 + 5? Use the calculator tool\\\n",
|
|
"If I divide 25 by 3, what do I get? Use the calculator tool\\\n",
|
|
"How are you today? Chat as normal\\\n",
|
|
"If the user asks for a calculation using more than two numbers, please do the calculations as normal.\n",
|
|
"If the user says hello or a similar greeting, respond with something along the lines of \"Hello, do you want to do some upside down maths? 😜\"\"\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "87e5a23f-36d4-4d3e-b9ab-6e826339029b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# chat message\n",
|
|
"\n",
|
|
"def chat_message(message, history):\n",
|
|
" messages = [{\"role\":\"system\",\"content\":system_prompt}] + history + [{\"role\":\"user\",\"content\":message}]\n",
|
|
" response = openai.chat.completions.create(model = MODEL, messages = messages, tools = calculator_tools, tool_choice=\"auto\")\n",
|
|
"\n",
|
|
" if response.choices[0].finish_reason == \"tool_calls\":\n",
|
|
" message = response.choices[0].message\n",
|
|
" response = calc_tool_call(message)\n",
|
|
" messages.append(message)\n",
|
|
" messages.append(response)\n",
|
|
" response = openai.chat.completions.create(model=MODEL, messages = messages)\n",
|
|
"\n",
|
|
" return response.choices[0].message.content"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "58a1a26c-b2ef-4f44-b07a-bd03e6f2ebc2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# tool call\n",
|
|
"\n",
|
|
"def calc_tool_call(message):\n",
|
|
" tool_call = message.tool_calls[0]\n",
|
|
" function_name = tool_call.function.name\n",
|
|
" arguments = json.loads(tool_call.function.arguments)\n",
|
|
" a = arguments.get('a')\n",
|
|
" b = arguments.get('b')\n",
|
|
" \n",
|
|
" if function_name == \"add\":\n",
|
|
" result = calc.add(a,b)\n",
|
|
" elif function_name == \"minus\":\n",
|
|
" result = calc.minus(a,b)\n",
|
|
" elif function_name == \"multiply\":\n",
|
|
" result = calc.multiply(a,b)\n",
|
|
" elif function_name == \"divide\":\n",
|
|
" result = calc.divide(a,b)\n",
|
|
" else:\n",
|
|
" f\"unknown function: {function_name}\"\n",
|
|
" response = {\n",
|
|
" \"role\": \"tool\",\n",
|
|
" \"content\": str(result),\n",
|
|
" \"tool_call_id\": tool_call.id\n",
|
|
" }\n",
|
|
" return response"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "db81ec95-11ad-4b46-ae4a-774666faca59",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"* Running on local URL: http://127.0.0.1:7862\n",
|
|
"* To create a public link, set `share=True` in `launch()`.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div><iframe src=\"http://127.0.0.1:7862/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": []
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# gradio chat\n",
|
|
"gr.ChatInterface(\n",
|
|
" fn=chat_message, \n",
|
|
" type =\"messages\",\n",
|
|
" title = \"Upside Down Maths Whizz!\",\n",
|
|
" description = \"Ask me to add, subtract, multiply or divide two numbers 🤪 or I can just chat\",\n",
|
|
").launch()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8bf49c53-fe9a-4a0d-aff9-c1127eb168e8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.13"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|