{ "cells": [ { "cell_type": "markdown", "id": "ddfa9ae6-69fe-444a-b994-8c4c5970a7ec", "metadata": {}, "source": [ "# Week 2 Exercise - with Booking, Translation and Speech-To-Text" ] }, { "cell_type": "code", "execution_count": null, "id": "8ccbf174-a724-46a8-9db4-addd249923a0", "metadata": {}, "outputs": [], "source": [ "# Note: The speech-to-text functionality requires FFmpeg to be installed. Go to FFmpeg website and downoad the corresponding OS installer.\n", "# !pip install openai-whisper sounddevice scipy numpy" ] }, { "cell_type": "code", "execution_count": null, "id": "8b50bbe2-c0b1-49c3-9a5c-1ba7efa2bcb4", "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "import os\n", "import json\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", "import gradio as gr\n", "from anthropic import Anthropic\n", "import numpy as np\n", "import sounddevice as sd\n", "import scipy.io.wavfile as wav\n", "import tempfile\n", "import whisper" ] }, { "cell_type": "code", "execution_count": null, "id": "747e8786-9da8-4342-b6c9-f5f69c2e22ae", "metadata": {}, "outputs": [], "source": [ "# Initialization\n", "load_dotenv(override=True)\n", "openai_api_key = os.getenv('OPENAI_API_KEY')\n", "anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n", "# Initialize clients\n", "MODEL = \"gpt-4o-mini\"\n", "STT_DURATION = 3\n", "openai = OpenAI()\n", "anthropic = Anthropic(api_key=anthropic_api_key)" ] }, { "cell_type": "code", "execution_count": null, "id": "0a521d84-d07c-49ab-a0df-d6451499ed97", "metadata": {}, "outputs": [], "source": [ "system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n", "system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n", "system_message += \"Always be accurate. If you don't know the answer, say so.\"" ] }, { "cell_type": "code", "execution_count": null, "id": "0696acb1-0b05-4dc2-80d5-771be04f1fb2", "metadata": {}, "outputs": [], "source": [ "# get ticket price function\n", "\n", "ticket_prices = {\"london\": \"$799\", \"paris\": \"$899\", \"tokyo\": \"$1400\", \"berlin\": \"$499\", \"rome\": \"$699\", \"bucharest\": \"$949\", \"moscow\": \"$1199\"}\n", "\n", "def get_ticket_price(destination_city):\n", " print(f\"Tool get_ticket_price called for {destination_city}\")\n", " city = destination_city.lower()\n", " return ticket_prices.get(city, \"Unknown\")\n", "\n", "# create booking function\n", "import random\n", "\n", "def create_booking(destination_city):\n", " # Generate a random 6-digit number\n", " digits = ''.join([str(random.randint(0, 9)) for _ in range(6)]) \n", " booking_number = f\"AI{digits}\"\n", " \n", " # Print the booking confirmation message\n", " print(f\"Booking {booking_number} created for the flight to {destination_city}\")\n", " \n", " return booking_number" ] }, { "cell_type": "code", "execution_count": null, "id": "4afceded-7178-4c05-8fa6-9f2085e6a344", "metadata": {}, "outputs": [], "source": [ "# price function structure:\n", "\n", "price_function = {\n", " \"name\": \"get_ticket_price\",\n", " \"description\": \"Get the price of a return ticket to the destination city. Call this whenever you need to know the ticket price, for example when a customer asks 'How much is a ticket to this city'\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to travel to\",\n", " },\n", " },\n", " \"required\": [\"destination_city\"],\n", " \"additionalProperties\": False\n", " }\n", "}\n", "\n", "# booking function structure:\n", "booking_function = {\n", " \"name\": \"make_booking\",\n", " \"description\": \"Make a flight booking for the customer. Call this whenever a customer wants to book a flight to a destination.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to travel to\",\n", " },\n", " },\n", " \"required\": [\"destination_city\"],\n", " \"additionalProperties\": False\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "bdca8679-935f-4e7f-97e6-e71a4d4f228c", "metadata": {}, "outputs": [], "source": [ "# List of tools:\n", "\n", "tools = [\n", " {\"type\": \"function\", \"function\": price_function},\n", " {\"type\": \"function\", \"function\": booking_function}\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "b0992986-ea09-4912-a076-8e5603ee631f", "metadata": {}, "outputs": [], "source": [ "# Function handle_tool_call:\n", "\n", "def handle_tool_call(message):\n", " tool_call = message.tool_calls[0]\n", " function_name = tool_call.function.name\n", " arguments = json.loads(tool_call.function.arguments)\n", " \n", " if function_name == \"get_ticket_price\":\n", " city = arguments.get('destination_city')\n", " price = get_ticket_price(city)\n", " response = {\n", " \"role\": \"tool\",\n", " \"content\": json.dumps({\"destination_city\": city,\"price\": price}),\n", " \"tool_call_id\": tool_call.id\n", " }\n", " return response, city\n", " elif function_name == \"make_booking\":\n", " city = arguments.get('destination_city')\n", " booking_number = create_booking(city)\n", " response = {\n", " \"role\": \"tool\",\n", " \"content\": json.dumps({\"destination_city\": city, \"booking_number\": booking_number}),\n", " \"tool_call_id\": tool_call.id\n", " }\n", " return response, city" ] }, { "cell_type": "code", "execution_count": null, "id": "773a9f11-557e-43c9-ad50-56cbec3a0f8f", "metadata": {}, "outputs": [], "source": [ "# Image generation\n", "\n", "import base64\n", "from io import BytesIO\n", "from PIL import Image\n", "\n", "def artist(city, testing_mode=False):\n", " if testing_mode:\n", " print(f\"Image generation skipped for {city} - in testing mode\")\n", " return None\n", " \n", " image_response = openai.images.generate(\n", " model=\"dall-e-3\",\n", " prompt=f\"An image representing a vacation in {city}, showing tourist spots and everything unique about {city}, in a realistic style\",\n", " size=\"1024x1024\",\n", " n=1,\n", " response_format=\"b64_json\",\n", " )\n", " image_base64 = image_response.data[0].b64_json\n", " image_data = base64.b64decode(image_base64)\n", " return Image.open(BytesIO(image_data))" ] }, { "cell_type": "code", "execution_count": null, "id": "7d1519a8-98ed-4673-ade0-aaba6341f155", "metadata": {}, "outputs": [], "source": [ "# Text to speech \n", "\n", "import base64\n", "from io import BytesIO\n", "from PIL import Image\n", "from IPython.display import Audio, display\n", "\n", "def talker(message, testing_mode=False):\n", " \"\"\"Generate speech from text and return the path to the audio file for Gradio to play\"\"\"\n", " if testing_mode:\n", " print(f\"Text-to-speech skipped - in testing mode\")\n", " return None\n", " \n", " try:\n", " response = openai.audio.speech.create(\n", " model=\"tts-1\",\n", " voice=\"onyx\",\n", " input=message)\n", "\n", " # Save to a unique filename based on timestamp to avoid caching issues\n", " import time\n", " timestamp = int(time.time())\n", " output_filename = f\"output_audio_{timestamp}.mp3\"\n", " \n", " with open(output_filename, \"wb\") as f:\n", " f.write(response.content)\n", " \n", " print(f\"Audio saved to {output_filename}\")\n", " return output_filename\n", " except Exception as e:\n", " print(f\"Error generating speech: {e}\")\n", " return None" ] }, { "cell_type": "code", "execution_count": null, "id": "68149e08-d2de-4790-914a-6def79ff5612", "metadata": {}, "outputs": [], "source": [ "# Speech to text function\n", "\n", "def recorder_and_transcriber(duration=STT_DURATION, samplerate=16000, testing_mode=False):\n", " \"\"\"Record audio for the specified duration and transcribe it using Whisper\"\"\"\n", " if testing_mode:\n", " print(\"Speech-to-text skipped - in testing mode\")\n", " return \"This is a test speech input\"\n", " \n", " print(f\"Recording for {duration} seconds...\")\n", " \n", " # Record audio using sounddevice\n", " recording = sd.rec(int(duration * samplerate), samplerate=samplerate, channels=1, dtype='float32')\n", " sd.wait() # Wait until recording is finished\n", " \n", " # Save the recording to a temporary WAV file\n", " with tempfile.NamedTemporaryFile(suffix=\".wav\", delete=False) as temp_audio:\n", " temp_filename = temp_audio.name\n", " wav.write(temp_filename, samplerate, recording)\n", " \n", " # Load Whisper model and transcribe\n", " model = whisper.load_model(\"base\") # You can use \"tiny\", \"base\", \"small\", \"medium\", or \"large\"\n", " result = model.transcribe(temp_filename)\n", " \n", " # Clean up the temporary file\n", " import os\n", " os.unlink(temp_filename)\n", " \n", " return result[\"text\"].strip()" ] }, { "cell_type": "code", "execution_count": null, "id": "bf1d5600-8df8-4cc2-8bf5-b0b33818b385", "metadata": {}, "outputs": [], "source": [ "import os\n", "import glob\n", "\n", "def cleanup_audio_files():\n", " \"\"\"Delete all MP3 files in the current directory that match our output pattern\"\"\"\n", " \n", " # Get all mp3 files that match our naming pattern\n", " mp3_files = glob.glob(\"output_audio_*.mp3\")\n", " \n", " # Delete each file\n", " count = 0\n", " for file in mp3_files:\n", " try:\n", " os.remove(file)\n", " count += 1\n", " except Exception as e:\n", " print(f\"Error deleting {file}: {e}\")\n", " \n", " print(f\"Cleaned up {count} audio files\")\n", " return None" ] }, { "cell_type": "code", "execution_count": null, "id": "44a6f8e0-c111-4e40-a5ae-68dd0aa9f65d", "metadata": {}, "outputs": [], "source": [ "# Translation function\n", "\n", "def translate_text(text, target_language):\n", " if not text or not target_language:\n", " return \"\"\n", " \n", " # Map the language dropdown values to language names for Claude\n", " language_map = {\n", " \"French\": \"French\",\n", " \"Spanish\": \"Spanish\",\n", " \"German\": \"German\",\n", " \"Italian\": \"Italian\",\n", " \"Russian\": \"Russian\",\n", " \"Romanian\": \"Romanian\"\n", " }\n", " \n", " full_language_name = language_map.get(target_language, \"French\")\n", " \n", " try:\n", " response = anthropic.messages.create(\n", " model=\"claude-3-haiku-20240307\",\n", " max_tokens=1024,\n", " messages=[\n", " {\n", " \"role\": \"user\",\n", " \"content\": f\"Translate the following text to {full_language_name}. Provide only the translation, no explanations: \\n\\n{text}\"\n", " }\n", " ]\n", " )\n", " return response.content[0].text\n", " except Exception as e:\n", " print(f\"Translation error: {e}\")\n", " return f\"[Translation failed: {str(e)}]\"" ] }, { "cell_type": "code", "execution_count": null, "id": "ba820c95-02f5-499e-8f3c-8727ee0a6c0c", "metadata": {}, "outputs": [], "source": [ "def chat(history, image, testing_mode=False):\n", " messages = [{\"role\": \"system\", \"content\": system_message}] + history\n", " response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools) \n", " \n", " if response.choices[0].finish_reason==\"tool_calls\":\n", " message = response.choices[0].message\n", " response, city = handle_tool_call(message)\n", " messages.append(message)\n", " messages.append(response)\n", " \n", " # Only generate image if not in testing mode\n", " if not testing_mode and image is None:\n", " image = artist(city, testing_mode)\n", " \n", " response = openai.chat.completions.create(model=MODEL, messages=messages)\n", " \n", " reply = response.choices[0].message.content\n", " history += [{\"role\":\"assistant\", \"content\":reply}] \n", "\n", " # Return the reply directly - we'll handle TTS separately\n", " return history, image, reply" ] }, { "cell_type": "code", "execution_count": null, "id": "a3cc58f3-d0fc-47d1-b9cf-e5bf4c5edbdc", "metadata": {}, "outputs": [], "source": [ "# Function to translate conversation history\n", "def translate_history(history, target_language):\n", " if not history or not target_language:\n", " return []\n", " \n", " translated_history = []\n", " \n", " for msg in history:\n", " role = msg[\"role\"]\n", " content = msg[\"content\"]\n", " \n", " translated_content = translate_text(content, target_language)\n", " translated_history.append({\"role\": role, \"content\": translated_content})\n", " \n", " return translated_history" ] }, { "cell_type": "code", "execution_count": null, "id": "f38d0d27-33bf-4992-a2e5-5dbed973cde7", "metadata": {}, "outputs": [], "source": [ "# Update the Gradio interface to handle audio output properly\n", "def update_gradio_interface():\n", " with gr.Blocks() as ui:\n", " # Store chat history and audio output in state\n", " state = gr.State([])\n", " audio_state = gr.State(None)\n", " \n", " with gr.Row():\n", " testing_checkbox = gr.Checkbox(label=\"Testing\", info=\"Turn off multimedia features when checked\", value=False)\n", " \n", " with gr.Row():\n", " with gr.Column(scale=2):\n", " # Main panel with original chat and image\n", " with gr.Row():\n", " with gr.Column(scale=1):\n", " with gr.Row():\n", " chatbot = gr.Chatbot(height=300, type=\"messages\")\n", " with gr.Row():\n", " language_dropdown = gr.Dropdown(\n", " choices=[\"French\", \"Spanish\", \"German\", \"Italian\", \"Russian\", \"Romanian\"],\n", " value=\"French\",\n", " label=\"Translation to\"\n", " )\n", " with gr.Row():\n", " translation_output = gr.Chatbot(height=200, type=\"messages\", label=\"Translated chat\")\n", " with gr.Column(scale=1):\n", " with gr.Row():\n", " image_output = gr.Image(height=620)\n", " with gr.Row():\n", " audio_output = gr.Audio(label=\"Assistant's Voice\", visible=False, autoplay=True, type=\"filepath\")\n", " \n", " with gr.Row():\n", " entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n", " \n", " with gr.Row():\n", " with gr.Column(scale=1):\n", " with gr.Row():\n", " md = gr.Markdown()\n", " with gr.Column(scale=1):\n", " speak_button = gr.Button(value=\"🎤 Speak Command\", variant=\"primary\")\n", " with gr.Column(scale=1):\n", " with gr.Row():\n", " md = gr.Markdown()\n", " with gr.Column(scale=1): \n", " with gr.Row():\n", " clear = gr.Button(value=\"Clear\", variant=\"secondary\")\n", " with gr.Column(scale=1):\n", " with gr.Row():\n", " md = gr.Markdown()\n", "\n", " # Function to handle speech input\n", " def do_speech_input(testing_mode):\n", " # Record and transcribe speech\n", " speech_text = recorder_and_transcriber(duration=STT_DURATION, testing_mode=testing_mode)\n", " return speech_text\n", " \n", " # Function to handle user input\n", " def do_entry(message, history, testing_mode):\n", " history += [{\"role\":\"user\", \"content\":message}]\n", " return \"\", history\n", " \n", " # Function to handle translation updates\n", " def do_translation(history, language):\n", " translated = translate_history(history, language)\n", " return translated\n", " \n", " # Function to handle text-to-speech\n", " def do_tts(reply, testing_mode):\n", " if not reply or testing_mode:\n", " return None\n", " \n", " audio_file = talker(reply, testing_mode)\n", " return audio_file\n", " \n", " # Handle user message submission\n", " entry.submit(do_entry, inputs=[entry, chatbot, testing_checkbox], outputs=[entry, chatbot]).then(\n", " chat, inputs=[chatbot, image_output, testing_checkbox], outputs=[chatbot, image_output, audio_state]\n", " ).then(\n", " do_tts, inputs=[audio_state, testing_checkbox], outputs=[audio_output]\n", " ).then(\n", " do_translation, inputs=[chatbot, language_dropdown], outputs=[translation_output]\n", " )\n", " \n", " # Add speech button handling\n", " speak_button.click(\n", " do_speech_input, \n", " inputs=[testing_checkbox], \n", " outputs=[entry]\n", " ).then(\n", " do_entry, \n", " inputs=[entry, chatbot, testing_checkbox], \n", " outputs=[entry, chatbot]\n", " ).then(\n", " chat, \n", " inputs=[chatbot, image_output, testing_checkbox], \n", " outputs=[chatbot, image_output, audio_state]\n", " ).then(\n", " do_tts, inputs=[audio_state, testing_checkbox], outputs=[audio_output]\n", " ).then(\n", " do_translation, \n", " inputs=[chatbot, language_dropdown], \n", " outputs=[translation_output]\n", " )\n", " \n", " # Update translation when language is changed\n", " language_dropdown.change(do_translation, inputs=[chatbot, language_dropdown], outputs=[translation_output])\n", " \n", " # Handle clear button\n", " def clear_all():\n", " # Clean up audio files\n", " cleanup_audio_files()\n", " # Return None for all outputs to clear the UI\n", " return None, None, None, None\n", " \n", " clear.click(clear_all, inputs=None, outputs=[chatbot, translation_output, image_output, audio_output], queue=False)\n", "\n", " return ui\n", "\n", "# Replace the original ui code with this:\n", "ui = update_gradio_interface()\n", "ui.launch(inbrowser=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.12" } }, "nbformat": 4, "nbformat_minor": 5 }