From 2391aeb24729bef72e11b6395e48f220252152d6 Mon Sep 17 00:00:00 2001 From: Praveen M <32341624+Praveenm79@users.noreply.github.com> Date: Sun, 22 Jun 2025 12:10:53 +0530 Subject: [PATCH] Week 2 contribution: Airline Assistant using Gemini + Amadeus live ticket price --- ...ant_Gemini_Amadeus_live_ticket_price.ipynb | 808 ++++++++++++++++++ 1 file changed, 808 insertions(+) create mode 100644 week2/community-contributions/Week2_airline_assistant_Gemini_Amadeus_live_ticket_price.ipynb diff --git a/week2/community-contributions/Week2_airline_assistant_Gemini_Amadeus_live_ticket_price.ipynb b/week2/community-contributions/Week2_airline_assistant_Gemini_Amadeus_live_ticket_price.ipynb new file mode 100644 index 0000000..bc4f92a --- /dev/null +++ b/week2/community-contributions/Week2_airline_assistant_Gemini_Amadeus_live_ticket_price.ipynb @@ -0,0 +1,808 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d938fc6c-bcca-4572-b851-75370fe21c67", + "metadata": {}, + "source": [ + "# Airline Assistant using Gemini API for Image and Audio as well - Live ticket prices using Amadeus API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5eda470-07ee-4d01-bada-3390050ac9c2", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import random\n", + "import string\n", + "import base64\n", + "import gradio as gr\n", + "import pyaudio\n", + "import requests\n", + "from io import BytesIO\n", + "from PIL import Image\n", + "from dotenv import load_dotenv\n", + "from google import genai\n", + "from google.genai import types" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09aaf3b0-beb7-4b64-98a4-da16fc83dadb", + "metadata": {}, + "outputs": [], + "source": [ + "load_dotenv(override=True)\n", + "api_key = os.getenv(\"GOOGLE_API_KEY\")\n", + "\n", + "if not api_key:\n", + " print(\"API Key not found!\")\n", + "else:\n", + " print(\"API Key loaded in memory\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35881fb9-4d51-43dc-a5e6-d9517e22019a", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_GEMINI = 'gemini-2.5-flash'\n", + "MODEL_GEMINI_IMAGE = 'gemini-2.0-flash-preview-image-generation'\n", + "MODEL_GEMINI_SPEECH = 'gemini-2.5-flash-preview-tts'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5ed391c-8a67-4465-9c66-e915548a0d6a", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " client = genai.Client(api_key=api_key)\n", + " print(\"Google GenAI Client initialized successfully!\")\n", + "except Exception as e:\n", + " print(f\"Error initializing GenAI Client: {e}\")\n", + " print(\"Ensure your GOOGLE_API_KEY is correctly set as an environment variable.\")\n", + " exit() " + ] + }, + { + "cell_type": "markdown", + "id": "407ad581-9580-4dba-b236-abb6c6788933", + "metadata": {}, + "source": [ + "## Image Generation " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a21921f8-57b1-4665-8999-7f2a40645b59", + "metadata": {}, + "outputs": [], + "source": [ + "def fetch_image(city):\n", + " prompt = (\n", + " f\"A high-quality, photo-realistic image of a vacation in {city}, \"\n", + " f\"showing iconic landmarks, cultural attractions, authentic street life, and local cuisine. \"\n", + " f\"Capture natural lighting, real people enjoying travel experiences, and the unique vibe of {city}'s atmosphere. \"\n", + " f\"The composition should feel immersive, warm, and visually rich, as if taken by a travel photographer.\"\n", + ")\n", + "\n", + " response = client.models.generate_content(\n", + " model = MODEL_GEMINI_IMAGE,\n", + " contents = prompt,\n", + " config=types.GenerateContentConfig(\n", + " response_modalities=['TEXT', 'IMAGE']\n", + " )\n", + " )\n", + "\n", + " for part in response.candidates[0].content.parts:\n", + " if part.inline_data is not None:\n", + " image_data = BytesIO(part.inline_data.data)\n", + " return Image.open(image_data)\n", + "\n", + " raise ValueError(\"No image found in Gemini response.\")\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bcd4aed1-8b4d-4771-ba32-e729e82bab54", + "metadata": {}, + "outputs": [], + "source": [ + "fetch_image(\"london\")" + ] + }, + { + "cell_type": "markdown", + "id": "5f6baee6-e2e2-4cc4-941d-34a4c72cee67", + "metadata": {}, + "source": [ + "## Speech Generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "825dfedc-0271-4191-a3d1-50872af4c8cf", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "Kore -- Firm\n", + "Puck -- Upbeat\n", + "Leda -- Youthful\n", + "Iapetus -- Clear\n", + "Erinome -- Clear\n", + "Sadachbia -- Lively\n", + "Sulafat -- Warm\n", + "Despina -- Smooth\n", + "\"\"\"\n", + "\n", + "def talk(message:str, voice_name:str=\"Leda\", mood:str=\"cheerfully\"):\n", + " prompt = f\"Say {mood}: {message}\"\n", + " response = client.models.generate_content(\n", + " model = MODEL_GEMINI_SPEECH,\n", + " contents = prompt,\n", + " config=types.GenerateContentConfig(\n", + " response_modalities=[\"AUDIO\"],\n", + " speech_config=types.SpeechConfig(\n", + " voice_config=types.VoiceConfig(\n", + " prebuilt_voice_config=types.PrebuiltVoiceConfig(\n", + " voice_name=voice_name,\n", + " )\n", + " )\n", + " ), \n", + " )\n", + " )\n", + "\n", + " # Fetch the audio bytes\n", + " pcm_data = response.candidates[0].content.parts[0].inline_data.data\n", + " # Play the audio using PyAudio\n", + " p = pyaudio.PyAudio()\n", + " stream = p.open(format=pyaudio.paInt16, channels=1, rate=24000, output=True)\n", + " stream.write(pcm_data)\n", + " stream.stop_stream()\n", + " stream.close()\n", + " p.terminate()\n", + "\n", + " # Play using simpleaudio (16-bit PCM, mono, 24kHz)\n", + " # play_obj = sa.play_buffer(pcm_data, num_channels=1, bytes_per_sample=2, sample_rate=24000)\n", + " # play_obj.wait_done() " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54967ebc-24a6-4bb2-9a19-20c3585f1d77", + "metadata": {}, + "outputs": [], + "source": [ + "talk(\"Hi, How are you? Welcome to FlyJumbo Airlines\",\"Kore\",\"helpful\")" + ] + }, + { + "cell_type": "markdown", + "id": "be9dc275-838e-4c54-b487-41d094dad96b", + "metadata": {}, + "source": [ + "## Ticket Price Tool Function - Using Amadeus API " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8613a080-d82c-4c1a-8db4-377614997ac2", + "metadata": {}, + "outputs": [], + "source": [ + "client_id = os.getenv(\"AMADEUS_CLIENT_ID\")\n", + "client_secret = os.getenv(\"AMADEUS_CLIENT_SECRET\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6bf78f61-0de1-4552-a1d4-1a28380be6a5", + "metadata": {}, + "outputs": [], + "source": [ + "# Get the token first\n", + "def get_amadeus_token():\n", + " url = \"https://test.api.amadeus.com/v1/security/oauth2/token\"\n", + " headers = {\"Content-Type\": \"application/x-www-form-urlencoded\"}\n", + " data = {\n", + " \"grant_type\": \"client_credentials\",\n", + " \"client_id\": client_id,\n", + " \"client_secret\": client_secret,\n", + " }\n", + " \n", + " try:\n", + " response = requests.post(url, headers=headers, data=data, timeout=10)\n", + " response.raise_for_status()\n", + " return response.json()[\"access_token\"]\n", + " \n", + " except requests.exceptions.HTTPError as e:\n", + " print(f\"HTTP Error {response.status_code}: {response.text}\")\n", + " \n", + " except requests.exceptions.RequestException as e:\n", + " print(\"Network or connection error:\", e)\n", + " \n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c5261f6-6662-4e9d-8ff0-8e10171bb963", + "metadata": {}, + "outputs": [], + "source": [ + "def get_airline_name(code, token):\n", + " url = f\"https://test.api.amadeus.com/v1/reference-data/airlines\"\n", + " headers = {\"Authorization\": f\"Bearer {token}\"}\n", + " params = {\"airlineCodes\": code}\n", + "\n", + " response = requests.get(url, headers=headers, params=params)\n", + " response.raise_for_status()\n", + " data = response.json()\n", + "\n", + " if \"data\" in data and data[\"data\"]:\n", + " return data[\"data\"][0].get(\"businessName\", code)\n", + " return code" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42a55f06-880a-4c49-8560-2e7b97953c1a", + "metadata": {}, + "outputs": [], + "source": [ + "COMMON_CITY_CODES = {\n", + " \"delhi\": \"DEL\",\n", + " \"mumbai\": \"BOM\",\n", + " \"chennai\": \"MAA\",\n", + " \"kolkata\": \"CCU\",\n", + " \"bengaluru\": \"BLR\",\n", + " \"hyderabad\": \"HYD\",\n", + " \"patna\": \"PAT\",\n", + " \"raipur\": \"RPR\",\n", + " \"panaji\": \"GOI\",\n", + " \"chandigarh\": \"IXC\",\n", + " \"srinagar\": \"SXR\",\n", + " \"ranchi\": \"IXR\",\n", + " \"bengaluru\": \"BLR\",\n", + " \"thiruvananthapuram\": \"TRV\",\n", + " \"bhopal\": \"BHO\",\n", + " \"mumbai\": \"BOM\",\n", + " \"imphal\": \"IMF\",\n", + " \"aizawl\": \"AJL\",\n", + " \"bhubaneswar\": \"BBI\",\n", + " \"jaipur\": \"JAI\",\n", + " \"chennai\": \"MAA\",\n", + " \"hyderabad\": \"HYD\",\n", + " \"agartala\": \"IXA\",\n", + " \"lucknow\": \"LKO\",\n", + " \"dehradun\": \"DED\",\n", + " \"kolkata\": \"CCU\",\n", + "\n", + " # Union territories\n", + " \"port blair\": \"IXZ\",\n", + " \"leh\": \"IXL\",\n", + " \"puducherry\": \"PNY\",\n", + "\n", + " # Major metro cities (for redundancy)\n", + " \"ahmedabad\": \"AMD\",\n", + " \"surat\": \"STV\",\n", + " \"coimbatore\": \"CJB\",\n", + " \"vizag\": \"VTZ\",\n", + " \"vijayawada\": \"VGA\",\n", + " \"nagpur\": \"NAG\",\n", + " \"indore\": \"IDR\",\n", + " \"kanpur\": \"KNU\",\n", + " \"varanasi\": \"VNS\"\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b061ec2c-609b-4d77-bd41-c9bc5bf901f4", + "metadata": {}, + "outputs": [], + "source": [ + "city_code_cache = {}\n", + "\n", + "def get_city_code(city_name, token):\n", + " city_name = city_name.strip().lower()\n", + "\n", + " if city_name in city_code_cache:\n", + " return city_code_cache[city_name]\n", + "\n", + " if city_name in COMMON_CITY_CODES:\n", + " return COMMON_CITY_CODES[city_name]\n", + "\n", + " base_url = \"https://test.api.amadeus.com/v1/reference-data/locations\"\n", + " headers = {\"Authorization\": f\"Bearer {token}\"}\n", + "\n", + " for subtype in [\"CITY\", \"AIRPORT,CITY\"]:\n", + " params = {\"keyword\": city_name, \"subType\": subtype}\n", + " try:\n", + " response = requests.get(base_url, headers=headers, params=params, timeout=10)\n", + " response.raise_for_status()\n", + " data = response.json()\n", + "\n", + " if \"data\" in data and data[\"data\"]:\n", + " code = data[\"data\"][0][\"iataCode\"]\n", + " print(f\"[INFO] Found {subtype} match for '{city_name}': {code}\")\n", + " city_code_cache[city_name] = code\n", + " return code\n", + " except Exception as e:\n", + " print(f\"[ERROR] Location lookup failed for {subtype}: {e}\")\n", + "\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9816a9c-fd70-4dfc-a3c0-4d8709997371", + "metadata": {}, + "outputs": [], + "source": [ + "# Getting live ticket price \n", + "\n", + "def get_live_ticket_prices(origin, destination, departure_date, return_date=None):\n", + " token = get_amadeus_token()\n", + "\n", + " url = \"https://test.api.amadeus.com/v2/shopping/flight-offers\"\n", + " headers = {\"Authorization\": f\"Bearer {token}\"}\n", + "\n", + " origin_code = get_city_code(origin,token)\n", + " destination_code = get_city_code(destination,token)\n", + "\n", + " if not origin_code:\n", + " return f\"Sorry, I couldn't find the airport code for the city '{origin}'.\"\n", + " if not destination_code:\n", + " return f\"Sorry, I couldn't find the airport code for the city '{destination}'.\"\n", + "\n", + " params = {\n", + " \"originLocationCode\": origin_code.upper(),\n", + " \"destinationLocationCode\": destination_code.upper(),\n", + " \"departureDate\": departure_date,\n", + " \"adults\": 1,\n", + " \"currencyCode\": \"USD\",\n", + " \"max\": 1,\n", + " }\n", + "\n", + " if return_date:\n", + " params[\"returnDate\"] = return_date\n", + "\n", + " try:\n", + " response = requests.get(url, headers=headers, params=params, timeout=10)\n", + " response.raise_for_status()\n", + " data = response.json()\n", + " \n", + " if \"data\" in data and data[\"data\"]:\n", + " offer = data[\"data\"][0]\n", + " price = offer[\"price\"][\"total\"]\n", + " airline_codes = offer.get(\"validatingAirlineCodes\", [])\n", + " airline_code = airline_codes[0] if airline_codes else \"Unknown\"\n", + "\n", + " try:\n", + " airline_name = get_airline_name(airline_code, token) if airline_code != \"Unknown\" else \"Unknown Airline\"\n", + " if not airline_name: \n", + " airline_name = airline_code\n", + " except Exception:\n", + " airline_name = airline_code\n", + " \n", + " \n", + " if return_date:\n", + " return (\n", + " f\"Round-trip flight from {origin.capitalize()} to {destination.capitalize()}:\\n\"\n", + " f\"- Departing: {departure_date}\\n\"\n", + " f\"- Returning: {return_date}\\n\"\n", + " f\"- Airline: {airline_name}\\n\"\n", + " f\"- Price: ${price}\"\n", + " )\n", + " else:\n", + " return (\n", + " f\"One-way flight from {origin.capitalize()} to {destination.capitalize()} on {departure_date}:\\n\"\n", + " f\"- Airline: {airline_name}\\n\"\n", + " f\"- Price: ${price}\"\n", + " )\n", + " else:\n", + " return f\"No flights found from {origin.capitalize()} to {destination.capitalize()} on {departure_date}.\"\n", + " except requests.exceptions.RequestException as e:\n", + " return f\"❌ Error fetching flight data: {str(e)}\" \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bc7657e-e8b5-4647-9745-d7d403feb09a", + "metadata": {}, + "outputs": [], + "source": [ + "get_live_ticket_prices(\"london\", \"chennai\", \"2025-07-01\",\"2025-07-10\")" + ] + }, + { + "cell_type": "markdown", + "id": "e1153b94-90e7-4856-8c85-e456305a7817", + "metadata": {}, + "source": [ + "## Ticket Booking Tool Function - DUMMY" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5dfc3b12-0a16-4861-a549-594f175ff956", + "metadata": {}, + "outputs": [], + "source": [ + "def book_flight(origin, destination, departure_date, return_date=None, airline=\"Selected Airline\", passenger_name=\"Guest\"):\n", + " # Generate a dummy ticket reference (PNR)\n", + " ticket_ref = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))\n", + "\n", + " # Build confirmation message\n", + " confirmation = (\n", + " f\"🎫 Booking confirmed for {passenger_name}!\\n\"\n", + " f\"From: {origin.capitalize()} → To: {destination.capitalize()}\\n\"\n", + " f\"Departure: {departure_date}\"\n", + " )\n", + "\n", + " if return_date:\n", + " confirmation += f\"\\nReturn: {return_date}\"\n", + "\n", + " confirmation += (\n", + " f\"\\nAirline: {airline}\\n\"\n", + " f\"PNR: {ticket_ref}\\n\"\n", + " f\"✅ Your ticket has been booked successfully. Safe travels!\"\n", + " )\n", + "\n", + " return confirmation\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "122f655b-b7a4-45c6-aaec-afd2917a051b", + "metadata": {}, + "outputs": [], + "source": [ + "print(book_flight(\"chennai\", \"delhi\", \"2025-07-01\", \"2025-07-10\", \"Air India\", \"Ravi Kumar\"))" + ] + }, + { + "cell_type": "markdown", + "id": "e83d8e90-ae22-4728-83e5-d83fed7f2049", + "metadata": {}, + "source": [ + "## Gemini Chat Workings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a656f4e-914d-4f5e-b7fa-48457935181a", + "metadata": {}, + "outputs": [], + "source": [ + "ticket_price_function_declaration = {\n", + " \"name\":\"get_live_ticket_prices\",\n", + " \"description\": \"Get live flight ticket prices between two cities for a given date (round-trip or one-way).\\\n", + " The destination may be a city or country (e.g., 'China'). Call this function whenever a customer asks about ticket prices., such as 'How much is a ticket to Paris?'\",\n", + " \"parameters\":{\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"origin\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Name of the origin city. Example: 'Delhi'\",\n", + " },\n", + " \"destination\": {\n", + " \"type\": \"string\",\n", + " \"description\":\"Name of the destination city. Example: 'London'\",\n", + " },\n", + " \"departure_date\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Date of departure in YYYY-MM-DD format. Example: '2025-07-01'\",\n", + " },\n", + " \"return_date\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Optional return date for round-trip in YYYY-MM-DD format. Leave blank for one-way trips.\",\n", + " },\n", + " },\n", + " \"required\": [\"origin\", \"destination\", \"departure_date\"],\n", + " }\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05a835ab-a675-40ed-9cd8-65f4c6b22722", + "metadata": {}, + "outputs": [], + "source": [ + "book_flight_function_declaration = {\n", + " \"name\": \"book_flight\",\n", + " \"description\": \"Book a flight for the user after showing the ticket details and confirming the booking. \"\n", + " \"Call this function when the user says things like 'yes', 'book it', or 'I want to book this flight'.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"origin\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Name of the origin city. Example: 'Chennai'\",\n", + " },\n", + " \"destination\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Name of the destination city. Example: 'London'\",\n", + " },\n", + " \"departure_date\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Date of departure in YYYY-MM-DD format. Example: '2025-07-01'\",\n", + " },\n", + " \"return_date\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Optional return date for round-trip in YYYY-MM-DD format. Leave blank for one-way trips.\",\n", + " },\n", + " \"airline\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Airline name or code that the user wants to book with. Example: 'Air India'\",\n", + " },\n", + " \"passenger_name\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Full name of the passenger for the booking. Example: 'Ravi Kumar'\",\n", + " }\n", + " },\n", + " \"required\": [\"origin\", \"destination\", \"departure_date\", \"passenger_name\"],\n", + " }\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad0231cd-040f-416d-b150-0d8f90535718", + "metadata": {}, + "outputs": [], + "source": [ + "# System Definitions\n", + "\n", + "system_instruction_prompt = (\n", + " \"You are a helpful and courteous AI assistant for an airline company called FlyJumbo. \"\n", + " \"When a user starts a new conversation, greet them with: 'Hi there, welcome to FlyJumbo! How can I help you?'. \"\n", + " \"Do not repeat this greeting in follow-up messages. \"\n", + " \"Use the available tools if a user asks about ticket prices. \"\n", + " \"Ask follow-up questions to gather all necessary information before calling a function.\"\n", + " \"After calling a tool, always continue the conversation by summarizing the result and asking the user the next relevant question (e.g., if they want to proceed with a booking).\"\n", + " \"If you do not know the answer and no tool can help, respond politely that you are unable to help with the request. \"\n", + " \"Answer concisely in one sentence.\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff0b3de8-5674-4f08-9f9f-06f88ff959a1", + "metadata": {}, + "outputs": [], + "source": [ + "tools = types.Tool(function_declarations=[ticket_price_function_declaration,book_flight_function_declaration])\n", + "generate_content_config = types.GenerateContentConfig(system_instruction=system_instruction_prompt, tools=[tools])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00a56779-16eb-4f31-9941-2eb01d17ed87", + "metadata": {}, + "outputs": [], + "source": [ + "def handle_tool_call(function_call):\n", + " print(f\"🔧 Function Called - {function_call.name}\")\n", + " function_name = function_call.name\n", + " args = function_call.args\n", + "\n", + " if function_name == \"get_live_ticket_prices\":\n", + " origin = args.get(\"origin\")\n", + " destination = args.get(\"destination\")\n", + " departure_date = args.get(\"departure_date\")\n", + " return_date = args.get(\"return_date\") or None\n", + "\n", + " return get_live_ticket_prices(origin, destination, departure_date, return_date)\n", + "\n", + " elif function_name == \"book_flight\":\n", + " origin = args.get(\"origin\")\n", + " destination = args.get(\"destination\")\n", + " departure_date = args.get(\"departure_date\")\n", + " return_date = args.get(\"return_date\") or None\n", + " airline = args.get(\"airline\", \"Selected Airline\")\n", + " passenger_name = args.get(\"passenger_name\", \"Guest\")\n", + "\n", + " return book_flight(origin, destination, departure_date, return_date, airline, passenger_name)\n", + "\n", + " else:\n", + " return f\"❌ Unknown function: {function_name}\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0c334d2-9ab0-4f80-ac8c-c66897e0bd7c", + "metadata": {}, + "outputs": [], + "source": [ + "def chat(message, history):\n", + " full_message_history = []\n", + " city_name = None\n", + "\n", + " # Convert previous history to Gemini-compatible format\n", + " for h in history:\n", + " if h[\"role\"] == \"user\":\n", + " full_message_history.append(\n", + " types.Content(role=\"user\", parts=[types.Part.from_text(text=h[\"content\"])])\n", + " )\n", + " elif h[\"role\"] == \"assistant\":\n", + " full_message_history.append(\n", + " types.Content(role=\"model\", parts=[types.Part.from_text(text=h[\"content\"])])\n", + " )\n", + "\n", + " # Add current user message\n", + " full_message_history.append(\n", + " types.Content(role=\"user\", parts=[types.Part.from_text(text=message)])\n", + " )\n", + "\n", + " # Send to Gemini with tool config\n", + " response = client.models.generate_content(\n", + " model=MODEL_GEMINI,\n", + " contents=full_message_history,\n", + " config=generate_content_config\n", + " )\n", + "\n", + " candidate = response.candidates[0]\n", + " part = candidate.content.parts[0]\n", + " function_call = getattr(part, \"function_call\", None)\n", + "\n", + " # Case: Tool call required\n", + " if function_call:\n", + " # Append model message that triggered tool call\n", + " full_message_history.append(\n", + " types.Content(role=\"model\", parts=candidate.content.parts)\n", + " )\n", + "\n", + " # Execute the tool\n", + " tool_output = handle_tool_call(function_call)\n", + "\n", + " # Wrap and append tool output\n", + " tool_response_part = types.Part.from_function_response(\n", + " name=function_call.name,\n", + " response={\"result\": tool_output}\n", + " )\n", + " \n", + " full_message_history.append(\n", + " types.Content(role=\"function\", parts=[tool_response_part])\n", + " )\n", + "\n", + "\n", + " if function_call.name == \"book_flight\":\n", + " city_name = function_call.args.get(\"destination\").lower()\n", + " \n", + "\n", + " # Send follow-up message including tool result\n", + " followup_response = client.models.generate_content(\n", + " model=MODEL_GEMINI,\n", + " contents=full_message_history,\n", + " config=generate_content_config\n", + " )\n", + "\n", + " final_text = followup_response.text\n", + " \n", + " full_message_history.append(\n", + " types.Content(role=\"model\", parts=[types.Part.from_text(text=final_text)])\n", + " )\n", + "\n", + " return final_text,city_name, history + [{\"role\": \"assistant\", \"content\": final_text}]\n", + " else:\n", + " text = response.text\n", + " return text, city_name, history + [{\"role\": \"assistant\", \"content\": text}]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b245e6c-ef0b-4edf-b178-f14f2a75f285", + "metadata": {}, + "outputs": [], + "source": [ + "def user_submit(user_input, history):\n", + " history = history or []\n", + " history.append({\"role\": \"user\", \"content\": user_input})\n", + " \n", + " response_text, city_to_image, updated_history = chat(user_input, history)\n", + "\n", + " # Speak the response\n", + " try:\n", + " talk(response_text)\n", + " except Exception as e:\n", + " print(\"[Speech Error] Speech skipped due to quota limit.\")\n", + "\n", + " image = fetch_image(city_to_image) if city_to_image else None\n", + "\n", + " return \"\", updated_history, image, updated_history\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7db25b86-9a71-417c-98f0-790e3f3531bf", + "metadata": {}, + "outputs": [], + "source": [ + "with gr.Blocks() as demo:\n", + " gr.Markdown(\"## ✈️ FlyJumbo Airline Assistant\")\n", + "\n", + " with gr.Row():\n", + " with gr.Column(scale=3):\n", + " chatbot = gr.Chatbot(label=\"Assistant\", height=500, type=\"messages\")\n", + " msg = gr.Textbox(placeholder=\"Ask about flights...\", show_label=False)\n", + " send_btn = gr.Button(\"Send\")\n", + "\n", + " with gr.Column(scale=2):\n", + " image_output = gr.Image(label=\"Trip Visual\", visible=True, height=500)\n", + "\n", + " state = gr.State([])\n", + " \n", + " send_btn.click(fn=user_submit, inputs=[msg, state], outputs=[msg, chatbot, image_output, state])\n", + " msg.submit(fn=user_submit, inputs=[msg, state], outputs=[msg, chatbot, image_output, state])\n", + "\n", + "demo.launch(inbrowser=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef31bf62-9034-4fa7-b803-8f5df5309b77", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}