{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "989184c3-676b-4a68-8841-387ba0776e1d", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", "import gradio as gr\n", "import ollama" ] }, { "cell_type": "code", "execution_count": null, "id": "b0ac9605-d28a-4c19-97e3-1dd3f9ac99ba", "metadata": {}, "outputs": [], "source": [ "system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n", "system_message += \"Give short, courteous answers, no more than 1 sentence. Respond to greetings and general conversation politely.\"\n", "system_message += \"Always be accurate. If you don't know the answer, say so.\"\n", "system_message += \"When a user asks for information that requires external data or action, use the available tools to get that information Specifically\"" ] }, { "cell_type": "code", "execution_count": null, "id": "533e6edf-454a-493d-b0a7-dbc29a5f3930", "metadata": {}, "outputs": [], "source": [ "def chat(message, history):\n", " messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n", " response = ollama.chat(model=\"llama3.2\", messages=messages)\n", " return response['message']['content']\n", "\n", "gr.ChatInterface(fn=chat, type=\"messages\").launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "ac22d421-a241-4c1f-bac4-db2150099ecc", "metadata": {}, "outputs": [], "source": [ "ticket_prices = {\"london\": \"$799\", \"paris\": \"$899\", \"tokyo\": \"$1400\", \"berlin\": \"$499\"}\n", "\n", "def get_ticket_price(destination_city):\n", " print(f\"Tool get_ticket_price called for {destination_city}\")\n", " city = destination_city.lower()\n", " return ticket_prices.get(city, \"Unknown\")" ] }, { "cell_type": "code", "execution_count": null, "id": "2a0381b1-375c-44ac-8757-2fdde2c76541", "metadata": {}, "outputs": [], "source": [ "price_function = {\n", " \"name\": \"get_ticket_price\",\n", " \"description\": \"Get the price of a return ticket to the destination city. Call this whenever you need to know the ticket price, for example when a customer asks 'How much is a ticket to this city'\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to travel to\",\n", " },\n", " },\n", " \"required\": [\"destination_city\"],\n", " \"additionalProperties\": False\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "ce5a7fd0-1ce1-4b53-873e-f55d1e39d847", "metadata": {}, "outputs": [], "source": [ "#tools = [{\"type\": \"function\", \"function\": price_function}]\n", "tools = [\n", " {\n", " \"type\":\"function\",\n", " \"function\":{\n", " \"name\": \"get_ticket_price\",\n", " \"description\": \"Get the price of a return ticket to the destination city. Call this whenever you need to know the ticket price, for example when a customer asks 'How much is a ticket to this city'\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to travel to\"\n", " },\n", " },\n", " \"required\": [\"destination_city\"],\n", " \"additionalProperties\": False\n", " },\n", " },\n", " }\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "06eab709-3f05-4697-a6a8-5f5bc1f442a5", "metadata": {}, "outputs": [], "source": [ "def handle_tool_call(message):\n", " tool_call = message.tool_calls[0]\n", " arguments = tool_call.function.arguments\n", " city = arguments.get('destination_city')\n", " price = get_ticket_price(city)\n", " response = {\n", " \"role\": \"tool\",\n", " \"content\": json.dumps({\"destination_city\": city,\"price\": price}),\n", " # \"tool_call_id\": tool_call.id\n", " }\n", " return response, city" ] }, { "cell_type": "code", "execution_count": null, "id": "d7f9af23-0683-40c3-a70b-0a385754688c", "metadata": {}, "outputs": [], "source": [ "def chat(message, history):\n", " messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n", " response = ollama.chat(model=\"llama3.2\", messages=messages,tools=tools)\n", " if response['message'].get('tool_calls'):\n", " message = response['message']\n", " response, city = handle_tool_call(message)\n", " messages.append(message)\n", " messages.append(response)\n", " response = ollama.chat(model=\"llama3.2\", messages=messages)\n", " \n", " return response['message']['content']" ] }, { "cell_type": "code", "execution_count": null, "id": "fcfa39e2-92ce-48df-b735-f9bbfe638c81", "metadata": {}, "outputs": [], "source": [ "gr.ChatInterface(fn=chat, type=\"messages\").launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "5f5044e9-0ae8-4d88-a22f-d1180ab52434", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.12" } }, "nbformat": 4, "nbformat_minor": 5 }