Merge pull request #179 from udomai/w2d5_exercise_with_STT
week 2 exercise solution with STT
This commit is contained in:
448
week2/community-contributions/d5_TravelAgent_google_STT.ipynb
Normal file
448
week2/community-contributions/d5_TravelAgent_google_STT.ipynb
Normal file
@@ -0,0 +1,448 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "36e0cd9c-6622-4fa9-a4f8-b3da1b9b836e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import json\n",
|
||||||
|
"from dotenv import load_dotenv\n",
|
||||||
|
"from openai import OpenAI\n",
|
||||||
|
"import gradio as gr\n",
|
||||||
|
"import random\n",
|
||||||
|
"import re\n",
|
||||||
|
"import base64\n",
|
||||||
|
"from io import BytesIO\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"from IPython.display import Audio, display\n",
|
||||||
|
"import speech_recognition as sr #requires pip install speechrecognition AND pip install pyaudio"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "57fc95b9-043c-4a38-83aa-365cc3b285ba",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"load_dotenv()\n",
|
||||||
|
"\n",
|
||||||
|
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
||||||
|
"if openai_api_key:\n",
|
||||||
|
" print(f\"OpenAI API Key exists and begins with {openai_api_key[:8]}\")\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"OpenAI API Key? As if!\")\n",
|
||||||
|
" \n",
|
||||||
|
"MODEL = \"gpt-4o-mini\"\n",
|
||||||
|
"openai = OpenAI()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e633ee2a-bbaa-47a4-95ef-b1d8773866aa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n",
|
||||||
|
"system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n",
|
||||||
|
"system_message += \"Always be accurate. If you don't know the answer, say so. \"\n",
|
||||||
|
"system_message += \"You can book flights directly. \"\n",
|
||||||
|
"system_message += \"You can generate beautiful artistic renditions of the cities we fly to.\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c123af78-b5d6-4cc9-8f18-c492b1f30c85",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# ticket price function\n",
|
||||||
|
"\n",
|
||||||
|
"#spelled-out currency notation for better tts rendition\n",
|
||||||
|
"ticket_prices = {\"valletta\": \"799 Dollars\", \"turin\": \"899 Dollars\", \"sacramento\": \"1400 Dollars\", \"montreal\": \"499 Dollars\"}\n",
|
||||||
|
"\n",
|
||||||
|
"def get_ticket_price(destination_city):\n",
|
||||||
|
" print(f\"Tool get_ticket_price called for {destination_city}\")\n",
|
||||||
|
" city = destination_city.lower()\n",
|
||||||
|
" return ticket_prices.get(city, \"Unknown\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "00e486fb-709e-4b8e-a029-9e2b225ddc25",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# travel booking function\n",
|
||||||
|
"\n",
|
||||||
|
"def book_flight(destination_city):\n",
|
||||||
|
" booking_code = ''.join(random.choice('0123456789BCDFXYZ') for i in range(2)) + ''.join(random.choice('012346789HIJKLMNOPQRS') for i in range(2)) + ''.join(random.choice('0123456789GHIJKLMNUOP') for i in range(2))\n",
|
||||||
|
" print(f\"Booking code {booking_code} generated for flight to {destination_city}.\")\n",
|
||||||
|
" \n",
|
||||||
|
" return booking_code"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c0600b4e-fa4e-4c34-b317-fac1e60b5f95",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# verify if booking code is valid (i.e. follows the pattern)\n",
|
||||||
|
"\n",
|
||||||
|
"def check_code(code):\n",
|
||||||
|
" valid = \"valid\" if re.match(\"^[0123456789BCDFXYZ]{2}[012346789HIJKLMNOPQRS]{2}[0123456789GHIJKLMNUOP]{2}$\", code) != None else \"not valid\"\n",
|
||||||
|
" print(f\"Code checker called for code {code}, which is {valid}.\")\n",
|
||||||
|
" return re.match(\"^[0123456789BCDFXYZ]{2}[012346789HIJKLMNOPQRS]{2}[0123456789GHIJKLMNUOP]{2}$\", code) != None"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e1d1b1c2-089c-41e5-b1bd-900632271093",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# make a nice preview of the travel destination\n",
|
||||||
|
"\n",
|
||||||
|
"def artist(city):\n",
|
||||||
|
" image_response = openai.images.generate(\n",
|
||||||
|
" model=\"dall-e-3\",\n",
|
||||||
|
" prompt=f\"Make an image in the style of a vibrant, artistically filtered photo that is a collage of the best sights and views in {city}.\",\n",
|
||||||
|
" size=\"1024x1024\",\n",
|
||||||
|
" n=1,\n",
|
||||||
|
" response_format=\"b64_json\",\n",
|
||||||
|
" )\n",
|
||||||
|
" image_base64 = image_response.data[0].b64_json\n",
|
||||||
|
" image_data = base64.b64decode(image_base64)\n",
|
||||||
|
" img = Image.open(BytesIO(image_data))\n",
|
||||||
|
"\n",
|
||||||
|
" img.save(\"img001.png\") #make them 4 cents count! .save is from PIL library, btw\n",
|
||||||
|
" \n",
|
||||||
|
" return img"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "626d99af-90de-4594-9ffd-b87a8b6ef4fd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"price_function = {\n",
|
||||||
|
" \"name\": \"get_ticket_price\",\n",
|
||||||
|
" \"description\": \"Get the price of a return ticket to the destination city. Call this whenever you need to know the ticket price, for example when a customer asks 'How much is a ticket to this city'\",\n",
|
||||||
|
" \"parameters\": {\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"destination_city\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The city that the customer wants to travel to\",\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" \"required\": [\"destination_city\"],\n",
|
||||||
|
" \"additionalProperties\": False\n",
|
||||||
|
" }\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6e7bc09c-665b-4885-823c-f145cefe8c23",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"booking_function = {\n",
|
||||||
|
" \"name\": \"book_flight\",\n",
|
||||||
|
" \"description\": \"Call this whenever you have to book a flight. Give it the destination city and you will get a booking code. Tell the customer \\\n",
|
||||||
|
"that the flight is booked and give them the booking code obtained through this function. Never give any other codes to the customer.\",\n",
|
||||||
|
" \"parameters\": {\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"destination_city\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The city that the customer wants to book their flight to\",\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" \"required\": [\"destination_city\"],\n",
|
||||||
|
" \"additionalProperties\": False\n",
|
||||||
|
" }\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "cc365d87-fed2-41ff-9232-850fdce1cff2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"artist_function = {\n",
|
||||||
|
" \"name\": \"artist\",\n",
|
||||||
|
" \"description\": \"Call this whenever you need to generate a picture, photo, or graphic impression of a city.\",\n",
|
||||||
|
" \"parameters\": {\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"city\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The city of which an image is to be generated\",\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" \"required\": [\"city\"],\n",
|
||||||
|
" \"additionalProperties\": False\n",
|
||||||
|
" }\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "99b0a0e3-db44-49f9-8d27-349b9f04c680",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"codecheck_function = {\n",
|
||||||
|
" \"name\": \"check_code\",\n",
|
||||||
|
" \"description\": \"Call this whenever you need to verify if a booking code for a flight (also called 'flight code', 'booking reference', \\\n",
|
||||||
|
"or variations thereof) is valid.\",\n",
|
||||||
|
" \"parameters\": {\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"code\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The code that you or the user needs to verify\",\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" \"required\": [\"code\"],\n",
|
||||||
|
" \"additionalProperties\": False\n",
|
||||||
|
" }\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "3fa371c4-91ff-41ae-9b10-23fe617022d1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# List of tools:\n",
|
||||||
|
"\n",
|
||||||
|
"tools = [{\"type\": \"function\", \"function\": price_function}, {\"type\": \"function\", \"function\": booking_function}, {\"type\": \"function\", \"function\": codecheck_function}, {\"type\": \"function\", \"function\": artist_function}]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c00fb465-e448-4d68-9f18-88220fbaff76",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# speech recognition (STT) by Google\n",
|
||||||
|
"\n",
|
||||||
|
"r = sr.Recognizer()\n",
|
||||||
|
"\n",
|
||||||
|
"def speech_to_text():\n",
|
||||||
|
" try:\n",
|
||||||
|
" with sr.Microphone() as source:\n",
|
||||||
|
" r.adjust_for_ambient_noise(source, duration=0.2)\n",
|
||||||
|
" speech = r.listen(source, 10, 5) #timeout of 10 seconds, listen for 5\n",
|
||||||
|
" text = r.recognize_google(speech)\n",
|
||||||
|
" print(f\"STT heard: \\\"{text}\\\"\")\n",
|
||||||
|
" return text\n",
|
||||||
|
"\n",
|
||||||
|
" # sometimes, this STT fails. You'll see \"...\" as your input. Just try again even w/o re-starting Gradio.\n",
|
||||||
|
" except sr.RequestError as e:\n",
|
||||||
|
" print(f\"Could not request results; {0}\".format(e))\n",
|
||||||
|
" return \"…\"\n",
|
||||||
|
" except sr.UnknownValueError:\n",
|
||||||
|
" print(\"An unknown error occurred\")\n",
|
||||||
|
" return \"…\"\n",
|
||||||
|
" except sr.WaitTimeoutError:\n",
|
||||||
|
" print(\"Wait timed out\")\n",
|
||||||
|
" return \"…\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "505b585e-e9f9-4326-8455-184398bc82d1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# TTS by OpenAI\n",
|
||||||
|
"\n",
|
||||||
|
"def talker(message):\n",
|
||||||
|
" response = openai.audio.speech.create(\n",
|
||||||
|
" model=\"tts-1\",\n",
|
||||||
|
" voice=\"onyx\",\n",
|
||||||
|
" input=message)\n",
|
||||||
|
"\n",
|
||||||
|
" audio_stream = BytesIO(response.content)\n",
|
||||||
|
" output_filename = \"output_audio.mp3\"\n",
|
||||||
|
" with open(output_filename, \"wb\") as f:\n",
|
||||||
|
" f.write(audio_stream.read())\n",
|
||||||
|
"\n",
|
||||||
|
" # Play the generated audio\n",
|
||||||
|
" display(Audio(output_filename, autoplay=True))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4d34942a-f0c7-4835-ba07-746104a8c524",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def chat(history):\n",
|
||||||
|
" messages = [{\"role\": \"system\", \"content\": system_message}] + history\n",
|
||||||
|
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
|
||||||
|
" image = None\n",
|
||||||
|
" \n",
|
||||||
|
" if response.choices[0].finish_reason==\"tool_calls\":\n",
|
||||||
|
" message = response.choices[0].message\n",
|
||||||
|
" responses = handle_tool_call(message)[0]\n",
|
||||||
|
" image = handle_tool_call(message)[1]\n",
|
||||||
|
" messages.append(message)\n",
|
||||||
|
" for response in responses:\n",
|
||||||
|
" messages.append(response)\n",
|
||||||
|
" response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
|
||||||
|
" \n",
|
||||||
|
" reply = response.choices[0].message.content\n",
|
||||||
|
"\n",
|
||||||
|
" # comment in if you want the replies read out to you. Mind the price!\n",
|
||||||
|
" #talker(reply) #current cost: $0.015 per 1000 characters (not tokens!)\n",
|
||||||
|
" \n",
|
||||||
|
" history += [{\"role\": \"assistant\", \"content\": reply}]\n",
|
||||||
|
" \n",
|
||||||
|
" return history, image"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5413f7fb-c5f7-44c4-a63d-3d0465eb0af4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def handle_tool_call(message):\n",
|
||||||
|
" responses = []\n",
|
||||||
|
" image = None\n",
|
||||||
|
" \n",
|
||||||
|
" for tool_call in message.tool_calls:\n",
|
||||||
|
" arguments = json.loads(tool_call.function.arguments)\n",
|
||||||
|
" indata = arguments[list(arguments.keys())[0]] # works for now because we only have one argument in each of our functions\n",
|
||||||
|
" function_name = tool_call.function.name\n",
|
||||||
|
" if function_name == 'get_ticket_price':\n",
|
||||||
|
" outdata = get_ticket_price(indata)\n",
|
||||||
|
" input_name = \"destination city\"\n",
|
||||||
|
" output_name = \"price\"\n",
|
||||||
|
" elif function_name == 'book_flight':\n",
|
||||||
|
" outdata = book_flight(indata)\n",
|
||||||
|
" input_name = \"destination city\"\n",
|
||||||
|
" output_name = \"booking code\"\n",
|
||||||
|
" elif function_name == \"check_code\":\n",
|
||||||
|
" outdata = check_code(indata)\n",
|
||||||
|
" input_name = \"booking code\"\n",
|
||||||
|
" output_name = \"validity\"\n",
|
||||||
|
" elif function_name == \"artist\":\n",
|
||||||
|
" image = artist(indata)\n",
|
||||||
|
" outdata = f\"artistic rendition of {indata}\"\n",
|
||||||
|
" input_name = \"city\"\n",
|
||||||
|
" output_name = \"image\"\n",
|
||||||
|
"\n",
|
||||||
|
" responses.append({\n",
|
||||||
|
" \"role\": \"tool\",\n",
|
||||||
|
" \"content\": json.dumps({input_name: indata, output_name: outdata}),\n",
|
||||||
|
" \"tool_call_id\": tool_call.id\n",
|
||||||
|
" })\n",
|
||||||
|
"\n",
|
||||||
|
" return responses, image"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a5a31bcf-71d5-4537-a7bf-92385dc6e26e",
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"## Gradio with 'fancy' buttons. Claude explained this css business to me, and geeksforgeeks.\n",
|
||||||
|
"## see week2/community-contributions/day5_Careerhelper.ipynb for a much more competent version of this.\n",
|
||||||
|
"\n",
|
||||||
|
"with gr.Blocks(\n",
|
||||||
|
" css=\"\"\"\n",
|
||||||
|
" .red-button {\n",
|
||||||
|
" background-color: darkred !important;\n",
|
||||||
|
" border-color: red !important;\n",
|
||||||
|
" }\n",
|
||||||
|
" .blue-button {\n",
|
||||||
|
" background-color: darkblue !important;\n",
|
||||||
|
" border-color: blue !important;\n",
|
||||||
|
" }\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
") as ui:\n",
|
||||||
|
" with gr.Row():\n",
|
||||||
|
" chatbot = gr.Chatbot(height=500, type=\"messages\")\n",
|
||||||
|
" image_output = gr.Image(height=500)\n",
|
||||||
|
" with gr.Row():\n",
|
||||||
|
" entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n",
|
||||||
|
" with gr.Row():\n",
|
||||||
|
" speak = gr.Button(value=\"Speak to our AI Assistant\", elem_classes=\"blue-button\")\n",
|
||||||
|
" clear = gr.Button(value=\"Clear Chat\", elem_classes=\"red-button\")\n",
|
||||||
|
"\n",
|
||||||
|
" def do_entry(message, history):\n",
|
||||||
|
" history += [{\"role\":\"user\", \"content\":message}]\n",
|
||||||
|
" return \"\", history\n",
|
||||||
|
"\n",
|
||||||
|
" def listen(history):\n",
|
||||||
|
" message = speech_to_text()\n",
|
||||||
|
" history += [{\"role\":\"user\", \"content\":message}]\n",
|
||||||
|
" return history\n",
|
||||||
|
"\n",
|
||||||
|
" entry.submit(do_entry, inputs=[entry, chatbot], outputs=[entry, chatbot]).then(\n",
|
||||||
|
" chat, inputs=chatbot, outputs=[chatbot, image_output]\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False)\n",
|
||||||
|
" \n",
|
||||||
|
" speak.click(listen, inputs=chatbot, outputs=chatbot, queue=False).then(\n",
|
||||||
|
" chat, inputs=chatbot, outputs=[chatbot, image_output]\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"ui.launch(inbrowser=True)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.11"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user