{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "a390d675", "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "import os\n", "import json\n", "import ollama\n", "from google import genai\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", "import gradio as gr\n", "from IPython.display import Markdown" ] }, { "cell_type": "code", "execution_count": null, "id": "55c9c2a2", "metadata": {}, "outputs": [], "source": [ "# Initialization\n", "\n", "load_dotenv(override=True)\n", "\n", "openai_api_key = os.getenv('OPENAI_API_KEY')\n", "gemini_api_key = os.getenv('GEMINI_API_KEY')\n", " \n", "OPENAI_MODEL = 'gpt-4o-mini'\n", "GEMINI_MODEL = 'gemini-2.5-flash' \n", "OLLAMA_MODEL = 'llama3.2'\n", "\n", "openai = OpenAI()\n", "gemini = genai.Client(api_key = gemini_api_key)\n", "\n", "tools = []\n", "gemini_tools = []\n", "\n", "cached_search = {\n", " ('delhi', 'delhi'): \"INR 0\",\n", "}\n", "\n", "convertion_rate_to_inr = {\n", " \"USD\": 85.81,\n", " \"EUR\": 100.25,\n", " \"GBP\": 115.90,\n", " \"AUD\": 56.43,\n", " \"CAD\": 62.70,\n", " \"SGD\": 67.05,\n", " \"CHF\": 107.79,\n", " \"JPY\": 0.5825,\n", " \"CNY\": 11.97,\n", " \"AED\": 23.37,\n", " \"NZD\": 51.56,\n", " \"SAR\": 22.88,\n", " \"QAR\": 23.58,\n", " \"OMR\": 222.89,\n", " \"BHD\": 227.62,\n", " \"KWD\": 280.90,\n", " \"MYR\": 20.18,\n", " \"THB\": 2.655,\n", " \"HKD\": 10.93,\n", " \"ZAR\": 4.79\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "68ec7079", "metadata": {}, "outputs": [], "source": [ "import requests\n", "from bs4 import BeautifulSoup\n", "\n", "\n", "headers = {\n", " \"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36\"\n", "}\n", "\n", "\n", "class Website:\n", " \"\"\"\n", " A utility class to represent a Website that we have scraped, now with links\n", " \"\"\"\n", "\n", " def __init__(self, url):\n", " self.url = url \n", " try:\n", " response = requests.get(url=self.url, headers=headers, timeout=10)\n", " response.raise_for_status()\n", " self.body = response.content \n", " except requests.RequestException as e:\n", " print(f\"Failed to fetch {self.url}: {e}\")\n", " self.body = b\"\"\n", " self.title = \"Failed to load\"\n", " self.text = \"\"\n", " self.links = []\n", " return\n", " soup = BeautifulSoup(self.body, 'html.parser')\n", " self.title = soup.title.string if soup.title else \"No title found\"\n", " if soup.body:\n", " for irrelevant in soup.body(['script', 'style', 'img', 'input']):\n", " irrelevant.decompose()\n", " self.text = soup.body.get_text(separator=\"\\n\", strip=True)\n", " else:\n", " self.text = \"\" \n", " links = [link.get('href') for link in soup.find_all('a')]\n", " self.links = [link for link in links if link]\n", "\n", " def get_content(self):\n", " return f\"Webpage Title:\\n{self.title}\\nWebpage Contents:\\n{self.text}\\n\\n\"\n", " \n" ] }, { "cell_type": "code", "execution_count": null, "id": "812ddcef", "metadata": {}, "outputs": [], "source": [ "from googleapiclient.discovery import build\n", "from googleapiclient.errors import HttpError\n", "\n", "class GoogleSearch:\n", " def __init__(self, api_key=None, cse_id=None):\n", " \"\"\"\n", " Initialize the Google Search Tool\n", " \n", " Args:\n", " api_key: Your Google API key (or set GOOGLE_API_KEY env var)\n", " cse_id: Your Custom Search Engine ID (or set GOOGLE_CSE_ID env var)\n", " \"\"\"\n", " self.api_key = api_key or os.getenv('GOOGLE_SEARCH_KEY')\n", " self.cse_id = cse_id or os.getenv('GOOGLE_CSE_ID')\n", "\n", " if not self.api_key:\n", " raise ValueError(\"API key is required. Set GOOGLE_API_KEY env var or pass api_key parameter\")\n", " if not self.cse_id:\n", " raise ValueError(\"CSE ID is required. Set GOOGLE_CSE_ID env var or pass cse_id parameter\")\n", " \n", " self.service = build(\"customsearch\", \"v1\", developerKey=self.api_key)\n", " \n", " def search(self, query: str, num_result: int=10, start_index: int=1):\n", " \"\"\"\n", " Perform a Google Custom Search\n", " \n", " Args:\n", " query: Search query string\n", " num_results: Number of results to return (1-10)\n", " start_index: Starting index for results (for pagination)\n", " \n", " Returns:\n", " dict: Search results or None if error\n", " \"\"\"\n", " try:\n", " res = self.service.cse().list(\n", " q=query,\n", " cx=self.cse_id,\n", " num=min(num_result, 10),\n", " start=start_index\n", " ).execute()\n", "\n", " return self._parse_results(res)\n", " except HttpError as e:\n", " print(f\"HTTP Error: {e}\")\n", " return None\n", " except Exception as e:\n", " print(f\"Unexpected error: {e}\")\n", " return None\n", " \n", " def _parse_results(self, raw_res):\n", " \"\"\"Parse raw API response into clean format\"\"\"\n", " if \"items\" not in raw_res:\n", " return {\n", " 'total_results': 0,\n", " 'results': [],\n", " 'search_info': raw_res.get('searchInformation', {})\n", " }\n", " \n", " parsed_items = []\n", " for item in raw_res[\"items\"]:\n", " parsed_item = {\n", " \"title\": item.get(\"title\", ''),\n", " \"link\": item.get(\"link\", ''),\n", " \"snippet\": item.get(\"snippet\", ''),\n", " \"display_link\": item.get(\"display_link\", ''),\n", " 'formatted_url': item.get('formattedUrl', '')\n", " }\n", "\n", " parsed_items.append(parsed_item)\n", " \n", " return {\n", " 'total_results': int(raw_res.get('searchInformation', {}).get('totalResults', '0')),\n", " 'results': parsed_items,\n", " 'search_info': raw_res.get('searchInformation', {})\n", " }\n", " \n", " def compile_search_pages(self, query: str, num_result: int = 10, start_index: int=1):\n", " \"\"\"\n", " Compiles a list of results from multiple search pages for a given query\n", "\n", " Args:\n", " query: Search query string\n", " num_results: Number of results to return (1-10)\n", " start_index: Starting index for results (for pagination)\n", " \n", " Returns:\n", " str: Concatenated results from all search pages for the given query\n", " \"\"\"\n", "\n", " result = \"\"\n", "\n", " search_res = self.search(query=query, num_result=num_result, start_index=start_index)\n", "\n", " print(search_res)\n", "\n", " for item in search_res['results']:\n", " print(item.get('title'))\n", " result += f\"\\n\\nTitle: {item.get('title', '')}\\n\"\n", " result += Website(item.get('link', '')).get_content()\n", "\n", " print(result)\n", "\n", " return result\n", "\n", "google_search = GoogleSearch()" ] }, { "cell_type": "code", "execution_count": null, "id": "857e77f8", "metadata": {}, "outputs": [], "source": [ "# google_search.compile_search_pages('flight ticket price from delhi to chandigarh', num_result=4)" ] }, { "cell_type": "code", "execution_count": null, "id": "ec5cf817", "metadata": {}, "outputs": [], "source": [ "system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n", "system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n", "system_message += \"Always be accurate. If you don't know the answer, say so.\"\n", "system_message += \"Always ask the user about the departure point in case it asks about the price and departure is not mentioned.\"" ] }, { "cell_type": "code", "execution_count": null, "id": "b829f882", "metadata": {}, "outputs": [], "source": [ "def analyze_result_for_price(result: str, source: str, model: str):\n", " print(\"Analyze web results: \", source, model)\n", "\n", " system_prompt = \"You are an assistant that analyzes the contents of several relevant pages from a search query.\"\n", " system_prompt = \"Provide the lowest price, highest price and average price for one way and round trips.\"\n", " system_prompt += \"Always return the price in INR. If you are not sure about the conversion rate, only then use the following conversion rates:\"\n", " system_prompt += f\"{convertion_rate_to_inr} for conversion rates. Interpret the given conversion rate as for example:\"\n", " system_prompt += \"1 USD to INR = 85.81. Return result in Markdown\"\n", " \n", " if source == 'ollama':\n", " model_to_use = model if model else OLLAMA_MODEL\n", "\n", " print(f\"Using model: {model_to_use}\\n\\n\")\n", "\n", " try:\n", " response = ollama.chat(\n", " model=model_to_use, \n", " messages=[\n", " {\"role\":\"system\", \"content\": system_prompt},\n", " {\"role\": \"user\", \"content\": result}\n", " ],\n", " )\n", " \n", " result = response['message']['content']\n", " return result\n", " except Exception as e:\n", " print(f\"An error occurred during the API call: {e}\")\n", " return None\n", " elif source == 'openai':\n", " try:\n", " response = openai.chat.completions.create(\n", " model=OPENAI_MODEL,\n", " messages=[\n", " {\"role\":\"system\", \"content\": system_prompt},\n", " {\"role\":\"user\", \"content\": result}\n", " ],\n", " \n", " )\n", "\n", " result = response.choices[0].message.content\n", " return result\n", " except Exception as e:\n", " print(f\"An error occurred during the API call: {e}\")\n", " return None\n", " elif source == 'gemini':\n", " try:\n", " response = gemini.models.generate_content(\n", " model=GEMINI_MODEL,\n", " contents=f\"{system_prompt}\\n\\n{result}\"\n", " )\n", "\n", " result = response.text\n", " return result\n", " except Exception as e:\n", " print(f\"An error occurred during the API call: {e}\")\n", " return None\n", " else:\n", " print(\"Source not supported\")" ] }, { "cell_type": "code", "execution_count": null, "id": "eb2ba65a", "metadata": {}, "outputs": [], "source": [ "def get_ticket_price(destination_city, departure_city, source=\"openai\", model=\"\"):\n", " if not destination_city or not departure_city:\n", " return \"Error: Both destination and departure cities are required\"\n", " \n", " print(f\"Tool get_ticket_price called for {destination_city} from {departure_city}\")\n", " print(\"get_ticket_price: \", model)\n", "\n", " dest = destination_city.lower()\n", " dept = departure_city.lower()\n", "\n", " cache_key = (dest, dept)\n", "\n", " if cache_key not in cached_search:\n", " try:\n", " query = f'flight ticket price from {dept} to {dest}' \n", " results = google_search.compile_search_pages(query=query, num_result=10) \n", " \n", " if results: # Check if results is not empty\n", " cached_search[cache_key] = results\n", " else:\n", " return \"Error: No search results found\"\n", " except Exception as e:\n", " print(f\"Error during search: {e}\")\n", " return f\"Error: Unable to fetch flight prices - {str(e)}\"\n", " else:\n", " results = cached_search[cache_key]\n", "\n", " try:\n", " return analyze_result_for_price(results, source, model)\n", " except Exception as e:\n", " print(f\"Error analyzing results: {e}\")\n", " return f\"Error: Unable to analyze price data - {str(e)}\"\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b9c64fac", "metadata": {}, "outputs": [], "source": [ "# Markdown(get_ticket_price('New York', 'London', \"gemini\", \"\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "32440830", "metadata": {}, "outputs": [], "source": [ "price_function = {\n", " \"name\": \"get_ticket_price\",\n", " \"description\": \"Get the current flight ticket price between two cities. Call this whenever you need to know flight prices, for example when a customer asks 'How much is a ticket from Delhi to Mumbai?', 'What's the flight cost to Chandigarh?', or 'Show me ticket prices for travel between these cities'. This function searches for real-time flight pricing information from multiple sources.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to travel to (e.g., 'Mumbai', 'Delhi', 'Chandigarh')\",\n", " },\n", " \"departure_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to travel from (e.g., 'Delhi', 'Mumbai', 'Bangalore')\",\n", " },\n", " \"source\": {\n", " \"type\": \"string\",\n", " \"description\": \"The AI model source to use for price analysis (optional, defaults to 'openai')\",\n", " \"default\": \"openai\"\n", " },\n", " \"model\": {\n", " \"type\": \"string\", \n", " \"description\": \"The specific AI model to use for analysis (optional, defaults to empty string)\",\n", " \"default\": \"\"\n", " }\n", " },\n", " \"required\": [\"destination_city\", \"departure_city\"],\n", " \"additionalProperties\": False\n", " }\n", "}\n", "\n", "tools.append({\"type\": \"function\", \"function\": price_function})" ] }, { "cell_type": "code", "execution_count": null, "id": "c670e697", "metadata": {}, "outputs": [], "source": [ "gemini_tools = [\n", " {\n", " \"function_declarations\": [\n", " {\n", " \"name\": \"get_ticket_price\",\n", " \"description\": \"Get the current flight ticket price between two cities. Call this whenever you need to know flight prices, for example when a customer asks 'How much is a ticket from Delhi to Mumbai?', 'What's the flight cost to Chandigarh?', or 'Show me ticket prices for travel between these cities'. This function searches for real-time flight pricing information from multiple sources.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to travel to (e.g., 'Mumbai', 'Delhi', 'Chandigarh')\"\n", " },\n", " \"departure_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to travel from (e.g., 'Delhi', 'Mumbai', 'Bangalore')\"\n", " },\n", " \"source\": {\n", " \"type\": \"string\",\n", " \"description\": \"The AI model source to use for price analysis (optional, defaults to 'openai')\"\n", " },\n", " \"model\": {\n", " \"type\": \"string\",\n", " \"description\": \"The specific AI model to use for analysis (optional, defaults to empty string)\"\n", " }\n", " },\n", " \"required\": [\"destination_city\", \"departure_city\"]\n", " }\n", " }\n", " ]\n", " }\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "2c608a19", "metadata": {}, "outputs": [], "source": [ "def handle_tool_call(message, model):\n", "\n", " tool_call = message.tool_calls[0]\n", " arguments = json.loads(tool_call.function.arguments)\n", " print(tool_call)\n", " if tool_call.function.name == \"get_ticket_price\":\n", " dest_city = arguments.get(\"destination_city\", '')\n", " dept_city = arguments.get(\"departure_city\",'')\n", " price = get_ticket_price(dest_city, dept_city, model, \"\")\n", " return {\n", " \"role\": \"tool\",\n", " \"content\": json.dumps({\"destination_city\": dest_city,\"departure_city\": dept_city,\"price\": price}),\n", " \"tool_call_id\": tool_call.id\n", " }\n", " return None\n", "\n", "def handle_tool_call_gemini(response, model):\n", " tool_call = response.candidates[0].content.parts[0].function_call\n", " function_name = tool_call.name\n", " arguments = tool_call.args\n", " \n", " if function_name == \"get_ticket_price\":\n", " dest_city = arguments.get(\"destination_city\", \"\")\n", " dept_city = arguments.get(\"departure_city\", \"\")\n", " price = get_ticket_price(dest_city, dept_city, model, \"\")\n", " \n", " return {\n", " \"tool_response\": {\n", " \"name\": function_name,\n", " \"response\": {\n", " \"content\": json.dumps({\n", " \"destination_city\": dest_city,\n", " \"departure_city\": dept_city,\n", " \"price\": price\n", " })\n", " }\n", " }\n", " }\n", " \n", " return None" ] }, { "cell_type": "code", "execution_count": null, "id": "81c56b0d", "metadata": {}, "outputs": [], "source": [ "def chat(history, model):\n", " MODEL_TO_USE = \"\"\n", " if model.lower() == 'openai':\n", " MODEL_TO_USE = OPENAI_MODEL\n", "\n", " messages = [{\"role\": \"system\", \"content\": system_message}] + history\n", " response = openai.chat.completions.create(model=MODEL_TO_USE, messages=messages, tools=tools)\n", "\n", " if response.choices[0].finish_reason==\"tool_calls\":\n", " message = response.choices[0].message\n", " response = handle_tool_call(message, model.lower())\n", " messages.append(message)\n", " messages.append(response)\n", " response = openai.chat.completions.create(model=MODEL_TO_USE, messages=messages, tools=tools)\n", " \n", " history += [{\"role\": \"assistant\", \"content\": response.choices[0].message.content}]\n", " elif model.lower() == 'gemini':\n", " MODEL_TO_USE = GEMINI_MODEL\n", " messages = [{\"role\": \"system\", \"content\": system_message}] + history\n", " response = gemini.models.generate_content(messages, tools=gemini_tools) \n", " candidate = response.candidates[0]\n", " \n", " if candidate.finish_reason == 'TOOL_CALL':\n", " messages.append(candidate.content)\n", " tool_response = handle_tool_call_gemini(response, model.lower())\n", " messages.append(tool_response)\n", " response = gemini.models.generate_content(messages, tools=gemini_tools)\n", " \n", " history += [{\"role\": \"model\", \"content\": response.text}]\n", " return history" ] }, { "cell_type": "code", "execution_count": null, "id": "3b2dac94", "metadata": {}, "outputs": [], "source": [ "with gr.Blocks() as ui:\n", " with gr.Row():\n", " chatbot = gr.Chatbot(height=500, type=\"messages\")\n", " with gr.Row():\n", " entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n", " model = gr.Dropdown([\"OpenAI\", \"Gemini\", \"Ollama\"], label=\"Choose a model\")\n", " with gr.Row():\n", " clear = gr.Button(\"Clear\")\n", "\n", " def do_entry(message, history):\n", " history += [{\"role\":\"user\", \"content\":message}]\n", " return \"\", history\n", "\n", " entry.submit(do_entry, inputs=[entry, chatbot], outputs=[entry, chatbot]).then(\n", " chat, inputs=[chatbot, model], outputs=[chatbot]\n", " )\n", " clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False)\n", "\n", "ui.launch(inbrowser=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "d50b03d4", "metadata": {}, "outputs": [], "source": [ "cached_search" ] }, { "cell_type": "code", "execution_count": null, "id": "8a7f06bf", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "llms", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 5 }