458 lines
14 KiB
Plaintext
458 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ddfa9ae6-69fe-444a-b994-8c4c5970a7ec",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Project - Airline AI Assistant\n",
|
|
"\n",
|
|
"We'll now bring together what we've learned to make an AI Customer Support assistant for an Airline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8b50bbe2-c0b1-49c3-9a5c-1ba7efa2bcb4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# imports\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import json\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"from openai import OpenAI\n",
|
|
"import gradio as gr\n",
|
|
"import sqlite3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "747e8786-9da8-4342-b6c9-f5f69c2e22ae",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Initialization\n",
|
|
"\n",
|
|
"load_dotenv(override=True)\n",
|
|
"\n",
|
|
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
|
"if openai_api_key:\n",
|
|
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
|
|
"else:\n",
|
|
" print(\"OpenAI API Key not set\")\n",
|
|
" \n",
|
|
"MODEL = \"gpt-4.1-mini\"\n",
|
|
"openai = OpenAI()\n",
|
|
"\n",
|
|
"DB = \"prices.db\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0a521d84-d07c-49ab-a0df-d6451499ed97",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"system_message = \"\"\"\n",
|
|
"You are a helpful assistant for an Airline called FlightAI.\n",
|
|
"Give short, courteous answers, no more than 1 sentence.\n",
|
|
"Always be accurate. If you don't know the answer, say so.\n",
|
|
"\"\"\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c3e8173c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_ticket_price(city):\n",
|
|
" print(f\"DATABASE TOOL CALLED: Getting price for {city}\", flush=True)\n",
|
|
" with sqlite3.connect(DB) as conn:\n",
|
|
" cursor = conn.cursor()\n",
|
|
" cursor.execute('SELECT price FROM prices WHERE city = ?', (city.lower(),))\n",
|
|
" result = cursor.fetchone()\n",
|
|
" return f\"Ticket price to {city} is ${result[0]}\" if result else \"No price data available for this city\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "03f19289",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"get_ticket_price(\"Paris\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "bcfb6523",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"price_function = {\n",
|
|
" \"name\": \"get_ticket_price\",\n",
|
|
" \"description\": \"Get the price of a return ticket to the destination city.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"destination_city\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The city that the customer wants to travel to\",\n",
|
|
" },\n",
|
|
" },\n",
|
|
" \"required\": [\"destination_city\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
"}\n",
|
|
"tools = [{\"type\": \"function\", \"function\": price_function}]\n",
|
|
"tools"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "61a2a15d-b559-4844-b377-6bd5cb4949f6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def chat(message, history):\n",
|
|
" history = [{\"role\": h[\"role\"], \"content\": h[\"content\"]} for h in history]\n",
|
|
" messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n",
|
|
" response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
|
|
" return response.choices[0].message.content\n",
|
|
"\n",
|
|
"gr.ChatInterface(fn=chat, type=\"messages\").launch()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c91d012e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def chat(message, history):\n",
|
|
" history = [{\"role\":h[\"role\"], \"content\":h[\"content\"]} for h in history]\n",
|
|
" messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n",
|
|
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
|
|
"\n",
|
|
" while response.choices[0].finish_reason==\"tool_calls\":\n",
|
|
" message = response.choices[0].message\n",
|
|
" responses = handle_tool_calls(message)\n",
|
|
" messages.append(message)\n",
|
|
" messages.extend(responses)\n",
|
|
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
|
|
" \n",
|
|
" return response.choices[0].message.content"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "956c3b61",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def handle_tool_calls(message):\n",
|
|
" responses = []\n",
|
|
" for tool_call in message.tool_calls:\n",
|
|
" if tool_call.function.name == \"get_ticket_price\":\n",
|
|
" arguments = json.loads(tool_call.function.arguments)\n",
|
|
" city = arguments.get('destination_city')\n",
|
|
" price_details = get_ticket_price(city)\n",
|
|
" responses.append({\n",
|
|
" \"role\": \"tool\",\n",
|
|
" \"content\": price_details,\n",
|
|
" \"tool_call_id\": tool_call.id\n",
|
|
" })\n",
|
|
" return responses"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8eca803e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"gr.ChatInterface(fn=chat, type=\"messages\").launch()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b369bf10",
|
|
"metadata": {},
|
|
"source": [
|
|
"## A bit more about what Gradio actually does:\n",
|
|
"\n",
|
|
"1. Gradio constructs a frontend Svelte app based on our Python description of the UI\n",
|
|
"2. Gradio starts a server built upon the Starlette web framework listening on a free port that serves this React app\n",
|
|
"3. Gradio creates backend routes for our callbacks, like chat(), which calls our functions\n",
|
|
"\n",
|
|
"And of course when Gradio generates the frontend app, it ensures that the the Submit button calls the right backend route.\n",
|
|
"\n",
|
|
"That's it!\n",
|
|
"\n",
|
|
"It's simple, and it has a resut that feels magical."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "863aac34",
|
|
"metadata": {},
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "473e5b39-da8f-4db1-83ae-dbaca2e9531e",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Let's go multi-modal!!\n",
|
|
"\n",
|
|
"We can use DALL-E-3, the image generation model behind GPT-4o, to make us some images\n",
|
|
"\n",
|
|
"Let's put this in a function called artist.\n",
|
|
"\n",
|
|
"### Price alert: each time I generate an image it costs about 4 cents - don't go crazy with images!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2c27c4ba-8ed5-492f-add1-02ce9c81d34c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Some imports for handling images\n",
|
|
"\n",
|
|
"import base64\n",
|
|
"from io import BytesIO\n",
|
|
"from PIL import Image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "773a9f11-557e-43c9-ad50-56cbec3a0f8f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def artist(city):\n",
|
|
" image_response = openai.images.generate(\n",
|
|
" model=\"dall-e-3\",\n",
|
|
" prompt=f\"An image representing a vacation in {city}, showing tourist spots and everything unique about {city}, in a vibrant pop-art style\",\n",
|
|
" size=\"1024x1024\",\n",
|
|
" n=1,\n",
|
|
" response_format=\"b64_json\",\n",
|
|
" )\n",
|
|
" image_base64 = image_response.data[0].b64_json\n",
|
|
" image_data = base64.b64decode(image_base64)\n",
|
|
" return Image.open(BytesIO(image_data))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d877c453-e7fb-482a-88aa-1a03f976b9e9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"image = artist(\"New York City\")\n",
|
|
"display(image)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "728a12c5-adc3-415d-bb05-82beb73b079b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def talker(message):\n",
|
|
" response = openai.audio.speech.create(\n",
|
|
" model=\"gpt-4o-mini-tts\",\n",
|
|
" voice=\"onyx\", # Also, try replacing onyx with alloy or coral\n",
|
|
" input=message\n",
|
|
" )\n",
|
|
" return response.content"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3bc7580b",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Let's bring this home:\n",
|
|
"\n",
|
|
"1. A multi-modal AI assistant with image and audio generation\n",
|
|
"2. Tool callling with database lookup\n",
|
|
"3. A step towards an Agentic workflow\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b119ed1b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def chat(history):\n",
|
|
" history = [{\"role\":h[\"role\"], \"content\":h[\"content\"]} for h in history]\n",
|
|
" messages = [{\"role\": \"system\", \"content\": system_message}] + history\n",
|
|
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
|
|
" cities = []\n",
|
|
" image = None\n",
|
|
"\n",
|
|
" while response.choices[0].finish_reason==\"tool_calls\":\n",
|
|
" message = response.choices[0].message\n",
|
|
" responses, cities = handle_tool_calls_and_return_cities(message)\n",
|
|
" messages.append(message)\n",
|
|
" messages.extend(responses)\n",
|
|
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
|
|
"\n",
|
|
" reply = response.choices[0].message.content\n",
|
|
" history += [{\"role\":\"assistant\", \"content\":reply}]\n",
|
|
"\n",
|
|
" voice = talker(reply)\n",
|
|
"\n",
|
|
" if cities:\n",
|
|
" image = artist(cities[0])\n",
|
|
" \n",
|
|
" return history, voice, image\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5846bc77",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def handle_tool_calls_and_return_cities(message):\n",
|
|
" responses = []\n",
|
|
" cities = []\n",
|
|
" for tool_call in message.tool_calls:\n",
|
|
" if tool_call.function.name == \"get_ticket_price\":\n",
|
|
" arguments = json.loads(tool_call.function.arguments)\n",
|
|
" city = arguments.get('destination_city')\n",
|
|
" cities.append(city)\n",
|
|
" price_details = get_ticket_price(city)\n",
|
|
" responses.append({\n",
|
|
" \"role\": \"tool\",\n",
|
|
" \"content\": price_details,\n",
|
|
" \"tool_call_id\": tool_call.id\n",
|
|
" })\n",
|
|
" return responses, cities"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6e520161",
|
|
"metadata": {},
|
|
"source": [
|
|
"## The 3 types of Gradio UI\n",
|
|
"\n",
|
|
"`gr.Interface` is for standard, simple UIs\n",
|
|
"\n",
|
|
"`gr.ChatInterface` is for standard ChatBot UIs\n",
|
|
"\n",
|
|
"`gr.Blocks` is for custom UIs where you control the components and the callbacks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9f250915",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Callbacks (along with the chat() function above)\n",
|
|
"\n",
|
|
"def put_message_in_chatbot(message, history):\n",
|
|
" return \"\", history + [{\"role\":\"user\", \"content\":message}]\n",
|
|
"\n",
|
|
"# UI definition\n",
|
|
"\n",
|
|
"with gr.Blocks() as ui:\n",
|
|
" with gr.Row():\n",
|
|
" chatbot = gr.Chatbot(height=500, type=\"messages\")\n",
|
|
" image_output = gr.Image(height=500, interactive=False)\n",
|
|
" with gr.Row():\n",
|
|
" audio_output = gr.Audio(autoplay=True)\n",
|
|
" with gr.Row():\n",
|
|
" message = gr.Textbox(label=\"Chat with our AI Assistant:\")\n",
|
|
"\n",
|
|
"# Hooking up events to callbacks\n",
|
|
"\n",
|
|
" message.submit(put_message_in_chatbot, inputs=[message, chatbot], outputs=[message, chatbot]).then(\n",
|
|
" chat, inputs=chatbot, outputs=[chatbot, audio_output, image_output]\n",
|
|
" )\n",
|
|
"\n",
|
|
"ui.launch(inbrowser=True, auth=(\"ed\", \"bananas\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "226643d2-73e4-4252-935d-86b8019e278a",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Exercises and Business Applications\n",
|
|
"\n",
|
|
"Add in more tools - perhaps to simulate actually booking a flight. A student has done this and provided their example in the community contributions folder.\n",
|
|
"\n",
|
|
"Next: take this and apply it to your business. Make a multi-modal AI assistant with tools that could carry out an activity for your work. A customer support assistant? New employee onboarding assistant? So many possibilities! Also, see the week2 end of week Exercise in the separate Notebook."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7e795560-1867-42db-a256-a23b844e6fbe",
|
|
"metadata": {},
|
|
"source": [
|
|
"<table style=\"margin: 0; text-align: left;\">\n",
|
|
" <tr>\n",
|
|
" <td style=\"width: 150px; height: 150px; vertical-align: middle;\">\n",
|
|
" <img src=\"../assets/thankyou.jpg\" width=\"150\" height=\"150\" style=\"display: block;\" />\n",
|
|
" </td>\n",
|
|
" <td>\n",
|
|
" <h2 style=\"color:#090;\">I have a special request for you</h2>\n",
|
|
" <span style=\"color:#090;\">\n",
|
|
" My editor tells me that it makes a HUGE difference when students rate this course on Udemy - it's one of the main ways that Udemy decides whether to show it to others. If you're able to take a minute to rate this, I'd be so very grateful! And regardless - always please reach out to me at ed@edwarddonner.com if I can help at any point.\n",
|
|
" </span>\n",
|
|
" </td>\n",
|
|
" </tr>\n",
|
|
"</table>"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|