Flight conversation between 2 AI bots (Week 2)

This commit is contained in:
SABEEH Shaikh
2025-05-10 02:35:33 +02:00
parent 1fb53d70de
commit 22b5915d7e

View File

@@ -0,0 +1,631 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d8e24125-28e2-4d58-9684-ca2a5ce3d4ac",
"metadata": {},
"source": [
"# Automated Conversation between 2 bots\n",
"\n",
"## About Bots\n",
"This project accomplishes a back and forth conversation between a flight assistant bot and customer bot. The flight assistant bot is responsible for handling queries related to booking return flights in any European country, while the customer bot is given the task to find the cheapest ticket (with return) to any randomly chosen 5 European countries for a vacation holiday coming soon. You can read the the first 2 system prompts below to get a better overview. \n",
"\n",
"## Selecting LLMs\n",
"After doing a lot of trials, I found out that Anthropic's Claude model performance was not even close to the way Gemini and ChatGPT gave responses, with the same system prompt. Claude's response were empty (None) most of the time, even by swapping the role. If anyone figures out why please let me know (sabeehmehtab@gmail.com), thanks!\n",
"\n",
"## Tool Issues\n",
"I did implement the use of tools but for some reason ChatGPT model does not consider using it. Though my implementation of tools is a bit tricky, I have used a separate model (Claude because it failed above) for handling tool calls from a GPT chatting model when it has the role of a flight assistant. This tool handling Claude model receives a query/task input generated from the GPT and is given a further set of tools (3 in this case) to help it answer the query/task. The issue is it never gets till this point. The GPT model never uses it since it can figure out the answer to any query from the customer bot on its own. Just to mention, I did a few tries by changing the system prompt to kind of force it to use tools but did not get any success. "
]
},
{
"cell_type": "markdown",
"id": "9bf8e3d8-bfde-4a0e-b133-fa8cda87030e",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9dda7606-a5bf-490d-84ea-f1fb7e0116db",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"import os\n",
"import json\n",
"import time\n",
"import random\n",
"import anthropic\n",
"import gradio as gr\n",
"import google.generativeai\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"from datetime import date, timedelta"
]
},
{
"cell_type": "markdown",
"id": "24267c14-4025-48cf-af0b-1f8082d037f5",
"metadata": {},
"source": [
"## Setup keys from environment file"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0321895a-eee7-4d5e-98f0-0983178331f4",
"metadata": {},
"outputs": [],
"source": [
"#Load available keys from environment file \n",
"#Print the keys first 6 characters \n",
"\n",
"load_dotenv(override=True)\n",
"\n",
"openai_api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"ant_api_key = os.getenv(\"ANTHROPIC_API_KEY\")\n",
"goo_api_key = os.getenv(\"GOOGLE_API_KEY\")\n",
"\n",
"if openai_api_key:\n",
" print(f\"OpenAI API key exists and begins {openai_api_key[:6]}\")\n",
"else:\n",
" print(\"OpenAI API key does not exist\")\n",
"\n",
"if ant_api_key:\n",
" print(f\"Anthropic API key exists and begins {ant_api_key[:6]}\")\n",
"else:\n",
" print(\"Anthropic API key API key does not exist\")\n",
"\n",
"if goo_api_key:\n",
" print(f\"Google API key exists and begins {goo_api_key[:6]}\")\n",
"else:\n",
" print(\"Google API key does not exist\")"
]
},
{
"cell_type": "markdown",
"id": "2cb778fd-7f45-4271-b984-9349b32abe1b",
"metadata": {},
"source": [
"## Model(s) Initialization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f050192d-9cd4-45c1-9d26-a720bdaaf7ca",
"metadata": {},
"outputs": [],
"source": [
"# Setup code for OpenAI, Anthropic and Google\n",
"\n",
"openai = OpenAI()\n",
"gpt_model = \"gpt-4o-mini\"\n",
"\n",
"claude_sonnet = anthropic.Anthropic()\n",
"claude_model = \"claude-3-7-sonnet-latest\"\n",
"\n",
"google.generativeai.configure()\n",
"gemini_model = \"gemini-2.0-flash\""
]
},
{
"cell_type": "markdown",
"id": "55589a8e-3ca7-4218-a59d-20d51a1235e1",
"metadata": {},
"source": [
"## Define System Prompts"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea9534e0-8277-404d-aa69-f9aff87fca75",
"metadata": {},
"outputs": [],
"source": [
"system_prompt1 = \"You are a helpful assistant chatbot for an airline called 'Edge Air'. \\\n",
"You are to respond to any queires related to booking of flights in European countries. \\\n",
"You should offer a discount of 10% to European Nationals and a 5% discount on debit/credit card payments, when asked. \\\n",
"You are provided with a tool that you can use when customer query is related to return ticket price or flight duration or available dates. \\\n",
"Responses must be in a polite and courteous way, while encouraging the customer to buy a ticket as early as possible.\"\n",
"\n",
"system_prompt2 = \"You are a customer who wants to book a flight at 'Edge Air' airline, via a chatbot assistant. \\\n",
"You reside in Dubai and will be flying to Europe after 90 days from today on a vacation. \\\n",
"You are to choose any 5 countries in the European region and find the cheapest return ticket available. \\\n",
"You should ask for discounts and act smart to get the best available discount.\\\n",
"Remember to ask questions related to the return flight ticket price, available dates and duration to and from destination city. \\\n",
"Keep your responses short and precise.\"\n",
"\n",
"system_prompt3 = \"You are an airline flight booking manager who has access to multiple tools required \\\n",
"in the process of a booking. You will be given a query or task from a chabot assistant that should be responded \\\n",
"with the help of the tools provided. If no such tool exists to resolve the query/task at hand, \\\n",
"you must guess the solution and respond back with a high level of confidence. When taking a guess, \\\n",
"make sure that your solution is relevant to the query/task given by giving a second-thought to it.\"\n",
"\n",
"starting_prompt = \"Start of an autonomous conversation between two AI bots. They take turns for flight booking process discussion.\""
]
},
{
"cell_type": "markdown",
"id": "c00a45a4-bf50-4770-8599-29d082b80c65",
"metadata": {},
"source": [
"## Define Flight Assistant tools"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "785373a2-1bd6-4c7f-8eee-6765a45c7eba",
"metadata": {},
"outputs": [],
"source": [
"# Flight Assistant Tool\n",
"\n",
"def call_manager(task):\n",
" prompt = [\n",
" {\"role\" : \"system\", \"content\" : system_prompt3},\n",
" {\"role\" : \"user\", \"content\" : task}\n",
" ]\n",
" model = \"gemini-2.0-flash\"\n",
" gemini_via_openai_client = OpenAI(\n",
" api_key=goo_api_key, \n",
" base_url=\"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
" )\n",
" response = gemini_via_openai_client.chat.completions.create(model=model,messages=prompt)\n",
" return response.choices[0].message.content\n",
"\n",
"\n",
"# There's a particular dictionary structure that's required to describe our function:\n",
"manager_function = {\n",
" \"name\": \"call_manager\",\n",
" \"description\": \"Use this tool only when you are unsure about the answer to the clients query, like when you want to know the ticket price \\\n",
" of a country, available traveling dates, duration of the flight journey from one country to another or any other flight booking information \",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"task\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The query or task you want to resolve in simple words\",\n",
" },\n",
" },\n",
" \"required\": [\"task\"],\n",
" \"additionalProperties\": False\n",
" }\n",
"}\n",
"\n",
"assistant_tools = [{\"type\":\"function\",\"function\":manager_function}]"
]
},
{
"cell_type": "markdown",
"id": "df32fd9f-c890-455a-91c2-8a661b18163b",
"metadata": {},
"source": [
"## Define Flight Manager Tools"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33e7c5c4-3124-4f47-b075-e8916a4a368b",
"metadata": {},
"outputs": [],
"source": [
"# Flight Manager Tools\n",
"\n",
"fixed_city_durations = {\"france\":\"6 Hours\",\"berlin\":\"6.5 Hours\",\"germany\":\"7 Hours\",\"netherlands\":\"7.5 Hours\",\"spain\":\"5 Hours\"}\n",
"\n",
"def get_ticket_price():\n",
" price = random.randint(800, 2000)\n",
" return price\n",
"\n",
"def get_available_dates():\n",
" available_dates = []\n",
" no_of_dates = random.randint(15,30)\n",
" \n",
" start_date = date.today()\n",
" end_date = start_date + timedelta(180)\n",
" diff = end_date-current_date\n",
"\n",
" for day in range(no_of_dates):\n",
" random.seed(a=None)\n",
" rand_day = random.randrange(diff.days)\n",
" available_dates.append(current + timedelta(rand_day))\n",
"\n",
" return available_dates\n",
"\n",
"def get_duration(city):\n",
" city = city.lower()\n",
" if (city in fixed_city_durations.keys()):\n",
" return fixed_city_durations[city]\n",
" else:\n",
" return [f\"{random.randint(4,10) + random.random()} Hours\", f\"{random.randint(4,10) + random.random()} Hours\"]\n",
" "
]
},
{
"cell_type": "markdown",
"id": "46e77f2a-f5a1-467e-86bf-997fe86a30e4",
"metadata": {},
"source": [
"### Anthropic tool usage format "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b55d3db-ff9e-4706-b7ff-a29d28832eed",
"metadata": {},
"outputs": [],
"source": [
"# There's a particular Antrhopic Tool Object structure that's required to describe our tool function for Claude:\n",
"price_function = {\n",
" \"name\":\"get_ticket_price\",\n",
" \"description\":\"Use this tool to get the price of a return ticket to the destination city. It will return the price in the dollar currency.\",\n",
" \"input_schema\":{\n",
" \"type\": \"object\",\n",
" \"properties\": {},\n",
" \"required\": []\n",
" }\n",
"}\n",
"dates_function = {\n",
" \"name\":\"get_available_dates\",\n",
" \"description\":\"Use this tool for fetching the available dates of a flight to the destination city. It will return a list of dates that are avilable for travelling.\",\n",
" \"input_schema\":{\n",
" \"type\": \"object\",\n",
" \"properties\": {},\n",
" \"required\": []\n",
" }\n",
"}\n",
"duration_function = {\n",
" \"name\":\"get_duration\",\n",
" \"description\":\"Use this tool to get the flight durations to and from the destination city. It will return the two flight durations in hours in a string format in a list.\",\n",
" \"input_schema\":{\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"city\" : { \"type\":\"String\", \"description\":\"Name of the destination city\"}\n",
" },\n",
" \"required\": [\"city\"]\n",
" }\n",
"}\n",
"\n",
"anthropic_manager_tools = [price_function,dates_function,duration_function]\n",
"\n",
"openai_manager_tools = [\n",
" {\"type\":\"function\",\"function\":price_function},\n",
" {\"type\":\"function\",\"function\":dates_function},\n",
" {\"type\":\"function\",\"function\":duration_function}\n",
"]\n"
]
},
{
"cell_type": "markdown",
"id": "9fb43fdf-6eb5-44b3-841d-5aae05523ad2",
"metadata": {},
"source": [
"## Gradio Chatbot Conversation Structure"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e01b8c5e-a455-4d51-8683-ce07146b8a89",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
" Commented Claudes conversation chat funtion as it produces a lot of empty responses\n",
"\"\"\"\n",
"def get_structured_messages(history, system_prompt):\n",
" return [{\"role\" : \"system\", \"content\" : system_prompt}] + history\n",
"\n",
"def chat_gpt(system_prompt, history):\n",
" messages = get_structured_messages(history=history, system_prompt=system_prompt)\n",
"\n",
" response = openai.chat.completions.create(model=gpt_model, messages=messages)#, tools=assistant_tools)\n",
"\n",
" if (response.choices[0].finish_reason==\"tool_calls\"):\n",
" message = response.choices[0].message\n",
" response = handle_assistant_tool_call(message)\n",
" messages.append({\"role\" : \"assistant\", \"content\" : messages.content})\n",
" messages.append(response)\n",
" response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
"\n",
" return response.choices[0].message.content\n",
"\n",
"# def chat_claude(system_prompt, history): \n",
"# response = claude_sonnet.messages.create(\n",
"# model=claude_model,\n",
"# max_tokens=200,\n",
"# temperature=0.7,\n",
"# system=system_prompt,\n",
"# messages=history,\n",
"# )\n",
"# try:\n",
"# text = response.content[0].text\n",
"# except:\n",
"# print(\"No response from claude\")\n",
"# text = \"\"\n",
"# return text\n",
"\n",
"def chat_gemini(system_prompt, history):\n",
" gemini = google.generativeai.GenerativeModel(\n",
" model_name=gemini_model,\n",
" system_instruction=system_prompt\n",
" )\n",
" response = gemini.generate_content(json.dumps(history))\n",
" # print(f\"Gemini Response: \\n{response}\")\n",
" return response.text"
]
},
{
"cell_type": "markdown",
"id": "b6d05d4b-0d4f-4bee-82d5-d1a3b6b36551",
"metadata": {},
"source": [
"## Handling Tool Calls"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "208b7b23-ae83-4bf4-a31f-5309a747ea86",
"metadata": {},
"outputs": [],
"source": [
"def handle_assistant_tool_call(message):\n",
" content_list = []\n",
" tool_calls = message.tool_calls\n",
" print(f\"List of tool call: \\n{tool_calls}\")\n",
" for tool in tool_calls:\n",
" try:\n",
" arguments = json.loads(tool_call.function.arguments)\n",
" except:\n",
" print(\"Error loading arguments from tool call\")\n",
" print(f\"Arguments in json format: \\n{arguments}\")\n",
" task = arguments.get('task')\n",
" content = run_manager_llm(task)\n",
" content_list.append(content)\n",
" response = {\n",
" \"role\": \"tool\",\n",
" \"content\": content_list,\n",
" \"tool_call_id\": tool_call.id\n",
" }\n",
" return response\n",
" \n",
"# Anthropic Claude-Sonnet\n",
"def run_manager_llm(task):\n",
" user_prompt = [\n",
" {\"role\":\"user\", \"content\": task}\n",
" ]\n",
"\n",
" response = claude_sonnet.messages.create(\n",
" model=claude_model,\n",
" max_tokens=1024,\n",
" tools=anthropic_manager_tools,\n",
" tool_choice='auto',\n",
" temperature=0.7,\n",
" system=system_prompt3,\n",
" messages=user_prompt,\n",
" )\n",
"\n",
" tool_use = response.content[0].tool_use\n",
" print(f\"Claude tool help: {tool_use}\")\n",
" \n",
" if tool_use.name==\"get_ticket_price\":\n",
" price = get_ticket_price()\n",
" response = manager_tool_response(user_prompt,tool_use,price)\n",
" \n",
" elif tool_use.name==\"get_available_dates\":\n",
" dates = get_available_dates()\n",
" response = manager_tool_response(user_prompt,tool_use,dates)\n",
" elif tool_use.name==\"get_duration\":\n",
" duration = get_duration(tool_use.input[\"city\"])\n",
" response = manager_tool_response(user_prompt,tool_use,duration)\n",
"\n",
" try:\n",
" text = response.content[0].text\n",
" except:\n",
" print(\"No response from claude\")\n",
" text = \"\"\n",
" return text\n",
"\n",
"# Function for generating response after tool usage\n",
"def manager_tool_response(user_prompt, tool_use, content):\n",
" user_prompt.append({\"role\":\"assistant\",\"content\": [\n",
" {\n",
" \"type\": \"tool_use\", \"tool_use_id\": tool_use.tool_use_id, \"name\": tool_use.name, \"input\": tool_use.input,\n",
" }\n",
" ]})\n",
" user_prompt.append({\"role\":\"user\",\"content\": [\n",
" {\n",
" \"type\": \"tool_result\", \"tool_use_id\": tool_use.tool_use_id, \"content\": content,\n",
" }\n",
" ]})\n",
" response = claude_sonnet.messages.create(\n",
" model=claude_model,\n",
" max_tokens=1024,\n",
" tools=anthropic_manager_tools,\n",
" tool_choice='auto',\n",
" temperature=0.7,\n",
" system=system_prompt3,\n",
" messages=user_prompt,\n",
" )\n",
" return response"
]
},
{
"cell_type": "markdown",
"id": "b9e12b32-4ac7-4825-bd5e-d531597ebc5c",
"metadata": {},
"source": [
"## Build UI using Gradio"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fbeb49d9-e4d3-4e16-92ea-2f1fbf9a610d",
"metadata": {},
"outputs": [],
"source": [
"chatbot_models = [\"ChatGPT\", \"Gemini\"]\n",
"\n",
"with gr.Blocks() as demo:\n",
" gr.Markdown(\"# 🤖 AI Chatbot Conversation\")\n",
" gr.Markdown(\"Watch two AI chatbots have a conversation with each other.\")\n",
" \n",
" is_conversation_active = gr.State(True)\n",
" turns_count = gr.State(0)\n",
" \n",
" with gr.Row():\n",
" with gr.Column(scale=3):\n",
" # Chat display\n",
" chatbot = gr.Chatbot(\n",
" type='messages',\n",
" label=\"Bot Conversation\",\n",
" height=500,\n",
" elem_id=\"chatbot\",\n",
" avatar_images=(\"🧑\", \"🤖\",)\n",
" )\n",
" \n",
" # Controls\n",
" with gr.Row(elem_classes=\"controls\"):\n",
" start_btn = gr.Button(\"Start Conversation\", variant=\"primary\")\n",
" stop_btn = gr.Button(\"Stop\", variant=\"stop\")\n",
" clear_btn = gr.Button(\"Clear Conversation\")\n",
" \n",
" # Conversation settings\n",
" with gr.Row():\n",
" max_turns = gr.Slider(\n",
" minimum=5,\n",
" maximum=20,\n",
" value=8,\n",
" step=1,\n",
" label=\"Maximum Conversation Turns\",\n",
" info=\"How many exchanges between the bots\"\n",
" )\n",
" delay = gr.Slider(\n",
" minimum=1,\n",
" maximum=5,\n",
" value=2,\n",
" step=0.5,\n",
" label=\"Delay Between Responses (seconds)\",\n",
" info=\"Simulates thinking time\"\n",
" )\n",
" \n",
" with gr.Column(scale=1):\n",
" gr.Markdown(\"### About\")\n",
" gr.Markdown(\"\"\"\n",
" This interface simulates a flight booking conversation between two AI chatbots.\n",
" \n",
" - Click \"Start Conversation\" to begin\n",
" - The bots will automatically exchange messages\n",
" - You can stop the conversation at any time\n",
" \n",
" \"\"\")\n",
" bot1 = gr.Dropdown(chatbot_models, show_label=True, label=\"Flight Assistant Model (left)\", multiselect=False)\n",
" bot2 = gr.Dropdown(chatbot_models, show_label=True, label=\"Customer Model (right)\", multiselect=False)\n",
"\n",
" def bot_response(model, system_prompt, history):\n",
" if model==chatbot_models[0]:\n",
" return chat_gpt(system_prompt=system_prompt,history=history)\n",
" else:\n",
" return chat_gemini(system_prompt=system_prompt,history=history)\n",
" \n",
" # Function to update the conversation display\n",
" def start_conversation(turns, max_turns, delay_time, bot1_model, bot2_model):\n",
" history = []\n",
" conversation = []\n",
" history.append({\"role\":\"user\",\"content\":starting_prompt})\n",
" global is_conversation_active\n",
" is_conversation_active=True\n",
" \n",
" while is_conversation_active and turns < max_turns:\n",
" # Airline Assistant Responds first. Change chat function to change bot model \n",
" message = bot_response(bot1_model,system_prompt1,history=history)\n",
" print(f\"(assistant): \\n{message}\")\n",
" conversation.append({\"role\":\"assistant\",\"content\":message})\n",
" history.append({\"role\":\"assistant\", \"content\": message})\n",
" yield conversation, turns \n",
" time.sleep(delay_time)\n",
" \n",
" # Customer responds next. Change chat function to change bot model \n",
" reply = bot_response(bot2_model,system_prompt2,history=history)\n",
" print(f\"(customer): \\n{reply}\")\n",
" conversation.append({\"role\":\"user\",\"content\":reply})\n",
" history.append({\"role\":\"assistant\", \"content\": reply})\n",
" turns+=1\n",
" yield conversation, turns\n",
" time.sleep(delay_time)\n",
" \n",
" \n",
" # Function to stop the conversation\n",
" def stop_conversation():\n",
" global is_conversation_active\n",
" is_conversation_active=False\n",
" \n",
" \n",
" # Function to clear the conversation\n",
" def clear_conversation():\n",
" global is_conversation_active\n",
" is_conversation_active=False\n",
" return [], 0\n",
" \n",
" # Set up the event handlers\n",
" start_btn.click(\n",
" start_conversation,\n",
" inputs=[turns_count, max_turns, delay, bot1, bot2],\n",
" outputs=[chatbot, turns_count]\n",
" )\n",
" \n",
" stop_btn.click(\n",
" stop_conversation,\n",
" outputs=[]\n",
" )\n",
" \n",
" clear_btn.click(\n",
" clear_conversation,\n",
" outputs=[chatbot, turns_count]\n",
" )\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5558451-bb90-4ec7-9063-716b60f07e19",
"metadata": {},
"outputs": [],
"source": [
"demo.launch(share=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}