618 lines
24 KiB
Plaintext
618 lines
24 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a390d675",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# imports\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import json\n",
|
|
"import ollama\n",
|
|
"from google import genai\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"from openai import OpenAI\n",
|
|
"import gradio as gr\n",
|
|
"from IPython.display import Markdown"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "55c9c2a2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Initialization\n",
|
|
"\n",
|
|
"load_dotenv(override=True)\n",
|
|
"\n",
|
|
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
|
"gemini_api_key = os.getenv('GEMINI_API_KEY')\n",
|
|
" \n",
|
|
"OPENAI_MODEL = 'gpt-4o-mini'\n",
|
|
"GEMINI_MODEL = 'gemini-2.5-flash' \n",
|
|
"OLLAMA_MODEL = 'llama3.2'\n",
|
|
"\n",
|
|
"openai = OpenAI()\n",
|
|
"gemini = genai.Client(api_key = gemini_api_key)\n",
|
|
"\n",
|
|
"tools = []\n",
|
|
"gemini_tools = []\n",
|
|
"\n",
|
|
"cached_search = {\n",
|
|
" ('delhi', 'delhi'): \"INR 0\",\n",
|
|
"}\n",
|
|
"\n",
|
|
"convertion_rate_to_inr = {\n",
|
|
" \"USD\": 85.81,\n",
|
|
" \"EUR\": 100.25,\n",
|
|
" \"GBP\": 115.90,\n",
|
|
" \"AUD\": 56.43,\n",
|
|
" \"CAD\": 62.70,\n",
|
|
" \"SGD\": 67.05,\n",
|
|
" \"CHF\": 107.79,\n",
|
|
" \"JPY\": 0.5825,\n",
|
|
" \"CNY\": 11.97,\n",
|
|
" \"AED\": 23.37,\n",
|
|
" \"NZD\": 51.56,\n",
|
|
" \"SAR\": 22.88,\n",
|
|
" \"QAR\": 23.58,\n",
|
|
" \"OMR\": 222.89,\n",
|
|
" \"BHD\": 227.62,\n",
|
|
" \"KWD\": 280.90,\n",
|
|
" \"MYR\": 20.18,\n",
|
|
" \"THB\": 2.655,\n",
|
|
" \"HKD\": 10.93,\n",
|
|
" \"ZAR\": 4.79\n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "68ec7079",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import requests\n",
|
|
"from bs4 import BeautifulSoup\n",
|
|
"\n",
|
|
"\n",
|
|
"headers = {\n",
|
|
" \"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36\"\n",
|
|
"}\n",
|
|
"\n",
|
|
"\n",
|
|
"class Website:\n",
|
|
" \"\"\"\n",
|
|
" A utility class to represent a Website that we have scraped, now with links\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, url):\n",
|
|
" self.url = url \n",
|
|
" try:\n",
|
|
" response = requests.get(url=self.url, headers=headers, timeout=10)\n",
|
|
" response.raise_for_status()\n",
|
|
" self.body = response.content \n",
|
|
" except requests.RequestException as e:\n",
|
|
" print(f\"Failed to fetch {self.url}: {e}\")\n",
|
|
" self.body = b\"\"\n",
|
|
" self.title = \"Failed to load\"\n",
|
|
" self.text = \"\"\n",
|
|
" self.links = []\n",
|
|
" return\n",
|
|
" soup = BeautifulSoup(self.body, 'html.parser')\n",
|
|
" self.title = soup.title.string if soup.title else \"No title found\"\n",
|
|
" if soup.body:\n",
|
|
" for irrelevant in soup.body(['script', 'style', 'img', 'input']):\n",
|
|
" irrelevant.decompose()\n",
|
|
" self.text = soup.body.get_text(separator=\"\\n\", strip=True)\n",
|
|
" else:\n",
|
|
" self.text = \"\" \n",
|
|
" links = [link.get('href') for link in soup.find_all('a')]\n",
|
|
" self.links = [link for link in links if link]\n",
|
|
"\n",
|
|
" def get_content(self):\n",
|
|
" return f\"Webpage Title:\\n{self.title}\\nWebpage Contents:\\n{self.text}\\n\\n\"\n",
|
|
" \n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "812ddcef",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from googleapiclient.discovery import build\n",
|
|
"from googleapiclient.errors import HttpError\n",
|
|
"\n",
|
|
"class GoogleSearch:\n",
|
|
" def __init__(self, api_key=None, cse_id=None):\n",
|
|
" \"\"\"\n",
|
|
" Initialize the Google Search Tool\n",
|
|
" \n",
|
|
" Args:\n",
|
|
" api_key: Your Google API key (or set GOOGLE_API_KEY env var)\n",
|
|
" cse_id: Your Custom Search Engine ID (or set GOOGLE_CSE_ID env var)\n",
|
|
" \"\"\"\n",
|
|
" self.api_key = api_key or os.getenv('GOOGLE_SEARCH_KEY')\n",
|
|
" self.cse_id = cse_id or os.getenv('GOOGLE_CSE_ID')\n",
|
|
"\n",
|
|
" if not self.api_key:\n",
|
|
" raise ValueError(\"API key is required. Set GOOGLE_API_KEY env var or pass api_key parameter\")\n",
|
|
" if not self.cse_id:\n",
|
|
" raise ValueError(\"CSE ID is required. Set GOOGLE_CSE_ID env var or pass cse_id parameter\")\n",
|
|
" \n",
|
|
" self.service = build(\"customsearch\", \"v1\", developerKey=self.api_key)\n",
|
|
" \n",
|
|
" def search(self, query: str, num_result: int=10, start_index: int=1):\n",
|
|
" \"\"\"\n",
|
|
" Perform a Google Custom Search\n",
|
|
" \n",
|
|
" Args:\n",
|
|
" query: Search query string\n",
|
|
" num_results: Number of results to return (1-10)\n",
|
|
" start_index: Starting index for results (for pagination)\n",
|
|
" \n",
|
|
" Returns:\n",
|
|
" dict: Search results or None if error\n",
|
|
" \"\"\"\n",
|
|
" try:\n",
|
|
" res = self.service.cse().list(\n",
|
|
" q=query,\n",
|
|
" cx=self.cse_id,\n",
|
|
" num=min(num_result, 10),\n",
|
|
" start=start_index\n",
|
|
" ).execute()\n",
|
|
"\n",
|
|
" return self._parse_results(res)\n",
|
|
" except HttpError as e:\n",
|
|
" print(f\"HTTP Error: {e}\")\n",
|
|
" return None\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Unexpected error: {e}\")\n",
|
|
" return None\n",
|
|
" \n",
|
|
" def _parse_results(self, raw_res):\n",
|
|
" \"\"\"Parse raw API response into clean format\"\"\"\n",
|
|
" if \"items\" not in raw_res:\n",
|
|
" return {\n",
|
|
" 'total_results': 0,\n",
|
|
" 'results': [],\n",
|
|
" 'search_info': raw_res.get('searchInformation', {})\n",
|
|
" }\n",
|
|
" \n",
|
|
" parsed_items = []\n",
|
|
" for item in raw_res[\"items\"]:\n",
|
|
" parsed_item = {\n",
|
|
" \"title\": item.get(\"title\", ''),\n",
|
|
" \"link\": item.get(\"link\", ''),\n",
|
|
" \"snippet\": item.get(\"snippet\", ''),\n",
|
|
" \"display_link\": item.get(\"display_link\", ''),\n",
|
|
" 'formatted_url': item.get('formattedUrl', '')\n",
|
|
" }\n",
|
|
"\n",
|
|
" parsed_items.append(parsed_item)\n",
|
|
" \n",
|
|
" return {\n",
|
|
" 'total_results': int(raw_res.get('searchInformation', {}).get('totalResults', '0')),\n",
|
|
" 'results': parsed_items,\n",
|
|
" 'search_info': raw_res.get('searchInformation', {})\n",
|
|
" }\n",
|
|
" \n",
|
|
" def compile_search_pages(self, query: str, num_result: int = 10, start_index: int=1):\n",
|
|
" \"\"\"\n",
|
|
" Compiles a list of results from multiple search pages for a given query\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" query: Search query string\n",
|
|
" num_results: Number of results to return (1-10)\n",
|
|
" start_index: Starting index for results (for pagination)\n",
|
|
" \n",
|
|
" Returns:\n",
|
|
" str: Concatenated results from all search pages for the given query\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" result = \"\"\n",
|
|
"\n",
|
|
" search_res = self.search(query=query, num_result=num_result, start_index=start_index)\n",
|
|
"\n",
|
|
" print(search_res)\n",
|
|
"\n",
|
|
" for item in search_res['results']:\n",
|
|
" print(item.get('title'))\n",
|
|
" result += f\"\\n\\nTitle: {item.get('title', '')}\\n\"\n",
|
|
" result += Website(item.get('link', '')).get_content()\n",
|
|
"\n",
|
|
" print(result)\n",
|
|
"\n",
|
|
" return result\n",
|
|
"\n",
|
|
"google_search = GoogleSearch()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "857e77f8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# google_search.compile_search_pages('flight ticket price from delhi to chandigarh', num_result=4)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ec5cf817",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n",
|
|
"system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n",
|
|
"system_message += \"Always be accurate. If you don't know the answer, say so.\"\n",
|
|
"system_message += \"Always ask the user about the departure point in case it asks about the price and departure is not mentioned.\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b829f882",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def analyze_result_for_price(result: str, source: str, model: str):\n",
|
|
" print(\"Analyze web results: \", source, model)\n",
|
|
"\n",
|
|
" system_prompt = \"You are an assistant that analyzes the contents of several relevant pages from a search query.\"\n",
|
|
" system_prompt = \"Provide the lowest price, highest price and average price for one way and round trips.\"\n",
|
|
" system_prompt += \"Always return the price in INR. If you are not sure about the conversion rate, only then use the following conversion rates:\"\n",
|
|
" system_prompt += f\"{convertion_rate_to_inr} for conversion rates. Interpret the given conversion rate as for example:\"\n",
|
|
" system_prompt += \"1 USD to INR = 85.81. Return result in Markdown\"\n",
|
|
" \n",
|
|
" if source == 'ollama':\n",
|
|
" model_to_use = model if model else OLLAMA_MODEL\n",
|
|
"\n",
|
|
" print(f\"Using model: {model_to_use}\\n\\n\")\n",
|
|
"\n",
|
|
" try:\n",
|
|
" response = ollama.chat(\n",
|
|
" model=model_to_use, \n",
|
|
" messages=[\n",
|
|
" {\"role\":\"system\", \"content\": system_prompt},\n",
|
|
" {\"role\": \"user\", \"content\": result}\n",
|
|
" ],\n",
|
|
" )\n",
|
|
" \n",
|
|
" result = response['message']['content']\n",
|
|
" return result\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"An error occurred during the API call: {e}\")\n",
|
|
" return None\n",
|
|
" elif source == 'openai':\n",
|
|
" try:\n",
|
|
" response = openai.chat.completions.create(\n",
|
|
" model=OPENAI_MODEL,\n",
|
|
" messages=[\n",
|
|
" {\"role\":\"system\", \"content\": system_prompt},\n",
|
|
" {\"role\":\"user\", \"content\": result}\n",
|
|
" ],\n",
|
|
" \n",
|
|
" )\n",
|
|
"\n",
|
|
" result = response.choices[0].message.content\n",
|
|
" return result\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"An error occurred during the API call: {e}\")\n",
|
|
" return None\n",
|
|
" elif source == 'gemini':\n",
|
|
" try:\n",
|
|
" response = gemini.models.generate_content(\n",
|
|
" model=GEMINI_MODEL,\n",
|
|
" contents=f\"{system_prompt}\\n\\n{result}\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" result = response.text\n",
|
|
" return result\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"An error occurred during the API call: {e}\")\n",
|
|
" return None\n",
|
|
" else:\n",
|
|
" print(\"Source not supported\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "eb2ba65a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_ticket_price(destination_city, departure_city, source=\"openai\", model=\"\"):\n",
|
|
" if not destination_city or not departure_city:\n",
|
|
" return \"Error: Both destination and departure cities are required\"\n",
|
|
" \n",
|
|
" print(f\"Tool get_ticket_price called for {destination_city} from {departure_city}\")\n",
|
|
" print(\"get_ticket_price: \", model)\n",
|
|
"\n",
|
|
" dest = destination_city.lower()\n",
|
|
" dept = departure_city.lower()\n",
|
|
"\n",
|
|
" cache_key = (dest, dept)\n",
|
|
"\n",
|
|
" if cache_key not in cached_search:\n",
|
|
" try:\n",
|
|
" query = f'flight ticket price from {dept} to {dest}' \n",
|
|
" results = google_search.compile_search_pages(query=query, num_result=10) \n",
|
|
" \n",
|
|
" if results: # Check if results is not empty\n",
|
|
" cached_search[cache_key] = results\n",
|
|
" else:\n",
|
|
" return \"Error: No search results found\"\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error during search: {e}\")\n",
|
|
" return f\"Error: Unable to fetch flight prices - {str(e)}\"\n",
|
|
" else:\n",
|
|
" results = cached_search[cache_key]\n",
|
|
"\n",
|
|
" try:\n",
|
|
" return analyze_result_for_price(results, source, model)\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error analyzing results: {e}\")\n",
|
|
" return f\"Error: Unable to analyze price data - {str(e)}\"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b9c64fac",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Markdown(get_ticket_price('New York', 'London', \"gemini\", \"\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "32440830",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"price_function = {\n",
|
|
" \"name\": \"get_ticket_price\",\n",
|
|
" \"description\": \"Get the current flight ticket price between two cities. Call this whenever you need to know flight prices, for example when a customer asks 'How much is a ticket from Delhi to Mumbai?', 'What's the flight cost to Chandigarh?', or 'Show me ticket prices for travel between these cities'. This function searches for real-time flight pricing information from multiple sources.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"destination_city\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The city that the customer wants to travel to (e.g., 'Mumbai', 'Delhi', 'Chandigarh')\",\n",
|
|
" },\n",
|
|
" \"departure_city\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The city that the customer wants to travel from (e.g., 'Delhi', 'Mumbai', 'Bangalore')\",\n",
|
|
" },\n",
|
|
" \"source\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The AI model source to use for price analysis (optional, defaults to 'openai')\",\n",
|
|
" \"default\": \"openai\"\n",
|
|
" },\n",
|
|
" \"model\": {\n",
|
|
" \"type\": \"string\", \n",
|
|
" \"description\": \"The specific AI model to use for analysis (optional, defaults to empty string)\",\n",
|
|
" \"default\": \"\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"destination_city\", \"departure_city\"],\n",
|
|
" \"additionalProperties\": False\n",
|
|
" }\n",
|
|
"}\n",
|
|
"\n",
|
|
"tools.append({\"type\": \"function\", \"function\": price_function})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c670e697",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"gemini_tools = [\n",
|
|
" {\n",
|
|
" \"function_declarations\": [\n",
|
|
" {\n",
|
|
" \"name\": \"get_ticket_price\",\n",
|
|
" \"description\": \"Get the current flight ticket price between two cities. Call this whenever you need to know flight prices, for example when a customer asks 'How much is a ticket from Delhi to Mumbai?', 'What's the flight cost to Chandigarh?', or 'Show me ticket prices for travel between these cities'. This function searches for real-time flight pricing information from multiple sources.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"destination_city\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The city that the customer wants to travel to (e.g., 'Mumbai', 'Delhi', 'Chandigarh')\"\n",
|
|
" },\n",
|
|
" \"departure_city\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The city that the customer wants to travel from (e.g., 'Delhi', 'Mumbai', 'Bangalore')\"\n",
|
|
" },\n",
|
|
" \"source\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The AI model source to use for price analysis (optional, defaults to 'openai')\"\n",
|
|
" },\n",
|
|
" \"model\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The specific AI model to use for analysis (optional, defaults to empty string)\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"destination_city\", \"departure_city\"]\n",
|
|
" }\n",
|
|
" }\n",
|
|
" ]\n",
|
|
" }\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2c608a19",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def handle_tool_call(message, model):\n",
|
|
"\n",
|
|
" tool_call = message.tool_calls[0]\n",
|
|
" arguments = json.loads(tool_call.function.arguments)\n",
|
|
" print(tool_call)\n",
|
|
" if tool_call.function.name == \"get_ticket_price\":\n",
|
|
" dest_city = arguments.get(\"destination_city\", '')\n",
|
|
" dept_city = arguments.get(\"departure_city\",'')\n",
|
|
" price = get_ticket_price(dest_city, dept_city, model, \"\")\n",
|
|
" return {\n",
|
|
" \"role\": \"tool\",\n",
|
|
" \"content\": json.dumps({\"destination_city\": dest_city,\"departure_city\": dept_city,\"price\": price}),\n",
|
|
" \"tool_call_id\": tool_call.id\n",
|
|
" }\n",
|
|
" return None\n",
|
|
"\n",
|
|
"def handle_tool_call_gemini(response, model):\n",
|
|
" tool_call = response.candidates[0].content.parts[0].function_call\n",
|
|
" function_name = tool_call.name\n",
|
|
" arguments = tool_call.args\n",
|
|
" \n",
|
|
" if function_name == \"get_ticket_price\":\n",
|
|
" dest_city = arguments.get(\"destination_city\", \"\")\n",
|
|
" dept_city = arguments.get(\"departure_city\", \"\")\n",
|
|
" price = get_ticket_price(dest_city, dept_city, model, \"\")\n",
|
|
" \n",
|
|
" return {\n",
|
|
" \"tool_response\": {\n",
|
|
" \"name\": function_name,\n",
|
|
" \"response\": {\n",
|
|
" \"content\": json.dumps({\n",
|
|
" \"destination_city\": dest_city,\n",
|
|
" \"departure_city\": dept_city,\n",
|
|
" \"price\": price\n",
|
|
" })\n",
|
|
" }\n",
|
|
" }\n",
|
|
" }\n",
|
|
" \n",
|
|
" return None"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "81c56b0d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def chat(history, model):\n",
|
|
" MODEL_TO_USE = \"\"\n",
|
|
" if model.lower() == 'openai':\n",
|
|
" MODEL_TO_USE = OPENAI_MODEL\n",
|
|
"\n",
|
|
" messages = [{\"role\": \"system\", \"content\": system_message}] + history\n",
|
|
" response = openai.chat.completions.create(model=MODEL_TO_USE, messages=messages, tools=tools)\n",
|
|
"\n",
|
|
" if response.choices[0].finish_reason==\"tool_calls\":\n",
|
|
" message = response.choices[0].message\n",
|
|
" response = handle_tool_call(message, model.lower())\n",
|
|
" messages.append(message)\n",
|
|
" messages.append(response)\n",
|
|
" response = openai.chat.completions.create(model=MODEL_TO_USE, messages=messages, tools=tools)\n",
|
|
" \n",
|
|
" history += [{\"role\": \"assistant\", \"content\": response.choices[0].message.content}]\n",
|
|
" elif model.lower() == 'gemini':\n",
|
|
" MODEL_TO_USE = GEMINI_MODEL\n",
|
|
" messages = [{\"role\": \"system\", \"content\": system_message}] + history\n",
|
|
" response = gemini.models.generate_content(messages, tools=gemini_tools) \n",
|
|
" candidate = response.candidates[0]\n",
|
|
" \n",
|
|
" if candidate.finish_reason == 'TOOL_CALL':\n",
|
|
" messages.append(candidate.content)\n",
|
|
" tool_response = handle_tool_call_gemini(response, model.lower())\n",
|
|
" messages.append(tool_response)\n",
|
|
" response = gemini.models.generate_content(messages, tools=gemini_tools)\n",
|
|
" \n",
|
|
" history += [{\"role\": \"model\", \"content\": response.text}]\n",
|
|
" return history"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3b2dac94",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"with gr.Blocks() as ui:\n",
|
|
" with gr.Row():\n",
|
|
" chatbot = gr.Chatbot(height=500, type=\"messages\")\n",
|
|
" with gr.Row():\n",
|
|
" entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n",
|
|
" model = gr.Dropdown([\"OpenAI\", \"Gemini\", \"Ollama\"], label=\"Choose a model\")\n",
|
|
" with gr.Row():\n",
|
|
" clear = gr.Button(\"Clear\")\n",
|
|
"\n",
|
|
" def do_entry(message, history):\n",
|
|
" history += [{\"role\":\"user\", \"content\":message}]\n",
|
|
" return \"\", history\n",
|
|
"\n",
|
|
" entry.submit(do_entry, inputs=[entry, chatbot], outputs=[entry, chatbot]).then(\n",
|
|
" chat, inputs=[chatbot, model], outputs=[chatbot]\n",
|
|
" )\n",
|
|
" clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False)\n",
|
|
"\n",
|
|
"ui.launch(inbrowser=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d50b03d4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cached_search"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8a7f06bf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "llms",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.13"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|