362 lines
12 KiB
Plaintext
362 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5d799d2a-6e58-4a83-b17a-dbbc40efdc39",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Project - Course Booking AI Asssistant\n",
|
|
"AI Customer Support Bot that \n",
|
|
"- Returns Prices\n",
|
|
"- Books Tickets\n",
|
|
"- Adds Information to Text File"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "b1ad9acd-a702-48a3-8ff5-d536bcac8030",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# imports\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import json\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"from openai import OpenAI\n",
|
|
"import gradio as gr"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "74adab0c-99b3-46cd-a79f-320a3e74138a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"OpenAI API Key exists and begins sk-proj-\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Initialization\n",
|
|
"\n",
|
|
"load_dotenv(override=True)\n",
|
|
"\n",
|
|
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
|
"if openai_api_key:\n",
|
|
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
|
|
"else:\n",
|
|
" print(\"OpenAI API Key not set\")\n",
|
|
" \n",
|
|
"MODEL = \"gpt-4o-mini\"\n",
|
|
"openai = OpenAI()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "8d3240a4-99c1-4c07-acaa-ecbb69ffd2e4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"system_message = \"You are a helpful assistant for an Online Course Platform called StudyAI. \"\n",
|
|
"system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n",
|
|
"system_message += \"Always be accurate. If you don't know the answer, say so.\"\n",
|
|
"system_message += \"If you are given a partial name, for example 'discrete' instead of 'discrete structures' \\\n",
|
|
"ask the user if they meant to say 'discrete structures', and then display the price. The user may also use \\\n",
|
|
"acronyms like 'PF' instead of programming fundamentals or 'OOP' to mean 'Object oriented programming'. \\\n",
|
|
"Clarify wh\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "9a1b8d5f-f893-477b-8396-ff7d697eb0c3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"course_prices = {\"programming fundamentals\": \"$19\", \"discrete structures\": \"$39\", \"operating systems\": \"$24\", \"object oriented programming\": \"$39\"}\n",
|
|
"\n",
|
|
"def get_course_price(course):\n",
|
|
" print(f\"Tool get_course_price called for {course}\")\n",
|
|
" course = course.lower()\n",
|
|
" return course_prices.get(course, \"Unknown\")\n",
|
|
"\n",
|
|
"def enroll_in_course(course):\n",
|
|
" print(f'Tool enroll_in_course_ called for {course}')\n",
|
|
" course_price = get_course_price(course)\n",
|
|
" if course_price != 'Unknown':\n",
|
|
" with open('enrolled_courses.txt', 'a') as file: \n",
|
|
" file.write(course + \"\\n\")\n",
|
|
" return 'Successfully enrolled in course'\n",
|
|
" else:\n",
|
|
" return 'Enrollment failed, no such course available'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "330d2b94-a8c5-4967-ace7-15d2cd52d7ae",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Tool get_course_price called for graph theory\n",
|
|
"Tool get_course_price called for discrete structures\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'$39'"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"get_course_price('graph theory')\n",
|
|
"get_course_price('discrete structures')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "5bb65830-fab8-45a7-bf43-7e52186915a0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"price_function = {\n",
|
|
" \"name\": \"get_course_price\",\n",
|
|
" \"description\": \"Get the price of a course. Call this whenever you need to know the course price, for example when a customer asks 'How much is a ticket for this course?'\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"course\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The course that the customer wants to purchase\",\n",
|
|
" },\n",
|
|
" },\n",
|
|
" \"required\": [\"course\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
"}\n",
|
|
"\n",
|
|
"enroll_function = {\n",
|
|
" \"name\": \"enroll_in_course\",\n",
|
|
" \"description\":\"Get the success status of course enrollment. Call whenever a customer wants to enroll in a course\\\n",
|
|
" for example, if they say 'I want to purchase this course' or 'I want to enroll in this course'\",\n",
|
|
" \"parameters\":{\n",
|
|
" \"type\":\"object\",\n",
|
|
" \"properties\":{\n",
|
|
" \"course\":{\n",
|
|
" \"type\":\"string\",\n",
|
|
" \"description\": \"The course that the customer wants to purchase\",\n",
|
|
" },\n",
|
|
" },\n",
|
|
" \"required\": [\"course\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" } \n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "08af86b9-3aaa-4b6b-bf7c-ee668ba1cbfe",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tools = [\n",
|
|
" {\"type\":\"function\",\"function\":price_function},\n",
|
|
" {\"type\":\"function\",\"function\":enroll_function}\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "482efc34-ff1f-4146-9570-58b4d59c3b2f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def chat(message,history):\n",
|
|
" messages = [{\"role\":\"system\",\"content\":system_message}] + history + [{\"role\":\"user\",\"content\":message}]\n",
|
|
" response = openai.chat.completions.create(model=MODEL,messages=messages,tools=tools)\n",
|
|
"\n",
|
|
" if response.choices[0].finish_reason == \"tool_calls\":\n",
|
|
" message = response.choices[0].message\n",
|
|
" messages.append(message)\n",
|
|
" for tool_call in message.tool_calls:\n",
|
|
" messages.append(handle_tool_call(tool_call))\n",
|
|
" response = openai.chat.completions.create(model=MODEL,messages=messages)\n",
|
|
"\n",
|
|
" return response.choices[0].message.content"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "f725b4fb-d477-4d7d-80b5-5d70e1b25a86",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# We have to write that function handle_tool_call:\n",
|
|
"\n",
|
|
"def handle_tool_call(tool_call):\n",
|
|
" function = tool_call.function.name\n",
|
|
" arguments = json.loads(tool_call.function.arguments)\n",
|
|
" match function:\n",
|
|
" case 'get_course_price':\n",
|
|
" course = arguments.get('course')\n",
|
|
" price = get_course_price(course)\n",
|
|
" return {\n",
|
|
" \"role\": \"tool\",\n",
|
|
" \"content\": json.dumps({\"course\": course,\"price\": price}),\n",
|
|
" \"tool_call_id\": tool_call.id\n",
|
|
" }\n",
|
|
" case 'enroll_in_course':\n",
|
|
" course = arguments.get('course')\n",
|
|
" status = enroll_in_course(course)\n",
|
|
" return {\n",
|
|
" \"role\": \"tool\",\n",
|
|
" \"content\": json.dumps({\"course\": course, \"status\": status}),\n",
|
|
" \"tool_call_id\": tool_call.id\n",
|
|
" }\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "c446272a-9ce1-4ffd-9bc8-483d782810b4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"* Running on local URL: http://127.0.0.1:7864\n",
|
|
"\n",
|
|
"To create a public link, set `share=True` in `launch()`.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": []
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Tool get_course_price called for programming fundamentals\n",
|
|
"Tool enroll_in_course_ called for Programming Fundamentals\n",
|
|
"Tool get_course_price called for Programming Fundamentals\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Traceback (most recent call last):\n",
|
|
" File \"C:\\Users\\92310\\anaconda3\\envs\\llms\\Lib\\site-packages\\gradio\\queueing.py\", line 625, in process_events\n",
|
|
" response = await route_utils.call_process_api(\n",
|
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
" File \"C:\\Users\\92310\\anaconda3\\envs\\llms\\Lib\\site-packages\\gradio\\route_utils.py\", line 322, in call_process_api\n",
|
|
" output = await app.get_blocks().process_api(\n",
|
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
" File \"C:\\Users\\92310\\anaconda3\\envs\\llms\\Lib\\site-packages\\gradio\\blocks.py\", line 2096, in process_api\n",
|
|
" result = await self.call_function(\n",
|
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
" File \"C:\\Users\\92310\\anaconda3\\envs\\llms\\Lib\\site-packages\\gradio\\blocks.py\", line 1641, in call_function\n",
|
|
" prediction = await fn(*processed_input)\n",
|
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
" File \"C:\\Users\\92310\\anaconda3\\envs\\llms\\Lib\\site-packages\\gradio\\utils.py\", line 857, in async_wrapper\n",
|
|
" response = await f(*args, **kwargs)\n",
|
|
" ^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
" File \"C:\\Users\\92310\\anaconda3\\envs\\llms\\Lib\\site-packages\\gradio\\chat_interface.py\", line 862, in _submit_fn\n",
|
|
" response = await anyio.to_thread.run_sync(\n",
|
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
" File \"C:\\Users\\92310\\anaconda3\\envs\\llms\\Lib\\site-packages\\anyio\\to_thread.py\", line 56, in run_sync\n",
|
|
" return await get_async_backend().run_sync_in_worker_thread(\n",
|
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
" File \"C:\\Users\\92310\\anaconda3\\envs\\llms\\Lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 2461, in run_sync_in_worker_thread\n",
|
|
" return await future\n",
|
|
" ^^^^^^^^^^^^\n",
|
|
" File \"C:\\Users\\92310\\anaconda3\\envs\\llms\\Lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 962, in run\n",
|
|
" result = context.run(func, *args)\n",
|
|
" ^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
" File \"C:\\Users\\92310\\AppData\\Local\\Temp\\ipykernel_3348\\1161680098.py\", line 9, in chat\n",
|
|
" messages.append(handle_tool_call(tool_call))\n",
|
|
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
" File \"C:\\Users\\92310\\AppData\\Local\\Temp\\ipykernel_3348\\1187326431.py\", line 17, in handle_tool_call\n",
|
|
" status = enroll_in_course(course)\n",
|
|
" ^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
|
" File \"C:\\Users\\92310\\AppData\\Local\\Temp\\ipykernel_3348\\2541918318.py\", line 13, in enroll_in_course\n",
|
|
" file.write(course_name + \"\\n\")\n",
|
|
" ^^^^^^^^^^^\n",
|
|
"NameError: name 'course_name' is not defined\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"gr.ChatInterface(fn=chat,type=\"messages\").launch(inbrowser=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1fe714a3-f793-4c3b-b5aa-6c81b82aea1b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|