Merge pull request #716 from sevriugin/feat/booking-agent
feat: booking AI Agent with Amadeus API and Google Map Generator
This commit is contained in:
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
KEY_CONFIGS = {
|
||||
"gpt": {
|
||||
"id": "gpt",
|
||||
"api_key_env": "OPENAI_API_KEY",
|
||||
},
|
||||
"claude": {
|
||||
"id": "claude",
|
||||
"api_key_env": "ANTHROPIC_API_KEY",
|
||||
},
|
||||
"gemini": {
|
||||
"id": "gemini",
|
||||
"api_key_env": "GOOGLE_API_KEY",
|
||||
},
|
||||
"openai": {
|
||||
"id": "openai",
|
||||
"api_key_env": "OPENAI_API_KEY",
|
||||
},
|
||||
"deepseek": {
|
||||
"id": "deepseek",
|
||||
"api_key_env": "DEEPSEEK_API_KEY"
|
||||
},
|
||||
"amadeus_client_id": {
|
||||
"id": "client_id",
|
||||
"api_key_env": "AMADEUS_CLIENT_ID"
|
||||
},
|
||||
"amadeus_client_secret": {
|
||||
"id": "client_secret",
|
||||
"api_key_env": "AMADEUS_CLIENT_SECRET"
|
||||
},
|
||||
"google_map": {
|
||||
"id": "google_map_api_key",
|
||||
"api_key_env": "GOOGLE_MAP_API_KEY"
|
||||
}
|
||||
}
|
||||
|
||||
class ApiKeyLoader:
|
||||
def __init__(self):
|
||||
load_dotenv(override=False)
|
||||
|
||||
required_env_vars = {cfg["api_key_env"] for cfg in KEY_CONFIGS.values() if "api_key_env" in cfg}
|
||||
|
||||
self.missing = [var for var in sorted(required_env_vars) if not os.getenv(var)]
|
||||
|
||||
if self.missing:
|
||||
raise RuntimeError(
|
||||
"Missing required API key environment variables: "
|
||||
+ ", ".join(self.missing)
|
||||
+ ". Please add them to your .env file or export them in your environment."
|
||||
)
|
||||
|
||||
self.keys = {
|
||||
cfg["id"]: os.getenv(cfg["api_key_env"])
|
||||
for cfg in KEY_CONFIGS.values()
|
||||
if os.getenv(cfg["api_key_env"])
|
||||
}
|
||||
|
||||
def get(self, key):
|
||||
return self.keys.get(key)
|
||||
@@ -0,0 +1,265 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d006b2ea-9dfe-49c7-88a9-a5a0775185fd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Additional End of week Exercise - week 2\n",
|
||||
"\n",
|
||||
"Now use everything you've learned from Week 2 to build a full prototype for the technical question/answerer you built in Week 1 Exercise.\n",
|
||||
"\n",
|
||||
"This should include a Gradio UI, streaming, use of the system prompt to add expertise, and the ability to switch between models. Bonus points if you can demonstrate use of a tool!\n",
|
||||
"\n",
|
||||
"If you feel bold, see if you can add audio input so you can talk to it, and have it respond with audio. ChatGPT or Claude can help you, or email me if you have questions.\n",
|
||||
"\n",
|
||||
"I will publish a full solution here soon - unless someone beats me to it...\n",
|
||||
"\n",
|
||||
"There are so many commercial applications for this, from a language tutor, to a company onboarding solution, to a companion AI to a course (like this one!) I can't wait to see your results."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-06T16:52:12.637325Z",
|
||||
"start_time": "2025-10-06T16:52:10.174609Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from openai import OpenAI\n",
|
||||
"from api_key_loader import ApiKeyLoader\n",
|
||||
"import gradio as gr"
|
||||
],
|
||||
"id": "40fad776c1390c95",
|
||||
"outputs": [],
|
||||
"execution_count": 1
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-06T16:52:14.602846Z",
|
||||
"start_time": "2025-10-06T16:52:14.599016Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Initialization\n",
|
||||
"keys = ApiKeyLoader()"
|
||||
],
|
||||
"id": "df481efd444c3042",
|
||||
"outputs": [],
|
||||
"execution_count": 2
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-06T16:52:16.325113Z",
|
||||
"start_time": "2025-10-06T16:52:16.317253Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Create LLM client\n",
|
||||
"MODEL = \"gpt-4o-mini\"\n",
|
||||
"openai = OpenAI()"
|
||||
],
|
||||
"id": "50c88632f5e3e8ca",
|
||||
"outputs": [],
|
||||
"execution_count": 3
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-06T16:52:18.224200Z",
|
||||
"start_time": "2025-10-06T16:52:18.198959Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": "from tool_box import ToolBox",
|
||||
"id": "5dde6f905c143779",
|
||||
"outputs": [],
|
||||
"execution_count": 4
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-06T16:52:19.913686Z",
|
||||
"start_time": "2025-10-06T16:52:19.906576Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": "tool_box = ToolBox(keys)",
|
||||
"id": "f999cfc5d533bf4e",
|
||||
"outputs": [],
|
||||
"execution_count": 5
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-06T16:52:21.515122Z",
|
||||
"start_time": "2025-10-06T16:52:21.513463Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"system_message = \"You are a helpful assistant for an Travel Agency called TravelAI. \"\n",
|
||||
"system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n",
|
||||
"system_message += \"Always be accurate. If you don't know the answer, say so.\""
|
||||
],
|
||||
"id": "1891e095dc08da95",
|
||||
"outputs": [],
|
||||
"execution_count": 6
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-06T16:52:26.569153Z",
|
||||
"start_time": "2025-10-06T16:52:26.566763Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def chat(history):\n",
|
||||
" messages = [{\"role\": \"system\", \"content\": system_message}] + history\n",
|
||||
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tool_box.tools)\n",
|
||||
" image = None\n",
|
||||
"\n",
|
||||
" if response.choices[0].finish_reason==\"tool_calls\":\n",
|
||||
" message = response.choices[0].message\n",
|
||||
" tools_response, image = tool_box.apply(message)\n",
|
||||
" messages.append(message)\n",
|
||||
" messages.extend(tools_response)\n",
|
||||
" response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
|
||||
"\n",
|
||||
" reply = response.choices[0].message.content\n",
|
||||
" history += [{\"role\":\"assistant\", \"content\":reply}]\n",
|
||||
"\n",
|
||||
" return history, image"
|
||||
],
|
||||
"id": "b74e96dbf19dceef",
|
||||
"outputs": [],
|
||||
"execution_count": 7
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"source": "resp, image = chat([{\"role\": \"user\", \"content\": \"Show me Milan Airports on the map\"}])",
|
||||
"id": "8c4a6f62ed5079f7",
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-06T16:52:59.264912Z",
|
||||
"start_time": "2025-10-06T16:52:58.789327Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"with gr.Blocks() as ui:\n",
|
||||
" with gr.Row():\n",
|
||||
" chatbot = gr.Chatbot(height=500, type=\"messages\")\n",
|
||||
" image_output = gr.Image(height=500)\n",
|
||||
" with gr.Row():\n",
|
||||
" entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n",
|
||||
" with gr.Row():\n",
|
||||
" clear = gr.Button(\"Clear\")\n",
|
||||
"\n",
|
||||
" def do_entry(message, history):\n",
|
||||
" history += [{\"role\":\"user\", \"content\":message}]\n",
|
||||
" return \"\", history\n",
|
||||
"\n",
|
||||
" entry.submit(do_entry, inputs=[entry, chatbot], outputs=[entry, chatbot]).then(\n",
|
||||
" chat, inputs=chatbot, outputs=[chatbot, image_output]\n",
|
||||
" )\n",
|
||||
" clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False)\n",
|
||||
"\n",
|
||||
"ui.launch(inbrowser=True)"
|
||||
],
|
||||
"id": "71809ab63b2973b0",
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"* Running on local URL: http://127.0.0.1:7860\n",
|
||||
"* To create a public link, set `share=True` in `launch()`.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
],
|
||||
"text/html": [
|
||||
"<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data",
|
||||
"jetTransient": {
|
||||
"display_id": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": []
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 9
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from travel_api import TravelAPI\n",
|
||||
"trave_agent = TravelAPI(keys.get(\"client_id\"), keys.get(\"client_secret\"))\n",
|
||||
"airports = trave_agent.get_airport('Milan')\n",
|
||||
"print(airports)"
|
||||
],
|
||||
"id": "3e63c59c69cbdae9",
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-06T16:52:55.269319Z",
|
||||
"start_time": "2025-10-06T16:52:34.020932Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": "resp, image = chat([{\"role\": \"user\", \"content\": \"Give me a boarding pass from MXP to LHR for 1 Nov 2025 for Sergei Sevriugin\"}])",
|
||||
"id": "caa12b84a1863bc1",
|
||||
"outputs": [],
|
||||
"execution_count": 8
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
100
week2/community-contributions/book_ticket_agent/map_generator.py
Normal file
100
week2/community-contributions/book_ticket_agent/map_generator.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
from http.client import IncompleteRead
|
||||
from googlemaps.maps import StaticMapMarker
|
||||
|
||||
import googlemaps
|
||||
import time
|
||||
|
||||
def get_center(points):
|
||||
if not points:
|
||||
raise ValueError("points must be a non-empty list of coordinate objects")
|
||||
|
||||
locations = []
|
||||
lats, lngs = [], []
|
||||
for p in points:
|
||||
g = p.get("geoCode")
|
||||
lat = p.get("latitude", g.get("latitude"))
|
||||
lng = p.get("longitude", g.get("longitude"))
|
||||
if lat is None or lng is None:
|
||||
raise ValueError("Each point must include 'latitude' and 'longitude' (or 'geoCode.latitude'/'geoCode.longitude').")
|
||||
lat_f = float(lat)
|
||||
lng_f = float(lng)
|
||||
locations.append({"lat": lat_f, "lng": lng_f})
|
||||
lats.append(lat_f)
|
||||
lngs.append(lng_f)
|
||||
|
||||
# Center at the centroid of provided points
|
||||
center = (sum(lats) / len(lats), sum(lngs) / len(lngs))
|
||||
return center, locations
|
||||
|
||||
|
||||
class MapGenerator:
|
||||
def __init__(self, google_map_api_key):
|
||||
self.client = googlemaps.Client(google_map_api_key)
|
||||
|
||||
def fetch_static_map_bytes(
|
||||
self,
|
||||
center,
|
||||
markers,
|
||||
size=(400, 400),
|
||||
zoom=6,
|
||||
map_type="hybrid",
|
||||
img_format="png",
|
||||
scale=2,
|
||||
visible=None,
|
||||
max_retries=3,
|
||||
backoff_base=0.6,
|
||||
):
|
||||
last_err = None
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
iterator = self.client.static_map(
|
||||
size=size,
|
||||
zoom=zoom,
|
||||
center=center,
|
||||
maptype=map_type,
|
||||
format=img_format,
|
||||
scale=scale,
|
||||
visible=visible,
|
||||
markers=markers,
|
||||
)
|
||||
return b"".join(chunk for chunk in iterator if chunk)
|
||||
except (ChunkedEncodingError, IncompleteRead) as e:
|
||||
last_err = e
|
||||
if attempt == max_retries:
|
||||
break
|
||||
# An exponential backoff before retrying
|
||||
time.sleep(backoff_base * attempt)
|
||||
# If we got here, all retries failed; re-raise the last error, so the user sees the cause.
|
||||
raise last_err
|
||||
|
||||
def generate(
|
||||
self,
|
||||
points,
|
||||
zoom=6,
|
||||
size=(600, 600),
|
||||
map_type="roadmap",
|
||||
color="blue",
|
||||
label=None,
|
||||
marker_size="mid"
|
||||
):
|
||||
center, locations = get_center(points)
|
||||
|
||||
sm_marker = StaticMapMarker(
|
||||
locations=locations,
|
||||
size=marker_size,
|
||||
color=color,
|
||||
label=label,
|
||||
)
|
||||
|
||||
img_bytes = self.fetch_static_map_bytes(
|
||||
center=center,
|
||||
markers=[sm_marker],
|
||||
size=size,
|
||||
zoom=zoom,
|
||||
map_type=map_type,
|
||||
img_format="png",
|
||||
scale=2,
|
||||
)
|
||||
|
||||
return img_bytes
|
||||
272
week2/community-contributions/book_ticket_agent/tool_box.py
Normal file
272
week2/community-contributions/book_ticket_agent/tool_box.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import base64
|
||||
import json
|
||||
from travel_api import TravelAPI
|
||||
from api_key_loader import ApiKeyLoader
|
||||
from map_generator import MapGenerator
|
||||
from openai import OpenAI
|
||||
from typing import Any, Dict, List, Optional
|
||||
from PIL import Image
|
||||
import io
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
# Internal function specs (simple) used to build OpenAI-compatible tools list
|
||||
_FUNCTION_SPECS: Dict[str, Dict[str, Any]] = {
|
||||
"get_flight": {
|
||||
"name": "get_flight",
|
||||
"description": (
|
||||
"Get flight options from the departure airport (origin), destination airport, date and number of adults."
|
||||
"Before calling this function, you should have called 'get_airports' to get airport codes for the origin and destination - 2 calls in total. "
|
||||
"Call this when client ask to book a flight, for example when client asks 'Book ticket to Paris on 2023-01-01'. If the origin or destination city is missing ask client first to provide it."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"origin_location_code": {
|
||||
"type": "string",
|
||||
"description": "IATA code of the origin airport, e.g. 'MAD'",
|
||||
},
|
||||
"destination_location_code": {
|
||||
"type": "string",
|
||||
"description": "IATA code of the destination airport, e.g. 'ATH'",
|
||||
},
|
||||
"departure_date": {
|
||||
"type": "string",
|
||||
"description": "Date of departure in 'YYYY-MM-DD'",
|
||||
},
|
||||
"adults": {
|
||||
"type": "integer",
|
||||
"description": "Number of adult passengers (default 1)",
|
||||
}
|
||||
},
|
||||
"required": ["origin_location_code", "destination_location_code", "departure_date"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
},
|
||||
"get_airports": {
|
||||
"name": "get_airports",
|
||||
"description": (
|
||||
"Get airports for a city name using 'city'. Call this to resolve a city to airports."
|
||||
"The response contains a list of airport objects. Use the selected airport's 'iataCode' for get_flight."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name to search airports for",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
},
|
||||
"get_map": {
|
||||
"name": "get_map",
|
||||
"description": "Generate a Google Static Map PNG for a list of airport/location for given `city`. Call this function when user ask to show city airports on map",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name to search airports for and than show on map",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
},
|
||||
"get_boarding_pass": {
|
||||
"name": "get_boarding_pass",
|
||||
"description": "Generate a boarding pass for a flight. Call this when client asks for boarding pass.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"origin_location_code": {
|
||||
"type": "string",
|
||||
"description": "IATA code of the origin airport, e.g. 'MAD'",
|
||||
},
|
||||
"destination_location_code": {
|
||||
"type": "string",
|
||||
"description": "IATA code of the destination airport, e.g. 'ATH'",
|
||||
},
|
||||
"departure_date": {
|
||||
"type": "string",
|
||||
"description": "Date of departure in 'YYYY-MM-DD'",
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Passenger name",
|
||||
}
|
||||
},
|
||||
"required": ["origin_location_code", "destination_location_code", "departure_date", "name"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _to_openai_tools(specs: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert simple specs into OpenAI "tools" list schema."""
|
||||
tools: List[Dict[str, Any]] = []
|
||||
for spec in specs.values():
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": spec["name"],
|
||||
"description": spec.get("description", ""),
|
||||
"parameters": spec.get("parameters", {"type": "object"}),
|
||||
}
|
||||
})
|
||||
return tools
|
||||
|
||||
|
||||
def _tool_response(tool_call_id: Optional[str], payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": json.dumps(payload),
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
|
||||
def _parse_args(raw_args: Any) -> Dict[str, Any]:
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
return json.loads(raw_args) if raw_args else {}
|
||||
except Exception:
|
||||
return {}
|
||||
if isinstance(raw_args, dict):
|
||||
return raw_args
|
||||
return {}
|
||||
|
||||
def _extract_tool_call(tool_call: Any):
|
||||
function = getattr(tool_call, "function", None) or (
|
||||
tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||
)
|
||||
name = getattr(function, "name", None) or (
|
||||
function.get("name") if isinstance(function, dict) else None
|
||||
)
|
||||
raw_args = getattr(function, "arguments", None) or (
|
||||
function.get("arguments") if isinstance(function, dict) else None
|
||||
)
|
||||
call_id = getattr(tool_call, "id", None) or (
|
||||
tool_call.get("id") if isinstance(tool_call, dict) else None
|
||||
)
|
||||
return name, _parse_args(raw_args), call_id
|
||||
|
||||
|
||||
class ToolBox:
|
||||
def __init__(self, keys: ApiKeyLoader):
|
||||
self.travel_api = TravelAPI(keys.get("client_id"), keys.get("client_secret"))
|
||||
self.map_generator = MapGenerator(keys.get("google_map_api_key"))
|
||||
self.openai = OpenAI(api_key=keys.get("openai_api_key"))
|
||||
self.tools = _to_openai_tools(_FUNCTION_SPECS)
|
||||
self._fn_dispatch = {
|
||||
"get_flight": self.get_flight,
|
||||
"get_airports": self.get_airports,
|
||||
"get_map": self.get_map,
|
||||
}
|
||||
|
||||
def get_flight(self, origin_location_code, destination_location_code, departure_date, adults=1):
|
||||
return self.travel_api.get_flight(origin_location_code, destination_location_code, departure_date,
|
||||
adults=adults)
|
||||
|
||||
def get_airports(self, city):
|
||||
return self.travel_api.get_airport(city)
|
||||
|
||||
def get_map(self, city):
|
||||
airports = self.travel_api.get_airport(city)
|
||||
return airports, self.map_generator.generate(airports)
|
||||
|
||||
def get_toolset(self):
|
||||
return self.tools
|
||||
|
||||
def get_boarding_pass(self, origin_location_code, destination_location_code, departure_date, name):
|
||||
image_response = self.openai.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=f"An image representing a boarding pass from {origin_location_code} to {destination_location_code} for {name} and departure date {departure_date}",
|
||||
size="1024x1024",
|
||||
n=1,
|
||||
response_format="b64_json",
|
||||
)
|
||||
image_base64 = image_response.data[0].b64_json
|
||||
image_data = base64.b64decode(image_base64)
|
||||
return Image.open(BytesIO(image_data))
|
||||
|
||||
|
||||
def apply(self, message):
|
||||
"""Apply tool calls contained in an assistant message and return a list of tool messages."""
|
||||
results: List[Dict[str, Any]] = []
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
if not tool_calls:
|
||||
return results
|
||||
|
||||
generated_user_message: Optional[str] = None
|
||||
image = None
|
||||
|
||||
for tool_call in tool_calls:
|
||||
function_name, arguments, call_id = _extract_tool_call(tool_call)
|
||||
|
||||
if function_name == "get_flight":
|
||||
origin_location_code = arguments.get("origin_location_code")
|
||||
destination_location_code = arguments.get("destination_location_code")
|
||||
departure_date = arguments.get("departure_date")
|
||||
adults = arguments.get("adults") or 1
|
||||
|
||||
options = self.get_flight(
|
||||
origin_location_code,
|
||||
destination_location_code,
|
||||
departure_date,
|
||||
adults=adults,
|
||||
)
|
||||
results.append(_tool_response(call_id, {"flight_options": options}))
|
||||
|
||||
elif function_name == "get_boarding_pass":
|
||||
origin_location_code = arguments.get("origin_location_code")
|
||||
destination_location_code = arguments.get("destination_location_code")
|
||||
departure_date = arguments.get("departure_date")
|
||||
name = arguments.get("name")
|
||||
image = self.get_boarding_pass(origin_location_code, destination_location_code, departure_date, name)
|
||||
results.append(_tool_response(call_id, {"boarding_pass": f"boarding pass for {name} from {origin_location_code} to {destination_location_code} on {departure_date}."}))
|
||||
if generated_user_message is None:
|
||||
generated_user_message = (
|
||||
f"Here is my boarding pass for {name} from {origin_location_code} to {destination_location_code} on {departure_date}."
|
||||
)
|
||||
|
||||
elif function_name == "get_airports":
|
||||
city = arguments.get("city")
|
||||
airports = self.get_airports(city)
|
||||
results.append(_tool_response(call_id, {"airports": airports}))
|
||||
if generated_user_message is None:
|
||||
generated_user_message = (
|
||||
f"Here are the airports in {city}: {airports} Please help me to select one."
|
||||
)
|
||||
|
||||
elif function_name == "get_map":
|
||||
city = arguments.get("city")
|
||||
try:
|
||||
airports, img_bytes = self.get_map(city)
|
||||
if img_bytes:
|
||||
try:
|
||||
pil_img = Image.open(io.BytesIO(img_bytes))
|
||||
pil_img.load()
|
||||
if pil_img.mode not in ("RGB", "RGBA"):
|
||||
pil_img = pil_img.convert("RGB")
|
||||
image = pil_img
|
||||
except Exception:
|
||||
image = None
|
||||
results.append(_tool_response(call_id, {"airports": airports}))
|
||||
if generated_user_message is None:
|
||||
generated_user_message = (
|
||||
f"Here are the airports in {city}: {airports} Please help me to select one."
|
||||
)
|
||||
except Exception as e:
|
||||
results.append(_tool_response(call_id, {"error": f"get_map failed: {str(e)}"}))
|
||||
|
||||
else:
|
||||
# Unknown tool: respond so the model can recover gracefully.
|
||||
results.append(_tool_response(call_id, {"error": f"Unknown tool: {function_name}"}))
|
||||
|
||||
if generated_user_message:
|
||||
results.append({"role": "user", "content": generated_user_message})
|
||||
|
||||
return results, image
|
||||
@@ -0,0 +1,58 @@
|
||||
from amadeus import Client, Location, ResponseError
|
||||
|
||||
def filter_other_countries(airports):
|
||||
country_codes_wights = {}
|
||||
for airport in airports:
|
||||
country_code = airport["address"]["countryCode"]
|
||||
country_codes_wights[country_code] = country_codes_wights.get(country_code, 0) + 1
|
||||
country_code = max(country_codes_wights, key=country_codes_wights.get)
|
||||
return [airport for airport in airports if airport["address"]["countryCode"] == country_code]
|
||||
|
||||
|
||||
class TravelAPI:
|
||||
def __init__(self, client_id, client_secret):
|
||||
self.client = Client(client_id=client_id, client_secret=client_secret)
|
||||
|
||||
def get_airport(self, search):
|
||||
try:
|
||||
airport_locations = self.client.reference_data.locations.get(
|
||||
keyword=search,
|
||||
subType=Location.AIRPORT,
|
||||
)
|
||||
return filter_other_countries(airport_locations.data)
|
||||
except ResponseError as e:
|
||||
print(f"Amadeus API ResponseError in get_airport: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Unexpected error in get_airport: {e}")
|
||||
return []
|
||||
|
||||
def get_city(self, search, country_code="IT"):
|
||||
try:
|
||||
city_locations = self.client.reference_data.locations.get(
|
||||
keyword=search,
|
||||
subType=Location.CITY,
|
||||
countryCode=country_code
|
||||
)
|
||||
return city_locations.data
|
||||
except ResponseError as e:
|
||||
print(f"Amadeus API ResponseError in get_city: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Unexpected error in get_city: {e}")
|
||||
return []
|
||||
|
||||
def get_flight(self, origin_location_code, destination_location_code, departure_date, adults=1):
|
||||
try:
|
||||
offers = self.client.shopping.flight_offers_search.get(
|
||||
originLocationCode=origin_location_code,
|
||||
destinationLocationCode=destination_location_code,
|
||||
departureDate=departure_date,
|
||||
adults=adults)
|
||||
return offers.data
|
||||
except ResponseError as e:
|
||||
print(f"Amadeus API ResponseError in get_flight: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Unexpected error in get_flight: {e}")
|
||||
return []
|
||||
Reference in New Issue
Block a user