Files
LLM_Engineering_OLD/week2/community-contributions/week2-assignment-Joshua/airline_assistant_exercise.ipynb
2025-10-22 06:01:29 +03:00

520 lines
21 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Week 2 Day 4 Exercise - Enhanced Airline AI Assistant\n",
"\n",
"\n",
"This notebook extends the basic airline assistant with a tool to set ticket prices.\n",
"\n",
"### Key Features:\n",
"- **Get Ticket Price**: Query current ticket prices for destinations\n",
"- **Set Ticket Price**: Update ticket prices for destinations \n",
"- **Database Integration**: Uses SQLite for persistent storage\n",
"- **Multiple Tool Support**: Handles both get and set operations\n",
"- **Gradio Interface**: User-friendly chat interface\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Import necessary libraries\n",
"import os\n",
"import json\n",
"import sqlite3\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import gradio as gr\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"OpenAI API Key exists and begins sk-proj-\n"
]
}
],
"source": [
"# Initialize OpenAI client\n",
"load_dotenv(override=True)\n",
"\n",
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
"if openai_api_key:\n",
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
"else:\n",
" print(\"OpenAI API Key not set\")\n",
" \n",
"MODEL = \"gpt-4o-mini\"\n",
"openai = OpenAI()\n",
"\n",
"# System message for the assistant\n",
"system_message = \"\"\"\n",
"You are a helpful assistant for an Airline called FlightAI.\n",
"Give short, courteous answers, no more than 1 sentence.\n",
"Always be accurate. If you don't know the answer, say so.\n",
"You can get ticket prices and set ticket prices for different cities.\n",
"\"\"\"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Database setup complete!\n"
]
}
],
"source": [
"# Database setup\n",
"DB = \"prices.db\"\n",
"\n",
"def setup_database():\n",
" \"\"\"Initialize the database with the prices table\"\"\"\n",
" with sqlite3.connect(DB) as conn:\n",
" cursor = conn.cursor()\n",
" cursor.execute('CREATE TABLE IF NOT EXISTS prices (city TEXT PRIMARY KEY, price REAL)')\n",
" conn.commit()\n",
" print(\"✅ Database setup complete!\")\n",
"\n",
"# Setup the database\n",
"setup_database()\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🧪 Testing tool functions:\n",
"DATABASE TOOL CALLED: Getting price for London\n",
"No price data available for this city\n"
]
}
],
"source": [
"# Tool functions\n",
"def get_ticket_price(city):\n",
" \"\"\"Get the price of a ticket to a destination city\"\"\"\n",
" print(f\"DATABASE TOOL CALLED: Getting price for {city}\", flush=True)\n",
" with sqlite3.connect(DB) as conn:\n",
" cursor = conn.cursor()\n",
" cursor.execute('SELECT price FROM prices WHERE city = ?', (city.lower(),))\n",
" result = cursor.fetchone()\n",
" return f\"Ticket price to {city} is ${result[0]}\" if result else \"No price data available for this city\"\n",
"\n",
"def set_ticket_price(city, price):\n",
" \"\"\"Set the price of a ticket to a destination city\"\"\"\n",
" print(f\"DATABASE TOOL CALLED: Setting price for {city} to ${price}\", flush=True)\n",
" with sqlite3.connect(DB) as conn:\n",
" cursor = conn.cursor()\n",
" cursor.execute('INSERT INTO prices (city, price) VALUES (?, ?) ON CONFLICT(city) DO UPDATE SET price = ?', (city.lower(), price, price))\n",
" conn.commit()\n",
" return f\"Successfully set ticket price to {city} to ${price}\"\n",
"\n",
"# Test the functions\n",
"print(\"🧪 Testing tool functions:\")\n",
"print(get_ticket_price(\"London\")) # Should show no data initially\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🔧 Tools configured:\n",
" - get_ticket_price: Get the price of a return ticket to the destination city.\n",
" - set_ticket_price: Set the price of a return ticket to a destination city.\n"
]
}
],
"source": [
"# Tool definitions for OpenAI\n",
"get_price_function = {\n",
" \"name\": \"get_ticket_price\",\n",
" \"description\": \"Get the price of a return ticket to the destination city.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"destination_city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city that the customer wants to travel to\",\n",
" },\n",
" },\n",
" \"required\": [\"destination_city\"],\n",
" \"additionalProperties\": False\n",
" }\n",
"}\n",
"\n",
"set_price_function = {\n",
" \"name\": \"set_ticket_price\",\n",
" \"description\": \"Set the price of a return ticket to a destination city.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"destination_city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city to set the price for\",\n",
" },\n",
" \"price\": {\n",
" \"type\": \"number\",\n",
" \"description\": \"The new price for the ticket\",\n",
" },\n",
" },\n",
" \"required\": [\"destination_city\", \"price\"],\n",
" \"additionalProperties\": False\n",
" }\n",
"}\n",
"\n",
"# List of available tools\n",
"tools = [\n",
" {\"type\": \"function\", \"function\": get_price_function},\n",
" {\"type\": \"function\", \"function\": set_price_function}\n",
"]\n",
"\n",
"print(\"🔧 Tools configured:\")\n",
"print(f\" - {get_price_function['name']}: {get_price_function['description']}\")\n",
"print(f\" - {set_price_function['name']}: {set_price_function['description']}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Tool call handler configured!\n"
]
}
],
"source": [
"# Tool call handler\n",
"def handle_tool_calls(message):\n",
" \"\"\"Handle multiple tool calls from the LLM\"\"\"\n",
" responses = []\n",
" for tool_call in message.tool_calls:\n",
" if tool_call.function.name == \"get_ticket_price\":\n",
" arguments = json.loads(tool_call.function.arguments)\n",
" city = arguments.get('destination_city')\n",
" price_details = get_ticket_price(city)\n",
" responses.append({\n",
" \"role\": \"tool\",\n",
" \"content\": price_details,\n",
" \"tool_call_id\": tool_call.id\n",
" })\n",
" elif tool_call.function.name == \"set_ticket_price\":\n",
" arguments = json.loads(tool_call.function.arguments)\n",
" city = arguments.get('destination_city')\n",
" price = arguments.get('price')\n",
" result = set_ticket_price(city, price)\n",
" responses.append({\n",
" \"role\": \"tool\",\n",
" \"content\": result,\n",
" \"tool_call_id\": tool_call.id\n",
" })\n",
" return responses\n",
"\n",
"print(\"✅ Tool call handler configured!\")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Chat function configured!\n"
]
}
],
"source": [
"# Main chat function\n",
"def chat(message, history):\n",
" \"\"\"Main chat function that handles tool calls\"\"\"\n",
" history = [{\"role\":h[\"role\"], \"content\":h[\"content\"]} for h in history]\n",
" messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n",
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
"\n",
" # Handle tool calls in a loop to support multiple consecutive tool calls\n",
" while response.choices[0].finish_reason == \"tool_calls\":\n",
" message = response.choices[0].message\n",
" responses = handle_tool_calls(message)\n",
" messages.append(message)\n",
" messages.extend(responses)\n",
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
" \n",
" return response.choices[0].message.content\n",
"\n",
"print(\"✅ Chat function configured!\")\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DATABASE TOOL CALLED: Setting price for london to $799\n",
"DATABASE TOOL CALLED: Setting price for paris to $899\n",
"DATABASE TOOL CALLED: Setting price for tokyo to $1420\n",
"DATABASE TOOL CALLED: Setting price for sydney to $2999\n",
"DATABASE TOOL CALLED: Setting price for new york to $1099\n",
"DATABASE TOOL CALLED: Setting price for los angeles to $1299\n",
"DATABASE TOOL CALLED: Setting price for san francisco to $1199\n",
"DATABASE TOOL CALLED: Setting price for chicago to $999\n",
"DATABASE TOOL CALLED: Setting price for houston to $1399\n",
"DATABASE TOOL CALLED: Setting price for miami to $1499\n",
"DATABASE TOOL CALLED: Setting price for washington to $1199\n",
"DATABASE TOOL CALLED: Setting price for boston to $1299\n",
"DATABASE TOOL CALLED: Setting price for philadelphia to $1099\n",
"DATABASE TOOL CALLED: Setting price for seattle to $1399\n",
"DATABASE TOOL CALLED: Setting price for san diego to $1299\n",
"DATABASE TOOL CALLED: Setting price for san jose to $1199\n",
"DATABASE TOOL CALLED: Setting price for austin to $1099\n",
"DATABASE TOOL CALLED: Setting price for san antonio to $1399\n",
"DATABASE TOOL CALLED: Setting price for nairobi to $1099\n",
"DATABASE TOOL CALLED: Setting price for cape town to $1299\n",
"DATABASE TOOL CALLED: Setting price for durban to $1199\n",
"DATABASE TOOL CALLED: Setting price for johannesburg to $1399\n",
"DATABASE TOOL CALLED: Setting price for pretoria to $1099\n",
"DATABASE TOOL CALLED: Setting price for bloemfontein to $1299\n",
"DATABASE TOOL CALLED: Setting price for polokwane to $1199\n",
"DATABASE TOOL CALLED: Setting price for port elizabeth to $1199\n",
"DATABASE TOOL CALLED: Setting price for port shepstone to $1399\n",
"DATABASE TOOL CALLED: Setting price for port saint john to $1099\n",
"✅ Sample data initialized!\n",
"\n",
"🧪 Testing the setup:\n",
"DATABASE TOOL CALLED: Getting price for London\n",
"Ticket price to London is $799.0\n",
"DATABASE TOOL CALLED: Getting price for Tokyo\n",
"Ticket price to Tokyo is $1420.0\n"
]
}
],
"source": [
"# Initialize sample data\n",
"def initialize_sample_data():\n",
" \"\"\"Initialize the database with sample ticket prices\"\"\"\n",
" ticket_prices = {\"london\": 799, \"paris\": 899, \"tokyo\": 1420, \"sydney\": 2999, \"new york\": 1099, \"los angeles\": 1299, \"san francisco\": 1199, \"chicago\": 999, \"houston\": 1399, \"miami\": 1499, \"washington\": 1199, \"boston\": 1299, \"philadelphia\": 1099, \"seattle\": 1399, \"san diego\": 1299, \"san jose\": 1199, \"austin\": 1099, \"san antonio\": 1399, \"san francisco\": 1199, \"san diego\": 1299, \"san jose\": 1199, \"austin\": 1099, \"san antonio\": 1399, \"nairobi\": 1099, \"cape town\": 1299, \"durban\": 1199, \"johannesburg\": 1399, \"pretoria\": 1099, \"bloemfontein\": 1299, \"polokwane\": 1199, \"port elizabeth\": 1399, \"port shepstone\": 1099, \"port saint john\": 1299, \"port elizabeth\": 1199, \"port shepstone\": 1399, \"port saint john\": 1099}\n",
" for city, price in ticket_prices.items():\n",
" set_ticket_price(city, price)\n",
" print(\"✅ Sample data initialized!\")\n",
"\n",
"# Initialize sample data\n",
"initialize_sample_data()\n",
"\n",
"# Test the setup\n",
"print(\"\\n🧪 Testing the setup:\")\n",
"print(get_ticket_price(\"London\"))\n",
"print(get_ticket_price(\"Tokyo\"))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Launch the Enhanced Airline Assistant\n",
"\n",
"The assistant now supports both getting and setting ticket prices!\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🚀 Launching FlightAI Assistant with enhanced capabilities...\n",
"📋 Available commands:\n",
" - 'What's the price to London?' (get price)\n",
" - 'Set the price to New York to $1200' (set price)\n",
" - 'Update Tokyo price to $1500' (set price)\n",
" - 'How much does it cost to go to Paris?' (get price)\n",
"* Running on local URL: http://127.0.0.1:7882\n",
"* To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"<div><iframe src=\"http://127.0.0.1:7882/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DATABASE TOOL CALLED: Getting price for Paris\n",
"DATABASE TOOL CALLED: Setting price for Berlin to $9023\n"
]
}
],
"source": [
"# Launch the Gradio interface\n",
"print(\"🚀 Launching FlightAI Assistant with enhanced capabilities...\")\n",
"print(\"📋 Available commands:\")\n",
"print(\" - 'What's the price to London?' (get price)\")\n",
"print(\" - 'Set the price to New York to $1200' (set price)\")\n",
"print(\" - 'Update Tokyo price to $1500' (set price)\")\n",
"print(\" - 'How much does it cost to go to Paris?' (get price)\")\n",
"\n",
"interface = gr.ChatInterface(\n",
" fn=chat, \n",
" type=\"messages\",\n",
" title=\"FlightAI Assistant - Enhanced\",\n",
" description=\"Ask me about ticket prices or set new prices for destinations!\",\n",
" examples=[\n",
" \"What's the price to London?\",\n",
" \"Set the price to New York to $1200\",\n",
" \"How much does it cost to go to Paris?\",\n",
" \"Update Tokyo price to $1500\"\n",
" ]\n",
")\n",
"\n",
"interface.launch()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Key Implementation Features\n",
"\n",
"### 🔧 **Enhanced Tool Support**\n",
"- **Get Ticket Price**: Query current prices from database\n",
"- **Set Ticket Price**: Update prices in database\n",
"- **Multiple Tool Calls**: Handles both operations in sequence\n",
"- **Database Integration**: Persistent SQLite storage\n",
"\n",
"### 🎯 **Tool Function Definitions**\n",
"```python\n",
"# Get Price Tool\n",
"get_price_function = {\n",
" \"name\": \"get_ticket_price\",\n",
" \"description\": \"Get the price of a return ticket to the destination city.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"destination_city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city that the customer wants to travel to\",\n",
" },\n",
" },\n",
" \"required\": [\"destination_city\"],\n",
" \"additionalProperties\": False\n",
" }\n",
"}\n",
"\n",
"# Set Price Tool \n",
"set_price_function = {\n",
" \"name\": \"set_ticket_price\", \n",
" \"description\": \"Set the price of a return ticket to a destination city.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"destination_city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city to set the price for\",\n",
" },\n",
" \"price\": {\n",
" \"type\": \"number\", \n",
" \"description\": \"The new price for the ticket\",\n",
" },\n",
" },\n",
" \"required\": [\"destination_city\", \"price\"],\n",
" \"additionalProperties\": False\n",
" }\n",
"}\n",
"```\n",
"\n",
"### 🚀 **Usage Examples**\n",
"- **Get Price**: \"What's the price to London?\"\n",
"- **Set Price**: \"Set the price to New York to $1200\"\n",
"- **Update Price**: \"Update Tokyo price to $1500\"\n",
"- **Query Multiple**: \"What are the prices to London and Paris?\"\n",
"\n",
"### 💾 **Database Schema**\n",
"```sql\n",
"CREATE TABLE prices (\n",
" city TEXT PRIMARY KEY,\n",
" price REAL\n",
")\n",
"```\n",
"\n",
"This implementation demonstrates advanced tool integration with OpenAI's function calling capabilities!\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}