520 lines
21 KiB
Plaintext
520 lines
21 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Week 2 Day 4 Exercise - Enhanced Airline AI Assistant\n",
|
|
"\n",
|
|
"\n",
|
|
"This notebook extends the basic airline assistant with a tool to set ticket prices.\n",
|
|
"\n",
|
|
"### Key Features:\n",
|
|
"- **Get Ticket Price**: Query current ticket prices for destinations\n",
|
|
"- **Set Ticket Price**: Update ticket prices for destinations \n",
|
|
"- **Database Integration**: Uses SQLite for persistent storage\n",
|
|
"- **Multiple Tool Support**: Handles both get and set operations\n",
|
|
"- **Gradio Interface**: User-friendly chat interface\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Import necessary libraries\n",
|
|
"import os\n",
|
|
"import json\n",
|
|
"import sqlite3\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"from openai import OpenAI\n",
|
|
"import gradio as gr\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"OpenAI API Key exists and begins sk-proj-\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Initialize OpenAI client\n",
|
|
"load_dotenv(override=True)\n",
|
|
"\n",
|
|
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
|
"if openai_api_key:\n",
|
|
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
|
|
"else:\n",
|
|
" print(\"OpenAI API Key not set\")\n",
|
|
" \n",
|
|
"MODEL = \"gpt-4o-mini\"\n",
|
|
"openai = OpenAI()\n",
|
|
"\n",
|
|
"# System message for the assistant\n",
|
|
"system_message = \"\"\"\n",
|
|
"You are a helpful assistant for an Airline called FlightAI.\n",
|
|
"Give short, courteous answers, no more than 1 sentence.\n",
|
|
"Always be accurate. If you don't know the answer, say so.\n",
|
|
"You can get ticket prices and set ticket prices for different cities.\n",
|
|
"\"\"\"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ Database setup complete!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Database setup\n",
|
|
"DB = \"prices.db\"\n",
|
|
"\n",
|
|
"def setup_database():\n",
|
|
" \"\"\"Initialize the database with the prices table\"\"\"\n",
|
|
" with sqlite3.connect(DB) as conn:\n",
|
|
" cursor = conn.cursor()\n",
|
|
" cursor.execute('CREATE TABLE IF NOT EXISTS prices (city TEXT PRIMARY KEY, price REAL)')\n",
|
|
" conn.commit()\n",
|
|
" print(\"✅ Database setup complete!\")\n",
|
|
"\n",
|
|
"# Setup the database\n",
|
|
"setup_database()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"🧪 Testing tool functions:\n",
|
|
"DATABASE TOOL CALLED: Getting price for London\n",
|
|
"No price data available for this city\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Tool functions\n",
|
|
"def get_ticket_price(city):\n",
|
|
" \"\"\"Get the price of a ticket to a destination city\"\"\"\n",
|
|
" print(f\"DATABASE TOOL CALLED: Getting price for {city}\", flush=True)\n",
|
|
" with sqlite3.connect(DB) as conn:\n",
|
|
" cursor = conn.cursor()\n",
|
|
" cursor.execute('SELECT price FROM prices WHERE city = ?', (city.lower(),))\n",
|
|
" result = cursor.fetchone()\n",
|
|
" return f\"Ticket price to {city} is ${result[0]}\" if result else \"No price data available for this city\"\n",
|
|
"\n",
|
|
"def set_ticket_price(city, price):\n",
|
|
" \"\"\"Set the price of a ticket to a destination city\"\"\"\n",
|
|
" print(f\"DATABASE TOOL CALLED: Setting price for {city} to ${price}\", flush=True)\n",
|
|
" with sqlite3.connect(DB) as conn:\n",
|
|
" cursor = conn.cursor()\n",
|
|
" cursor.execute('INSERT INTO prices (city, price) VALUES (?, ?) ON CONFLICT(city) DO UPDATE SET price = ?', (city.lower(), price, price))\n",
|
|
" conn.commit()\n",
|
|
" return f\"Successfully set ticket price to {city} to ${price}\"\n",
|
|
"\n",
|
|
"# Test the functions\n",
|
|
"print(\"🧪 Testing tool functions:\")\n",
|
|
"print(get_ticket_price(\"London\")) # Should show no data initially\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"🔧 Tools configured:\n",
|
|
" - get_ticket_price: Get the price of a return ticket to the destination city.\n",
|
|
" - set_ticket_price: Set the price of a return ticket to a destination city.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Tool definitions for OpenAI\n",
|
|
"get_price_function = {\n",
|
|
" \"name\": \"get_ticket_price\",\n",
|
|
" \"description\": \"Get the price of a return ticket to the destination city.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"destination_city\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The city that the customer wants to travel to\",\n",
|
|
" },\n",
|
|
" },\n",
|
|
" \"required\": [\"destination_city\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
"}\n",
|
|
"\n",
|
|
"set_price_function = {\n",
|
|
" \"name\": \"set_ticket_price\",\n",
|
|
" \"description\": \"Set the price of a return ticket to a destination city.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"destination_city\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The city to set the price for\",\n",
|
|
" },\n",
|
|
" \"price\": {\n",
|
|
" \"type\": \"number\",\n",
|
|
" \"description\": \"The new price for the ticket\",\n",
|
|
" },\n",
|
|
" },\n",
|
|
" \"required\": [\"destination_city\", \"price\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
"}\n",
|
|
"\n",
|
|
"# List of available tools\n",
|
|
"tools = [\n",
|
|
" {\"type\": \"function\", \"function\": get_price_function},\n",
|
|
" {\"type\": \"function\", \"function\": set_price_function}\n",
|
|
"]\n",
|
|
"\n",
|
|
"print(\"🔧 Tools configured:\")\n",
|
|
"print(f\" - {get_price_function['name']}: {get_price_function['description']}\")\n",
|
|
"print(f\" - {set_price_function['name']}: {set_price_function['description']}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ Tool call handler configured!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Tool call handler\n",
|
|
"def handle_tool_calls(message):\n",
|
|
" \"\"\"Handle multiple tool calls from the LLM\"\"\"\n",
|
|
" responses = []\n",
|
|
" for tool_call in message.tool_calls:\n",
|
|
" if tool_call.function.name == \"get_ticket_price\":\n",
|
|
" arguments = json.loads(tool_call.function.arguments)\n",
|
|
" city = arguments.get('destination_city')\n",
|
|
" price_details = get_ticket_price(city)\n",
|
|
" responses.append({\n",
|
|
" \"role\": \"tool\",\n",
|
|
" \"content\": price_details,\n",
|
|
" \"tool_call_id\": tool_call.id\n",
|
|
" })\n",
|
|
" elif tool_call.function.name == \"set_ticket_price\":\n",
|
|
" arguments = json.loads(tool_call.function.arguments)\n",
|
|
" city = arguments.get('destination_city')\n",
|
|
" price = arguments.get('price')\n",
|
|
" result = set_ticket_price(city, price)\n",
|
|
" responses.append({\n",
|
|
" \"role\": \"tool\",\n",
|
|
" \"content\": result,\n",
|
|
" \"tool_call_id\": tool_call.id\n",
|
|
" })\n",
|
|
" return responses\n",
|
|
"\n",
|
|
"print(\"✅ Tool call handler configured!\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ Chat function configured!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Main chat function\n",
|
|
"def chat(message, history):\n",
|
|
" \"\"\"Main chat function that handles tool calls\"\"\"\n",
|
|
" history = [{\"role\":h[\"role\"], \"content\":h[\"content\"]} for h in history]\n",
|
|
" messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n",
|
|
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
|
|
"\n",
|
|
" # Handle tool calls in a loop to support multiple consecutive tool calls\n",
|
|
" while response.choices[0].finish_reason == \"tool_calls\":\n",
|
|
" message = response.choices[0].message\n",
|
|
" responses = handle_tool_calls(message)\n",
|
|
" messages.append(message)\n",
|
|
" messages.extend(responses)\n",
|
|
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
|
|
" \n",
|
|
" return response.choices[0].message.content\n",
|
|
"\n",
|
|
"print(\"✅ Chat function configured!\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"DATABASE TOOL CALLED: Setting price for london to $799\n",
|
|
"DATABASE TOOL CALLED: Setting price for paris to $899\n",
|
|
"DATABASE TOOL CALLED: Setting price for tokyo to $1420\n",
|
|
"DATABASE TOOL CALLED: Setting price for sydney to $2999\n",
|
|
"DATABASE TOOL CALLED: Setting price for new york to $1099\n",
|
|
"DATABASE TOOL CALLED: Setting price for los angeles to $1299\n",
|
|
"DATABASE TOOL CALLED: Setting price for san francisco to $1199\n",
|
|
"DATABASE TOOL CALLED: Setting price for chicago to $999\n",
|
|
"DATABASE TOOL CALLED: Setting price for houston to $1399\n",
|
|
"DATABASE TOOL CALLED: Setting price for miami to $1499\n",
|
|
"DATABASE TOOL CALLED: Setting price for washington to $1199\n",
|
|
"DATABASE TOOL CALLED: Setting price for boston to $1299\n",
|
|
"DATABASE TOOL CALLED: Setting price for philadelphia to $1099\n",
|
|
"DATABASE TOOL CALLED: Setting price for seattle to $1399\n",
|
|
"DATABASE TOOL CALLED: Setting price for san diego to $1299\n",
|
|
"DATABASE TOOL CALLED: Setting price for san jose to $1199\n",
|
|
"DATABASE TOOL CALLED: Setting price for austin to $1099\n",
|
|
"DATABASE TOOL CALLED: Setting price for san antonio to $1399\n",
|
|
"DATABASE TOOL CALLED: Setting price for nairobi to $1099\n",
|
|
"DATABASE TOOL CALLED: Setting price for cape town to $1299\n",
|
|
"DATABASE TOOL CALLED: Setting price for durban to $1199\n",
|
|
"DATABASE TOOL CALLED: Setting price for johannesburg to $1399\n",
|
|
"DATABASE TOOL CALLED: Setting price for pretoria to $1099\n",
|
|
"DATABASE TOOL CALLED: Setting price for bloemfontein to $1299\n",
|
|
"DATABASE TOOL CALLED: Setting price for polokwane to $1199\n",
|
|
"DATABASE TOOL CALLED: Setting price for port elizabeth to $1199\n",
|
|
"DATABASE TOOL CALLED: Setting price for port shepstone to $1399\n",
|
|
"DATABASE TOOL CALLED: Setting price for port saint john to $1099\n",
|
|
"✅ Sample data initialized!\n",
|
|
"\n",
|
|
"🧪 Testing the setup:\n",
|
|
"DATABASE TOOL CALLED: Getting price for London\n",
|
|
"Ticket price to London is $799.0\n",
|
|
"DATABASE TOOL CALLED: Getting price for Tokyo\n",
|
|
"Ticket price to Tokyo is $1420.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Initialize sample data\n",
|
|
"def initialize_sample_data():\n",
|
|
" \"\"\"Initialize the database with sample ticket prices\"\"\"\n",
|
|
" ticket_prices = {\"london\": 799, \"paris\": 899, \"tokyo\": 1420, \"sydney\": 2999, \"new york\": 1099, \"los angeles\": 1299, \"san francisco\": 1199, \"chicago\": 999, \"houston\": 1399, \"miami\": 1499, \"washington\": 1199, \"boston\": 1299, \"philadelphia\": 1099, \"seattle\": 1399, \"san diego\": 1299, \"san jose\": 1199, \"austin\": 1099, \"san antonio\": 1399, \"san francisco\": 1199, \"san diego\": 1299, \"san jose\": 1199, \"austin\": 1099, \"san antonio\": 1399, \"nairobi\": 1099, \"cape town\": 1299, \"durban\": 1199, \"johannesburg\": 1399, \"pretoria\": 1099, \"bloemfontein\": 1299, \"polokwane\": 1199, \"port elizabeth\": 1399, \"port shepstone\": 1099, \"port saint john\": 1299, \"port elizabeth\": 1199, \"port shepstone\": 1399, \"port saint john\": 1099}\n",
|
|
" for city, price in ticket_prices.items():\n",
|
|
" set_ticket_price(city, price)\n",
|
|
" print(\"✅ Sample data initialized!\")\n",
|
|
"\n",
|
|
"# Initialize sample data\n",
|
|
"initialize_sample_data()\n",
|
|
"\n",
|
|
"# Test the setup\n",
|
|
"print(\"\\n🧪 Testing the setup:\")\n",
|
|
"print(get_ticket_price(\"London\"))\n",
|
|
"print(get_ticket_price(\"Tokyo\"))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Launch the Enhanced Airline Assistant\n",
|
|
"\n",
|
|
"The assistant now supports both getting and setting ticket prices!\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"🚀 Launching FlightAI Assistant with enhanced capabilities...\n",
|
|
"📋 Available commands:\n",
|
|
" - 'What's the price to London?' (get price)\n",
|
|
" - 'Set the price to New York to $1200' (set price)\n",
|
|
" - 'Update Tokyo price to $1500' (set price)\n",
|
|
" - 'How much does it cost to go to Paris?' (get price)\n",
|
|
"* Running on local URL: http://127.0.0.1:7882\n",
|
|
"* To create a public link, set `share=True` in `launch()`.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div><iframe src=\"http://127.0.0.1:7882/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": []
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"DATABASE TOOL CALLED: Getting price for Paris\n",
|
|
"DATABASE TOOL CALLED: Setting price for Berlin to $9023\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Launch the Gradio interface\n",
|
|
"print(\"🚀 Launching FlightAI Assistant with enhanced capabilities...\")\n",
|
|
"print(\"📋 Available commands:\")\n",
|
|
"print(\" - 'What's the price to London?' (get price)\")\n",
|
|
"print(\" - 'Set the price to New York to $1200' (set price)\")\n",
|
|
"print(\" - 'Update Tokyo price to $1500' (set price)\")\n",
|
|
"print(\" - 'How much does it cost to go to Paris?' (get price)\")\n",
|
|
"\n",
|
|
"interface = gr.ChatInterface(\n",
|
|
" fn=chat, \n",
|
|
" type=\"messages\",\n",
|
|
" title=\"FlightAI Assistant - Enhanced\",\n",
|
|
" description=\"Ask me about ticket prices or set new prices for destinations!\",\n",
|
|
" examples=[\n",
|
|
" \"What's the price to London?\",\n",
|
|
" \"Set the price to New York to $1200\",\n",
|
|
" \"How much does it cost to go to Paris?\",\n",
|
|
" \"Update Tokyo price to $1500\"\n",
|
|
" ]\n",
|
|
")\n",
|
|
"\n",
|
|
"interface.launch()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Key Implementation Features\n",
|
|
"\n",
|
|
"### 🔧 **Enhanced Tool Support**\n",
|
|
"- **Get Ticket Price**: Query current prices from database\n",
|
|
"- **Set Ticket Price**: Update prices in database\n",
|
|
"- **Multiple Tool Calls**: Handles both operations in sequence\n",
|
|
"- **Database Integration**: Persistent SQLite storage\n",
|
|
"\n",
|
|
"### 🎯 **Tool Function Definitions**\n",
|
|
"```python\n",
|
|
"# Get Price Tool\n",
|
|
"get_price_function = {\n",
|
|
" \"name\": \"get_ticket_price\",\n",
|
|
" \"description\": \"Get the price of a return ticket to the destination city.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"destination_city\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The city that the customer wants to travel to\",\n",
|
|
" },\n",
|
|
" },\n",
|
|
" \"required\": [\"destination_city\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
"}\n",
|
|
"\n",
|
|
"# Set Price Tool \n",
|
|
"set_price_function = {\n",
|
|
" \"name\": \"set_ticket_price\", \n",
|
|
" \"description\": \"Set the price of a return ticket to a destination city.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"destination_city\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The city to set the price for\",\n",
|
|
" },\n",
|
|
" \"price\": {\n",
|
|
" \"type\": \"number\", \n",
|
|
" \"description\": \"The new price for the ticket\",\n",
|
|
" },\n",
|
|
" },\n",
|
|
" \"required\": [\"destination_city\", \"price\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
"}\n",
|
|
"```\n",
|
|
"\n",
|
|
"### 🚀 **Usage Examples**\n",
|
|
"- **Get Price**: \"What's the price to London?\"\n",
|
|
"- **Set Price**: \"Set the price to New York to $1200\"\n",
|
|
"- **Update Price**: \"Update Tokyo price to $1500\"\n",
|
|
"- **Query Multiple**: \"What are the prices to London and Paris?\"\n",
|
|
"\n",
|
|
"### 💾 **Database Schema**\n",
|
|
"```sql\n",
|
|
"CREATE TABLE prices (\n",
|
|
" city TEXT PRIMARY KEY,\n",
|
|
" price REAL\n",
|
|
")\n",
|
|
"```\n",
|
|
"\n",
|
|
"This implementation demonstrates advanced tool integration with OpenAI's function calling capabilities!\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|