{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Week 2 Day 4 Exercise - Enhanced Airline AI Assistant\n", "\n", "\n", "This notebook extends the basic airline assistant with a tool to set ticket prices.\n", "\n", "### Key Features:\n", "- **Get Ticket Price**: Query current ticket prices for destinations\n", "- **Set Ticket Price**: Update ticket prices for destinations \n", "- **Database Integration**: Uses SQLite for persistent storage\n", "- **Multiple Tool Support**: Handles both get and set operations\n", "- **Gradio Interface**: User-friendly chat interface\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Import necessary libraries\n", "import os\n", "import json\n", "import sqlite3\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", "import gradio as gr\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OpenAI API Key exists and begins sk-proj-\n" ] } ], "source": [ "# Initialize OpenAI client\n", "load_dotenv(override=True)\n", "\n", "openai_api_key = os.getenv('OPENAI_API_KEY')\n", "if openai_api_key:\n", " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", "else:\n", " print(\"OpenAI API Key not set\")\n", " \n", "MODEL = \"gpt-4o-mini\"\n", "openai = OpenAI()\n", "\n", "# System message for the assistant\n", "system_message = \"\"\"\n", "You are a helpful assistant for an Airline called FlightAI.\n", "Give short, courteous answers, no more than 1 sentence.\n", "Always be accurate. If you don't know the answer, say so.\n", "You can get ticket prices and set ticket prices for different cities.\n", "\"\"\"\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "โœ… Database setup complete!\n" ] } ], "source": [ "# Database setup\n", "DB = \"prices.db\"\n", "\n", "def setup_database():\n", " \"\"\"Initialize the database with the prices table\"\"\"\n", " with sqlite3.connect(DB) as conn:\n", " cursor = conn.cursor()\n", " cursor.execute('CREATE TABLE IF NOT EXISTS prices (city TEXT PRIMARY KEY, price REAL)')\n", " conn.commit()\n", " print(\"โœ… Database setup complete!\")\n", "\n", "# Setup the database\n", "setup_database()\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "๐Ÿงช Testing tool functions:\n", "DATABASE TOOL CALLED: Getting price for London\n", "No price data available for this city\n" ] } ], "source": [ "# Tool functions\n", "def get_ticket_price(city):\n", " \"\"\"Get the price of a ticket to a destination city\"\"\"\n", " print(f\"DATABASE TOOL CALLED: Getting price for {city}\", flush=True)\n", " with sqlite3.connect(DB) as conn:\n", " cursor = conn.cursor()\n", " cursor.execute('SELECT price FROM prices WHERE city = ?', (city.lower(),))\n", " result = cursor.fetchone()\n", " return f\"Ticket price to {city} is ${result[0]}\" if result else \"No price data available for this city\"\n", "\n", "def set_ticket_price(city, price):\n", " \"\"\"Set the price of a ticket to a destination city\"\"\"\n", " print(f\"DATABASE TOOL CALLED: Setting price for {city} to ${price}\", flush=True)\n", " with sqlite3.connect(DB) as conn:\n", " cursor = conn.cursor()\n", " cursor.execute('INSERT INTO prices (city, price) VALUES (?, ?) ON CONFLICT(city) DO UPDATE SET price = ?', (city.lower(), price, price))\n", " conn.commit()\n", " return f\"Successfully set ticket price to {city} to ${price}\"\n", "\n", "# Test the functions\n", "print(\"๐Ÿงช Testing tool functions:\")\n", "print(get_ticket_price(\"London\")) # Should show no data initially\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "๐Ÿ”ง Tools configured:\n", " - get_ticket_price: Get the price of a return ticket to the destination city.\n", " - set_ticket_price: Set the price of a return ticket to a destination city.\n" ] } ], "source": [ "# Tool definitions for OpenAI\n", "get_price_function = {\n", " \"name\": \"get_ticket_price\",\n", " \"description\": \"Get the price of a return ticket to the destination city.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to travel to\",\n", " },\n", " },\n", " \"required\": [\"destination_city\"],\n", " \"additionalProperties\": False\n", " }\n", "}\n", "\n", "set_price_function = {\n", " \"name\": \"set_ticket_price\",\n", " \"description\": \"Set the price of a return ticket to a destination city.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city to set the price for\",\n", " },\n", " \"price\": {\n", " \"type\": \"number\",\n", " \"description\": \"The new price for the ticket\",\n", " },\n", " },\n", " \"required\": [\"destination_city\", \"price\"],\n", " \"additionalProperties\": False\n", " }\n", "}\n", "\n", "# List of available tools\n", "tools = [\n", " {\"type\": \"function\", \"function\": get_price_function},\n", " {\"type\": \"function\", \"function\": set_price_function}\n", "]\n", "\n", "print(\"๐Ÿ”ง Tools configured:\")\n", "print(f\" - {get_price_function['name']}: {get_price_function['description']}\")\n", "print(f\" - {set_price_function['name']}: {set_price_function['description']}\")\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "โœ… Tool call handler configured!\n" ] } ], "source": [ "# Tool call handler\n", "def handle_tool_calls(message):\n", " \"\"\"Handle multiple tool calls from the LLM\"\"\"\n", " responses = []\n", " for tool_call in message.tool_calls:\n", " if tool_call.function.name == \"get_ticket_price\":\n", " arguments = json.loads(tool_call.function.arguments)\n", " city = arguments.get('destination_city')\n", " price_details = get_ticket_price(city)\n", " responses.append({\n", " \"role\": \"tool\",\n", " \"content\": price_details,\n", " \"tool_call_id\": tool_call.id\n", " })\n", " elif tool_call.function.name == \"set_ticket_price\":\n", " arguments = json.loads(tool_call.function.arguments)\n", " city = arguments.get('destination_city')\n", " price = arguments.get('price')\n", " result = set_ticket_price(city, price)\n", " responses.append({\n", " \"role\": \"tool\",\n", " \"content\": result,\n", " \"tool_call_id\": tool_call.id\n", " })\n", " return responses\n", "\n", "print(\"โœ… Tool call handler configured!\")\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "โœ… Chat function configured!\n" ] } ], "source": [ "# Main chat function\n", "def chat(message, history):\n", " \"\"\"Main chat function that handles tool calls\"\"\"\n", " history = [{\"role\":h[\"role\"], \"content\":h[\"content\"]} for h in history]\n", " messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n", " response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n", "\n", " # Handle tool calls in a loop to support multiple consecutive tool calls\n", " while response.choices[0].finish_reason == \"tool_calls\":\n", " message = response.choices[0].message\n", " responses = handle_tool_calls(message)\n", " messages.append(message)\n", " messages.extend(responses)\n", " response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n", " \n", " return response.choices[0].message.content\n", "\n", "print(\"โœ… Chat function configured!\")\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DATABASE TOOL CALLED: Setting price for london to $799\n", "DATABASE TOOL CALLED: Setting price for paris to $899\n", "DATABASE TOOL CALLED: Setting price for tokyo to $1420\n", "DATABASE TOOL CALLED: Setting price for sydney to $2999\n", "DATABASE TOOL CALLED: Setting price for new york to $1099\n", "DATABASE TOOL CALLED: Setting price for los angeles to $1299\n", "DATABASE TOOL CALLED: Setting price for san francisco to $1199\n", "DATABASE TOOL CALLED: Setting price for chicago to $999\n", "DATABASE TOOL CALLED: Setting price for houston to $1399\n", "DATABASE TOOL CALLED: Setting price for miami to $1499\n", "DATABASE TOOL CALLED: Setting price for washington to $1199\n", "DATABASE TOOL CALLED: Setting price for boston to $1299\n", "DATABASE TOOL CALLED: Setting price for philadelphia to $1099\n", "DATABASE TOOL CALLED: Setting price for seattle to $1399\n", "DATABASE TOOL CALLED: Setting price for san diego to $1299\n", "DATABASE TOOL CALLED: Setting price for san jose to $1199\n", "DATABASE TOOL CALLED: Setting price for austin to $1099\n", "DATABASE TOOL CALLED: Setting price for san antonio to $1399\n", "DATABASE TOOL CALLED: Setting price for nairobi to $1099\n", "DATABASE TOOL CALLED: Setting price for cape town to $1299\n", "DATABASE TOOL CALLED: Setting price for durban to $1199\n", "DATABASE TOOL CALLED: Setting price for johannesburg to $1399\n", "DATABASE TOOL CALLED: Setting price for pretoria to $1099\n", "DATABASE TOOL CALLED: Setting price for bloemfontein to $1299\n", "DATABASE TOOL CALLED: Setting price for polokwane to $1199\n", "DATABASE TOOL CALLED: Setting price for port elizabeth to $1199\n", "DATABASE TOOL CALLED: Setting price for port shepstone to $1399\n", "DATABASE TOOL CALLED: Setting price for port saint john to $1099\n", "โœ… Sample data initialized!\n", "\n", "๐Ÿงช Testing the setup:\n", "DATABASE TOOL CALLED: Getting price for London\n", "Ticket price to London is $799.0\n", "DATABASE TOOL CALLED: Getting price for Tokyo\n", "Ticket price to Tokyo is $1420.0\n" ] } ], "source": [ "# Initialize sample data\n", "def initialize_sample_data():\n", " \"\"\"Initialize the database with sample ticket prices\"\"\"\n", " ticket_prices = {\"london\": 799, \"paris\": 899, \"tokyo\": 1420, \"sydney\": 2999, \"new york\": 1099, \"los angeles\": 1299, \"san francisco\": 1199, \"chicago\": 999, \"houston\": 1399, \"miami\": 1499, \"washington\": 1199, \"boston\": 1299, \"philadelphia\": 1099, \"seattle\": 1399, \"san diego\": 1299, \"san jose\": 1199, \"austin\": 1099, \"san antonio\": 1399, \"san francisco\": 1199, \"san diego\": 1299, \"san jose\": 1199, \"austin\": 1099, \"san antonio\": 1399, \"nairobi\": 1099, \"cape town\": 1299, \"durban\": 1199, \"johannesburg\": 1399, \"pretoria\": 1099, \"bloemfontein\": 1299, \"polokwane\": 1199, \"port elizabeth\": 1399, \"port shepstone\": 1099, \"port saint john\": 1299, \"port elizabeth\": 1199, \"port shepstone\": 1399, \"port saint john\": 1099}\n", " for city, price in ticket_prices.items():\n", " set_ticket_price(city, price)\n", " print(\"โœ… Sample data initialized!\")\n", "\n", "# Initialize sample data\n", "initialize_sample_data()\n", "\n", "# Test the setup\n", "print(\"\\n๐Ÿงช Testing the setup:\")\n", "print(get_ticket_price(\"London\"))\n", "print(get_ticket_price(\"Tokyo\"))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Launch the Enhanced Airline Assistant\n", "\n", "The assistant now supports both getting and setting ticket prices!\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "๐Ÿš€ Launching FlightAI Assistant with enhanced capabilities...\n", "๐Ÿ“‹ Available commands:\n", " - 'What's the price to London?' (get price)\n", " - 'Set the price to New York to $1200' (set price)\n", " - 'Update Tokyo price to $1500' (set price)\n", " - 'How much does it cost to go to Paris?' (get price)\n", "* Running on local URL: http://127.0.0.1:7882\n", "* To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "name": "stdout", "output_type": "stream", "text": [ "DATABASE TOOL CALLED: Getting price for Paris\n", "DATABASE TOOL CALLED: Setting price for Berlin to $9023\n" ] } ], "source": [ "# Launch the Gradio interface\n", "print(\"๐Ÿš€ Launching FlightAI Assistant with enhanced capabilities...\")\n", "print(\"๐Ÿ“‹ Available commands:\")\n", "print(\" - 'What's the price to London?' (get price)\")\n", "print(\" - 'Set the price to New York to $1200' (set price)\")\n", "print(\" - 'Update Tokyo price to $1500' (set price)\")\n", "print(\" - 'How much does it cost to go to Paris?' (get price)\")\n", "\n", "interface = gr.ChatInterface(\n", " fn=chat, \n", " type=\"messages\",\n", " title=\"FlightAI Assistant - Enhanced\",\n", " description=\"Ask me about ticket prices or set new prices for destinations!\",\n", " examples=[\n", " \"What's the price to London?\",\n", " \"Set the price to New York to $1200\",\n", " \"How much does it cost to go to Paris?\",\n", " \"Update Tokyo price to $1500\"\n", " ]\n", ")\n", "\n", "interface.launch()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Key Implementation Features\n", "\n", "### ๐Ÿ”ง **Enhanced Tool Support**\n", "- **Get Ticket Price**: Query current prices from database\n", "- **Set Ticket Price**: Update prices in database\n", "- **Multiple Tool Calls**: Handles both operations in sequence\n", "- **Database Integration**: Persistent SQLite storage\n", "\n", "### ๐ŸŽฏ **Tool Function Definitions**\n", "```python\n", "# Get Price Tool\n", "get_price_function = {\n", " \"name\": \"get_ticket_price\",\n", " \"description\": \"Get the price of a return ticket to the destination city.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to travel to\",\n", " },\n", " },\n", " \"required\": [\"destination_city\"],\n", " \"additionalProperties\": False\n", " }\n", "}\n", "\n", "# Set Price Tool \n", "set_price_function = {\n", " \"name\": \"set_ticket_price\", \n", " \"description\": \"Set the price of a return ticket to a destination city.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city to set the price for\",\n", " },\n", " \"price\": {\n", " \"type\": \"number\", \n", " \"description\": \"The new price for the ticket\",\n", " },\n", " },\n", " \"required\": [\"destination_city\", \"price\"],\n", " \"additionalProperties\": False\n", " }\n", "}\n", "```\n", "\n", "### ๐Ÿš€ **Usage Examples**\n", "- **Get Price**: \"What's the price to London?\"\n", "- **Set Price**: \"Set the price to New York to $1200\"\n", "- **Update Price**: \"Update Tokyo price to $1500\"\n", "- **Query Multiple**: \"What are the prices to London and Paris?\"\n", "\n", "### ๐Ÿ’พ **Database Schema**\n", "```sql\n", "CREATE TABLE prices (\n", " city TEXT PRIMARY KEY,\n", " price REAL\n", ")\n", "```\n", "\n", "This implementation demonstrates advanced tool integration with OpenAI's function calling capabilities!\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 2 }