Merge pull request #1 from ed-donner/main

Update
This commit is contained in:
Abhinav M
2025-07-10 04:56:29 -05:00
committed by GitHub
30 changed files with 8718 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
<!-- Use this file to provide workspace-specific custom instructions to Copilot. For more details, visit https://code.visualstudio.com/docs/copilot/copilot-customization#_use-a-githubcopilotinstructionsmd-file -->
This is a Streamlit web application for clinical trial protocol summarization. Use Streamlit best practices for UI and Python for backend logic. Integrate with ClinicalTrials.gov v2 API for study search and OpenAI for summarization.

View File

@@ -0,0 +1,30 @@
updates.md
.env
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
venv/
ENV/
.streamlit/
.idea/
.vscode/
*.swp
*.swo
.DS_Store

View File

@@ -0,0 +1,66 @@
# Protocol Summarizer Webapp
A Streamlit web application for searching and summarizing clinical trial protocols from ClinicalTrials.gov using Large Language Models. This tool enables researchers and clinical professionals to quickly extract key information from clinical trial protocols.
## Features
- Search for clinical trials by keyword
- Display a list of studies with title and NCT number
- Select a study to summarize
- Fetch the protocol's brief summary from ClinicalTrials.gov API
- Automatically summarize the protocol using OpenAI's LLM
- Extract structured information like study design, population, interventions, and endpoints
## Installation
1. Clone this repository:
```sh
git clone https://github.com/albertoclemente/protocol_summarizer.git
cd protocol_summarizer/protocol_summarizer_webapp
```
2. Install dependencies:
```sh
pip install -r requirements.txt
```
3. Create a `.env` file in the project root with your OpenAI API key:
```
OPENAI_API_KEY=your_api_key_here
```
## Usage
1. Run the Streamlit app:
```sh
streamlit run app.py
```
2. In your browser:
- Enter a disease, condition, or keyword in the search box
- Select the number of results to display
- Click the "Search" button
- Select a study from the results
- Click "Summarize Protocol" to generate a structured summary
## Technical Details
- Uses ClinicalTrials.gov API v2 to retrieve study information
- Implements fallback methods to handle API changes or failures
- Extracts protocol brief summaries using reliable JSON parsing
- Generates structured summaries using OpenAI's GPT models
## Requirements
- Python 3.7+
- Streamlit
- Requests
- OpenAI Python library
- python-dotenv
## Contribution
Contributions are welcome! Please feel free to submit a Pull Request.
## License
MIT License

View File

@@ -0,0 +1,121 @@
import os
from dotenv import load_dotenv
import streamlit as st
import requests
from openai import OpenAI
load_dotenv()
st.title("Protocol Summarizer")
st.markdown("""
Search for clinical trials by keyword, select a study, and generate a protocol summary using an LLM.
""")
# Search input
# Show results only after user presses Enter
with st.form(key="search_form"):
query = st.text_input("Enter a disease, study title, or keyword:")
max_results = st.slider("Number of results", 1, 20, 5)
submitted = st.form_submit_button("Search")
@st.cache_data(show_spinner=False)
def search_clinical_trials(query, max_results=5):
if not query:
return []
url = f"https://clinicaltrials.gov/api/v2/studies?query.term={query}&pageSize={max_results}&format=json"
resp = requests.get(url)
studies = []
if resp.status_code == 200:
data = resp.json()
for study in data.get('studies', []):
nct = study.get('protocolSection', {}).get('identificationModule', {}).get('nctId', 'N/A')
title = study.get('protocolSection', {}).get('identificationModule', {}).get('officialTitle', 'N/A')
studies.append({'nct': nct, 'title': title})
return studies
results = search_clinical_trials(query, max_results) if query else []
if results:
st.subheader("Search Results")
for i, study in enumerate(results):
st.markdown(f"**{i+1}. {study['title']}** (NCT: {study['nct']})")
selected = st.number_input("Select study number to summarize", min_value=1, max_value=len(results), value=1)
selected_study = results[selected-1]
st.markdown(f"### Selected Study\n**{selected_study['title']}** (NCT: {selected_study['nct']})")
if st.button("Summarize Protocol"):
# Fetch the brief summary for the selected study
nct_id = selected_study['nct']
# Use the V2 API which we know works reliably
url = f"https://clinicaltrials.gov/api/v2/studies/{nct_id}?format=json"
with st.spinner("Fetching study details..."):
resp = requests.get(url)
brief = ""
if resp.status_code == 200:
try:
data = resp.json()
# V2 API has protocolSection at the root level
if 'protocolSection' in data:
desc_mod = data.get('protocolSection', {}).get('descriptionModule', {})
brief = desc_mod.get('briefSummary', '')
# If briefSummary is empty, try detailedDescription
if not brief:
brief = desc_mod.get('detailedDescription', '')
except Exception as e:
st.error(f"Error parsing study data: {e}")
# If API fails, try HTML scraping as a fallback
if not brief and resp.status_code != 200:
st.warning(f"API returned status code {resp.status_code}. Trying alternative method...")
html_url = f"https://clinicaltrials.gov/ct2/show/{nct_id}"
html_resp = requests.get(html_url)
if "Brief Summary:" in html_resp.text:
start = html_resp.text.find("Brief Summary:") + 15
excerpt = html_resp.text[start:start+1000]
# Clean up HTML
import re
excerpt = re.sub('<[^<]+?>', ' ', excerpt)
excerpt = re.sub('\\s+', ' ', excerpt)
brief = excerpt.strip()
if not brief:
st.error("No brief summary or detailed description found for this study.")
st.stop()
# Now we have the brief summary, send it to the LLM
openai = OpenAI()
def user_prompt_for_protocol_brief(brief_text):
return (
"Extract the following details from the clinical trial brief summary in markdown format with clear section headings (e.g., ## Study Design, ## Population, etc.):\n"
"- Study design\n"
"- Population\n"
"- Interventions\n"
"- Primary and secondary endpoints\n"
"- Study duration\n\n"
f"Brief summary text:\n{brief_text}"
)
system_prompt = "You are a clinical research assistant. Extract and list the requested protocol details in markdown format with clear section headings."
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_for_protocol_brief(brief)}
]
with st.spinner("Summarizing with LLM..."):
try:
response = openai.chat.completions.create(
model="gpt-4o-mini",
messages=messages
)
summary = response.choices[0].message.content
st.markdown(summary)
except Exception as e:
st.error(f"LLM call failed: {e}")
else:
if query:
st.info("No results found. Try a different keyword.")

View File

@@ -0,0 +1,4 @@
streamlit
openai
requests
python-dotenv

View File

@@ -0,0 +1,273 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "4e66a6eb-e44a-4dc3-bad7-82e27d45155d",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98bf393c-358e-4ee1-b15b-96dfec323734",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import requests\n",
"from dotenv import load_dotenv\n",
"from bs4 import BeautifulSoup\n",
"from IPython.display import Markdown, display\n",
"from openai import OpenAI"
]
},
{
"cell_type": "markdown",
"id": "f92034ed-a2e6-444a-8008-291ba3f80561",
"metadata": {},
"source": [
"# OpenAI API Key"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a084b35d-19e9-4b48-bb06-d2c9e4474b20",
"metadata": {},
"outputs": [],
"source": [
"# Load environment variables in a file called .env\n",
"\n",
"load_dotenv(override=True)\n",
"api_key = os.getenv('OPENAI_API_KEY')\n",
"\n",
"# Check the key\n",
"\n",
"if not api_key:\n",
" print(\"No API key was found - please head over to the troubleshooting notebook in this folder to identify & fix!\")\n",
"elif not api_key.startswith(\"sk-proj-\"):\n",
" print(\"An API key was found, but it doesn't start sk-proj-; please check you're using the right key - see troubleshooting notebook\")\n",
"elif api_key.strip() != api_key:\n",
" print(\"An API key was found, but it looks like it might have space or tab characters at the start or end - please remove them - see troubleshooting notebook\")\n",
"else:\n",
" print(\"API key found and looks good so far!\")"
]
},
{
"cell_type": "markdown",
"id": "32b35ea0-e4ca-492a-94af-822ec61468a0",
"metadata": {},
"source": [
"# About..."
]
},
{
"cell_type": "markdown",
"id": "c660b786-af88-4134-b958-ffbf7a7b2904",
"metadata": {},
"source": [
"In this project I use the code from day 1 for something I do at work. I'm a real estate appraiser and when I prepare a valuation for some real estate, I analyze the local market, and in particular the city where the property is located. I then gather economy-related information and create a report from it. I'm based in Poland, so the report is in Polish. Here, I want to ask the model to make such a report for me, using the official website of the city and its related Wikipedia article."
]
},
{
"cell_type": "markdown",
"id": "09f32b5a-4d0a-4fec-a2f8-5d323ca2745d",
"metadata": {},
"source": [
"# The Code"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0fb8fe1-f052-4426-8531-5520d5295807",
"metadata": {},
"outputs": [],
"source": [
"openai = OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a2cca4b-8cd0-4c1a-a01c-1da10199236c",
"metadata": {},
"outputs": [],
"source": [
"headers = {\n",
" \"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36\"\n",
"}\n",
"\n",
"class Website:\n",
"\n",
" def __init__(self, url):\n",
" \"\"\"\n",
" Create this Website object from the given url using the BeautifulSoup library\n",
" \"\"\"\n",
" self.url = url\n",
" response = requests.get(url, headers=headers)\n",
" soup = BeautifulSoup(response.content, 'html.parser')\n",
" self.title = soup.title.string if soup.title else \"No title found\"\n",
" for irrelevant in soup.body([\"script\", \"style\", \"img\", \"input\"]):\n",
" irrelevant.decompose()\n",
" self.text = soup.body.get_text(separator=\"\\n\", strip=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c73e91c8-5805-4c9f-9bbb-b4e9c1e7bf12",
"metadata": {},
"outputs": [],
"source": [
"system_prompt = \"\"\"You are an analyst and real estate appraiser who checks out the official websites \n",
"of cities as well as articles related to these cities on Wikipedia, searching the particular pages \n",
"of the official website and the Wikipedia article for economic data, in particular the \n",
"demographic structure of the city, its area, and how it's subdivided into built-up area, \n",
"rural area, forests, and so on, provided this kind of information is available. \n",
"The most important information you want to find is that related to the real estate market in the city, \n",
"but also the general economy of the city, so what kind of factories or companies there are, commerce, \n",
"business conditions, transportation, economic growth in recent years, and recent investments. \n",
"wealth of the inhabitants, and so on, depending on what kind of information is available on the website. \n",
"Combine the information found on the official website with the information found on Wikipedia, and in case\n",
"of discrepancies, the official website should take precedence. If any of the information is missing,\n",
"just omit it entirely and don't mention that it is missing, just don't write about it at all.\n",
"When you gather all the required information, create a comprehensive report presenting \n",
"the data in a clear way, using markdown, in tabular form where it makes sense. \n",
"The length of the report should be about 5000 characters. And one more thing, the report should be entirely \n",
"in Polish. \"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8015e8d-1655-4477-a111-aa8dd584f5eb",
"metadata": {},
"outputs": [],
"source": [
"def user_prompt_for(city, city_website, wiki_website):\n",
" user_prompt = f\"You are looking at the official website of the city {city}, and its wiki article.\"\n",
" user_prompt += f\"\\nThe contents of this website is as follows: \\\n",
"please provide a comprehensive report of economy-related data for the city of {city}, available on the \\\n",
"particular pages and subpages of its official website and Wikipedia in markdown. \\\n",
"Add tables if it makes sense for the data. The length of the report should be about 5000 characters. \\\n",
"The report should be in Polish.\\n\\n\"\n",
" user_prompt += city_website.text\n",
" user_prompt += wiki_website.text\n",
" return user_prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b55bd66b-e997-4d64-b5d5-679098013b9f",
"metadata": {},
"outputs": [],
"source": [
"def messages_for(city, city_website, wiki_website):\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": user_prompt_for(city, city_website, wiki_website)}\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5f1f218-d6a9-4a9e-be7e-b4f41e7647e5",
"metadata": {},
"outputs": [],
"source": [
"def report(url_official, url_wiki, city):\n",
" city_website = Website(url_official)\n",
" wiki_website = Website(url_wiki)\n",
" response = openai.chat.completions.create(\n",
" model = \"gpt-4o-mini\",\n",
" messages = messages_for(city, city_website, wiki_website)\n",
" )\n",
" return response.choices[0].message.content"
]
},
{
"cell_type": "markdown",
"id": "08b47ec7-d00f-44e4-bbe2-580c8efd88e5",
"metadata": {},
"source": [
"# Raw Result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "830f0746-08a7-43ae-bd40-78d4a4c5d3e5",
"metadata": {},
"outputs": [],
"source": [
"report(\"https://www.rudaslaska.pl/\", \"https://pl.wikipedia.org/wiki/Ruda_%C5%9Al%C4%85ska\", \"Ruda Śląska\")"
]
},
{
"cell_type": "markdown",
"id": "a3630ac4-c103-4b84-a1a2-c246a702346e",
"metadata": {},
"source": [
"# Polished Result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b89dd543-998d-4466-abd8-cc785118d3e4",
"metadata": {},
"outputs": [],
"source": [
"def display_report(url_official, url_wiki, city):\n",
" rep = report(url_official, url_wiki, city)\n",
" display(Markdown(rep))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "157926f3-ba67-4d4b-abbb-24a2dcd85a8b",
"metadata": {},
"outputs": [],
"source": [
"display_report(\"https://www.rudaslaska.pl/\", \"https://pl.wikipedia.org/wiki/Ruda_%C5%9Al%C4%85ska\", \"Ruda Śląska\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "727d2283-e74c-4e74-86f2-759b08f1427a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,271 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "032a76d2-a112-4c49-bd32-fe6c87f6ec19",
"metadata": {},
"source": [
"## Dota Game Assistant\n",
"\n",
"This script retrieves and summarizes information about a specified hero from `dotabuff.com` website"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04b24159-55d1-4eaf-bc19-474cec71cc3b",
"metadata": {},
"outputs": [],
"source": [
"!pip install selenium\n",
"!pip install webdriver-manager"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14d26510-6613-4c1a-a346-159d906d111c",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import requests\n",
"from dotenv import load_dotenv\n",
"from bs4 import BeautifulSoup\n",
"from IPython.display import Markdown, display\n",
"from openai import OpenAI"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9c8ea1e-8881-4f50-953d-ca7f462d8a32",
"metadata": {},
"outputs": [],
"source": [
"# Load environment variables in a file called .env\n",
"\n",
"load_dotenv(override=True)\n",
"api_key = os.getenv('OPENAI_API_KEY')\n",
"\n",
"# Check the key\n",
"\n",
"if not api_key:\n",
" print(\"No API key was found - please head over to the troubleshooting notebook in this folder to identify & fix!\")\n",
"elif not api_key.startswith(\"sk-proj-\"):\n",
" print(\"An API key was found, but it doesn't start sk-proj-; please check you're using the right key - see troubleshooting notebook\")\n",
"elif api_key.strip() != api_key:\n",
" print(\"An API key was found, but it looks like it might have space or tab characters at the start or end - please remove them - see troubleshooting notebook\")\n",
"else:\n",
" print(\"API key found and looks good so far!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02febcac-9a21-4322-b2ea-748972312165",
"metadata": {},
"outputs": [],
"source": [
"openai = OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb7dd822-962e-4b34-a743-c14809764e4a",
"metadata": {},
"outputs": [],
"source": [
"# A class to represent a Webpage\n",
"\n",
"# Some websites need you to use proper headers when fetching them:\n",
"headers = {\n",
" \"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36\"\n",
"}\n",
"\n",
"from selenium import webdriver\n",
"from selenium.webdriver.chrome.service import Service\n",
"from selenium.webdriver.chrome.options import Options\n",
"from selenium.webdriver.common.by import By\n",
"from selenium.webdriver.support.ui import WebDriverWait\n",
"from selenium.webdriver.support import expected_conditions as EC\n",
"from webdriver_manager.chrome import ChromeDriverManager\n",
"from bs4 import BeautifulSoup\n",
"\n",
"class Website:\n",
" def __init__(self, url, wait_time=10):\n",
" \"\"\"\n",
" Create this Website object from the given URL using Selenium and BeautifulSoup.\n",
" Uses headless Chrome to load JavaScript content.\n",
" \"\"\"\n",
" self.url = url\n",
"\n",
" # Configure headless Chrome\n",
" options = Options()\n",
" options.headless = True\n",
" options.add_argument(\"--disable-gpu\")\n",
" options.add_argument(\"--no-sandbox\")\n",
"\n",
" # Start the driver\n",
" service = Service(ChromeDriverManager().install())\n",
" driver = webdriver.Chrome(service=service, options=options)\n",
"\n",
" try:\n",
" driver.get(url)\n",
"\n",
" # Wait until body is loaded (you can tweak the wait condition)\n",
" WebDriverWait(driver, wait_time).until(\n",
" EC.presence_of_element_located((By.TAG_NAME, \"body\"))\n",
" )\n",
"\n",
" html = driver.page_source\n",
" soup = BeautifulSoup(html, \"html.parser\")\n",
"\n",
" self.title = soup.title.string.strip() if soup.title else \"No title found\"\n",
"\n",
" # Remove unwanted tags\n",
" for irrelevant in soup.body([\"script\", \"style\", \"img\", \"input\"]):\n",
" irrelevant.decompose()\n",
"\n",
" self.text = soup.body.get_text(separator=\"\\n\", strip=True)\n",
"\n",
" finally:\n",
" driver.quit()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d833fbb-0115-4d99-a4e9-464f27900eab",
"metadata": {},
"outputs": [],
"source": [
"class DotaWebsite:\n",
" def __init__(self, hero):\n",
" web = Website(\"https://www.dotabuff.com/heroes\" + \"/\" + hero)\n",
" self.title = web.title\n",
" self.text = web.text"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0a42c2b-c837-4d1b-b8f8-b2dbb8592a1a",
"metadata": {},
"outputs": [],
"source": [
"system_prompt = \"You are an game assistant that analyzes the contents of a website \\\n",
"and provides a short summary about facet selection, ability building, item building, best versus and worst versus, ignoring text that might be navigation related. \\\n",
"Respond in markdown.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c05843d-6373-4a76-8cca-9c716a6ca13a",
"metadata": {},
"outputs": [],
"source": [
"# A function that writes a User Prompt that asks for summaries of websites:\n",
"\n",
"def user_prompt_for(website):\n",
" user_prompt = f\"You are looking at a website titled {website.title}\"\n",
" user_prompt += \"\\nThe contents of this website is as follows; \\\n",
"please provide a short summary of provides a short summary about facet selection, ability building, item building, best versus and worst versus in markdown. \\\n",
"If it includes news or announcements, then summarize these too.\\n\\n\"\n",
" user_prompt += website.text\n",
" return user_prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0145eee1-39e2-4f00-89ec-7acc6e375972",
"metadata": {},
"outputs": [],
"source": [
"# See how this function creates exactly the format above\n",
"\n",
"def messages_for(website):\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": user_prompt_for(website)}\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f389c0-572a-476b-9b4e-719c0ef10abb",
"metadata": {},
"outputs": [],
"source": [
"# And now: call the OpenAI API. You will get very familiar with this!\n",
"\n",
"def summarize(hero):\n",
" website = DotaWebsite(hero)\n",
" response = openai.chat.completions.create(\n",
" model = \"gpt-4o-mini\",\n",
" messages = messages_for(website)\n",
" )\n",
" return response.choices[0].message.content"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fcb046b7-52a9-49ff-b7bc-d8f6c279df4c",
"metadata": {},
"outputs": [],
"source": [
"# A function to display this nicely in the Jupyter output, using markdown\n",
"\n",
"def display_summary(hero):\n",
" summary = summarize(hero)\n",
" display(Markdown(summary))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9befb685-2912-41a9-b2d9-ae33001494c0",
"metadata": {},
"outputs": [],
"source": [
"display_summary(\"axe\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf1bb1d9-0351-44fc-8ebf-91aa47a81b42",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,159 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "922bb144",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import requests\n",
"from dotenv import load_dotenv\n",
"from bs4 import BeautifulSoup\n",
"from IPython.display import Markdown, display\n",
"from openai import OpenAI"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "870bdcd9",
"metadata": {},
"outputs": [],
"source": [
"# Load environment variables in a file called .env\n",
"load_dotenv(override=True)\n",
"api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"\n",
"# Check the key\n",
"if not api_key:\n",
" print(\"No API key was found - please head over to the troubleshooting notebook in this folder to identify & fix!\")\n",
"elif not api_key.startswith(\"sk-proj-\"):\n",
" print(\"An API key was found, but it doesn't start sk-proj-; please check you're using the right key - see troubleshooting notebook\")\n",
"elif api_key.strip() != api_key:\n",
" print(\"An API key was found, but it looks like it might have space or tab characters at the start or end - please remove them - see troubleshooting notebook\")\n",
"else:\n",
" print(\"API key found and looks good so far!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6146102",
"metadata": {},
"outputs": [],
"source": [
"openai = OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f75573f",
"metadata": {},
"outputs": [],
"source": [
"class FinvizWebsite():\n",
" \"\"\"\n",
" Create this Website object from the given url using the BeautifulSoup library\n",
" \"\"\"\n",
" \n",
" def __init__(self, ticker):\n",
" self.ticker = ticker.upper()\n",
" self.url = f\"https://finviz.com/quote.ashx?t={self.ticker}&p=d&ty=ea\"\n",
" self.headers = {\n",
" \"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36\"\n",
" }\n",
" response = requests.get(self.url, headers=self.headers)\n",
" soup = BeautifulSoup(response.content, \"html.parser\")\n",
" self.title = soup.title.string if soup.title else \"No title found\"\n",
" self.table = soup.find(\"table\", class_=\"snapshot-table2\") "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42c7ced6",
"metadata": {},
"outputs": [],
"source": [
"def messages_for(website):\n",
" system_prompt = \"\"\"\n",
" You are a financial analysis assistant that analyzes the contents of HTML formated table.\n",
" and provides a summary of the stock's analysis with clear and professional language appropriate for financial research \n",
" with bulleted important list of **pros** and **cons** , ignoring text that might be navigation related. Repond in markdown.\n",
" \"\"\"\n",
" \n",
" user_prompt = f\"\"\"\n",
" You are looking at a website titled {website.title}.\\n\n",
" The contents of this website is as follows; please provide a summary of the stock's analysis from this website in markdown.\\n\\n\n",
" {website.table}\n",
" \"\"\"\n",
" \n",
" return [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": user_prompt}\n",
" ]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7bfaa6da",
"metadata": {},
"outputs": [],
"source": [
"def display_summary(ticker):\n",
" website = FinvizWebsite(ticker)\n",
" response = openai.chat.completions.create(\n",
" model = \"gpt-4o-mini\",\n",
" messages = messages_for(website)\n",
" )\n",
" summary = response.choices[0].message.content\n",
" display(Markdown(summary))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eeeff6f7",
"metadata": {},
"outputs": [],
"source": [
"display_summary(\"aapl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5aed2001",
"metadata": {},
"outputs": [],
"source": [
"display_summary(\"tsla\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,156 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "72a6552c-c837-4ced-b7c8-75a3d4cf777d",
"metadata": {},
"source": [
" <h2 style=\"color:#900;\">MAIL SUBJECT CREATION -</h2>\n",
"\n",
"<table style=\"margin: 0; text-align: left;\">\n",
" <tr>\n",
" <td style=\"width: 150px; height: 150px; vertical-align: middle;\">\n",
" <img src=\"../../important.jpg\" width=\"150\" height=\"150\" style=\"display: block;\" />\n",
" </td>\n",
" <td>\n",
" <h3 style=\"color:#900;\">Write something that will take the contents of an email, and will suggest an appropriate short subject line for the email. That's the kind of feature that might be built into a commercial email tool.</h3>\n",
" </td>\n",
" </tr>\n",
"</table>"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "76822a8b-d6e0-4dd9-a801-2d34bd104b7d",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import requests\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1a9de873-d24b-42fb-8f4a-a08f429050f5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"API key found and looks good so far!\n"
]
}
],
"source": [
"load_dotenv(override=True)\n",
"api_key = os.getenv('OPENAI_API_KEY')\n",
"\n",
"# Check the key\n",
"\n",
"if not api_key:\n",
" print(\"No API key was found - please head over to the troubleshooting notebook in this folder to identify & fix!\")\n",
"elif not api_key.startswith(\"sk-proj-\"):\n",
" print(\"An API key was found, but it doesn't start sk-proj-; please check you're using the right key - see troubleshooting notebook\")\n",
"elif api_key.strip() != api_key:\n",
" print(\"An API key was found, but it looks like it might have space or tab characters at the start or end - please remove them - see troubleshooting notebook\")\n",
"else:\n",
" print(\"API key found and looks good so far!\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "122af5d6-4727-4229-b85a-ea5246ff540c",
"metadata": {},
"outputs": [],
"source": [
"openai = OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b9a2c2c2-ac10-4019-aeef-2bfe6cc7b1f3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Subject: Missing API Logs for June 22nd: Scheduled Meeting to Address Issue\n"
]
}
],
"source": [
"system_prompt = \"You are an assistant which can generate a subject line as output by taking email of content as input. Subject line should be self explanatrory\"\n",
"user_prompt = \"\"\"\n",
" Below is the content of the text which I am giving as input\n",
" Mail Content - 'Hi Team,\n",
"\n",
"We have observed that the API logs for June 22nd between 6:00 AM and 12:00 PM are missing in Kibana.\n",
"\n",
"The SA team has confirmed that there were no errors reported on their end during this period.\n",
"\n",
"The DevOps team has verified that logs were being sent as expected.\n",
"\n",
"Upon checking the Fluentd pods, no errors were found.\n",
"\n",
"Logs were being shipped to td-agent as usual.\n",
"\n",
"No configuration changes or pod restarts were detected.\n",
"\n",
"We have also confirmed that no code changes were deployed from our side during this time.\n",
"\n",
"Bucket: api_application_log\n",
"Ticket\n",
"\n",
"We have scheduled a meeting with the SA and DevOps teams to restore the missing logs, as they are critical for our weekly report and analysis.'\n",
"\"\"\"\n",
"\n",
"# Step 2: Make the messages list\n",
"\n",
"messages = [ {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": user_prompt}] # fill this in\n",
"\n",
"# Step 3: Call OpenAI\n",
"\n",
"response = openai.chat.completions.create(\n",
" model = \"gpt-4o-mini\",\n",
" messages = messages\n",
" )\n",
"\n",
"# Step 4: print the result\n",
"\n",
"print(response.choices[0].message.content)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,459 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d15d8294-3328-4e07-ad16-8a03e9bbfdb9",
"metadata": {},
"source": [
"# Welcome to your first assignment!\n",
"\n",
"Instructions are below. Please give this a try, and look in the solutions folder if you get stuck (or feel free to ask me!)"
]
},
{
"cell_type": "markdown",
"id": "ada885d9-4d42-4d9b-97f0-74fbbbfe93a9",
"metadata": {},
"source": [
"<table style=\"margin: 0; text-align: left;\">\n",
" <tr>\n",
" <td style=\"width: 150px; height: 150px; vertical-align: middle;\">\n",
" <img src=\"../resources.jpg\" width=\"150\" height=\"150\" style=\"display: block;\" />\n",
" </td>\n",
" <td>\n",
" <h2 style=\"color:#f71;\">Just before we get to the assignment --</h2>\n",
" <span style=\"color:#f71;\">I thought I'd take a second to point you at this page of useful resources for the course. This includes links to all the slides.<br/>\n",
" <a href=\"https://edwarddonner.com/2024/11/13/llm-engineering-resources/\">https://edwarddonner.com/2024/11/13/llm-engineering-resources/</a><br/>\n",
" Please keep this bookmarked, and I'll continue to add more useful links there over time.\n",
" </span>\n",
" </td>\n",
" </tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"id": "6e9fa1fc-eac5-4d1d-9be4-541b3f2b3458",
"metadata": {},
"source": [
"# HOMEWORK EXERCISE ASSIGNMENT\n",
"\n",
"Upgrade the day 1 project to summarize a webpage to use an Open Source model running locally via Ollama rather than OpenAI\n",
"\n",
"You'll be able to use this technique for all subsequent projects if you'd prefer not to use paid APIs.\n",
"\n",
"**Benefits:**\n",
"1. No API charges - open-source\n",
"2. Data doesn't leave your box\n",
"\n",
"**Disadvantages:**\n",
"1. Significantly less power than Frontier Model\n",
"\n",
"## Recap on installation of Ollama\n",
"\n",
"Simply visit [ollama.com](https://ollama.com) and install!\n",
"\n",
"Once complete, the ollama server should already be running locally. \n",
"If you visit: \n",
"[http://localhost:11434/](http://localhost:11434/)\n",
"\n",
"You should see the message `Ollama is running`. \n",
"\n",
"If not, bring up a new Terminal (Mac) or Powershell (Windows) and enter `ollama serve` \n",
"And in another Terminal (Mac) or Powershell (Windows), enter `ollama pull llama3.2` \n",
"Then try [http://localhost:11434/](http://localhost:11434/) again.\n",
"\n",
"If Ollama is slow on your machine, try using `llama3.2:1b` as an alternative. Run `ollama pull llama3.2:1b` from a Terminal or Powershell, and change the code below from `MODEL = \"llama3.2\"` to `MODEL = \"llama3.2:1b\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e2a9393-7767-488e-a8bf-27c12dca35bd",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import requests\n",
"from bs4 import BeautifulSoup\n",
"from IPython.display import Markdown, display"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29ddd15d-a3c5-4f4e-a678-873f56162724",
"metadata": {},
"outputs": [],
"source": [
"# Constants\n",
"\n",
"OLLAMA_API = \"http://localhost:11434/api/chat\"\n",
"HEADERS = {\"Content-Type\": \"application/json\"}\n",
"MODEL = \"llama3.2\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dac0a679-599c-441f-9bf2-ddc73d35b940",
"metadata": {},
"outputs": [],
"source": [
"# Create a messages list using the same format that we used for OpenAI\n",
"\n",
"messages = [\n",
" {\"role\": \"user\", \"content\": \"Describe some of the business applications of Generative AI\"}\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7bb9c624-14f0-4945-a719-8ddb64f66f47",
"metadata": {},
"outputs": [],
"source": [
"payload = {\n",
" \"model\": MODEL,\n",
" \"messages\": messages,\n",
" \"stream\": False\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "479ff514-e8bd-4985-a572-2ea28bb4fa40",
"metadata": {},
"outputs": [],
"source": [
"# Let's just make sure the model is loaded\n",
"\n",
"!ollama pull llama3.2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42b9f644-522d-4e05-a691-56e7658c0ea9",
"metadata": {},
"outputs": [],
"source": [
"# If this doesn't work for any reason, try the 2 versions in the following cells\n",
"# And double check the instructions in the 'Recap on installation of Ollama' at the top of this lab\n",
"# And if none of that works - contact me!\n",
"\n",
"response = requests.post(OLLAMA_API, json=payload, headers=HEADERS)\n",
"print(response.json()['message']['content'])"
]
},
{
"cell_type": "markdown",
"id": "6a021f13-d6a1-4b96-8e18-4eae49d876fe",
"metadata": {},
"source": [
"# Introducing the ollama package\n",
"\n",
"And now we'll do the same thing, but using the elegant ollama python package instead of a direct HTTP call.\n",
"\n",
"Under the hood, it's making the same call as above to the ollama server running at localhost:11434"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7745b9c4-57dc-4867-9180-61fa5db55eb8",
"metadata": {},
"outputs": [],
"source": [
"import ollama\n",
"\n",
"response = ollama.chat(model=MODEL, messages=messages)\n",
"print(response['message']['content'])"
]
},
{
"cell_type": "markdown",
"id": "a4704e10-f5fb-4c15-a935-f046c06fb13d",
"metadata": {},
"source": [
"## Alternative approach - using OpenAI python library to connect to Ollama"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23057e00-b6fc-4678-93a9-6b31cb704bff",
"metadata": {},
"outputs": [],
"source": [
"# There's actually an alternative approach that some people might prefer\n",
"# You can use the OpenAI client python library to call Ollama:\n",
"\n",
"from openai import OpenAI\n",
"ollama_via_openai = OpenAI(base_url='http://localhost:11434/v1', api_key='ollama')\n",
"\n",
"response = ollama_via_openai.chat.completions.create(\n",
" model=MODEL,\n",
" messages=messages\n",
")\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"id": "9f9e22da-b891-41f6-9ac9-bd0c0a5f4f44",
"metadata": {},
"source": [
"## Are you confused about why that works?\n",
"\n",
"It seems strange, right? We just used OpenAI code to call Ollama?? What's going on?!\n",
"\n",
"Here's the scoop:\n",
"\n",
"The python class `OpenAI` is simply code written by OpenAI engineers that makes calls over the internet to an endpoint. \n",
"\n",
"When you call `openai.chat.completions.create()`, this python code just makes a web request to the following url: \"https://api.openai.com/v1/chat/completions\"\n",
"\n",
"Code like this is known as a \"client library\" - it's just wrapper code that runs on your machine to make web requests. The actual power of GPT is running on OpenAI's cloud behind this API, not on your computer!\n",
"\n",
"OpenAI was so popular, that lots of other AI providers provided identical web endpoints, so you could use the same approach.\n",
"\n",
"So Ollama has an endpoint running on your local box at http://localhost:11434/v1/chat/completions \n",
"And in week 2 we'll discover that lots of other providers do this too, including Gemini and DeepSeek.\n",
"\n",
"And then the team at OpenAI had a great idea: they can extend their client library so you can specify a different 'base url', and use their library to call any compatible API.\n",
"\n",
"That's it!\n",
"\n",
"So when you say: `ollama_via_openai = OpenAI(base_url='http://localhost:11434/v1', api_key='ollama')` \n",
"Then this will make the same endpoint calls, but to Ollama instead of OpenAI."
]
},
{
"cell_type": "markdown",
"id": "bc7d1de3-e2ac-46ff-a302-3b4ba38c4c90",
"metadata": {},
"source": [
"## Also trying the amazing reasoning model DeepSeek\n",
"\n",
"Here we use the version of DeepSeek-reasoner that's been distilled to 1.5B. \n",
"This is actually a 1.5B variant of Qwen that has been fine-tuned using synethic data generated by Deepseek R1.\n",
"\n",
"Other sizes of DeepSeek are [here](https://ollama.com/library/deepseek-r1) all the way up to the full 671B parameter version, which would use up 404GB of your drive and is far too large for most!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf9eb44e-fe5b-47aa-b719-0bb63669ab3d",
"metadata": {},
"outputs": [],
"source": [
"!ollama pull deepseek-r1:1.5b"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d3d554b-e00d-4c08-9300-45e073950a76",
"metadata": {},
"outputs": [],
"source": [
"# This may take a few minutes to run! You should then see a fascinating \"thinking\" trace inside <think> tags, followed by some decent definitions\n",
"\n",
"response = ollama_via_openai.chat.completions.create(\n",
" model=\"deepseek-r1:1.5b\",\n",
" messages=[{\"role\": \"user\", \"content\": \"Please give definitions of some core concepts behind LLMs: a neural network, attention and the transformer\"}]\n",
")\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"id": "1622d9bb-5c68-4d4e-9ca4-b492c751f898",
"metadata": {},
"source": [
"# NOW the exercise for you\n",
"\n",
"Take the code from day1 and incorporate it here, to build a website summarizer that uses Llama 3.2 running locally instead of OpenAI; use either of the above approaches."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6de38216-6d1c-48c4-877b-86d403f4e0f8",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import requests\n",
"from bs4 import BeautifulSoup\n",
"from IPython.display import Markdown, display"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0bd2aea1-d7d7-499f-b704-5b13e2ddd23f",
"metadata": {},
"outputs": [],
"source": [
"MODEL = \"llama3.2\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6df3141a-0a46-4ff9-ae73-bf8bee2aa3d8",
"metadata": {},
"outputs": [],
"source": [
"# A class to represent a Webpage\n",
"\n",
"class Website:\n",
" \"\"\"\n",
" A utility class to represent a Website that we have scraped\n",
" \"\"\"\n",
" url: str\n",
" title: str\n",
" text: str\n",
"\n",
" def __init__(self, url):\n",
" \"\"\"\n",
" Create this Website object from the given url using the BeautifulSoup library\n",
" \"\"\"\n",
" self.url = url\n",
" response = requests.get(url)\n",
" soup = BeautifulSoup(response.content, 'html.parser')\n",
" self.title = soup.title.string if soup.title else \"No title found\"\n",
" for irrelevant in soup.body([\"script\", \"style\", \"img\", \"input\"]):\n",
" irrelevant.decompose()\n",
" self.text = soup.body.get_text(separator=\"\\n\", strip=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df2ea48b-7343-47be-bdcb-52b63a4de43e",
"metadata": {},
"outputs": [],
"source": [
"# Define our system prompt - you can experiment with this later, changing the last sentence to 'Respond in markdown in Spanish.\"\n",
"\n",
"system_prompt = \"You are an assistant that analyzes the contents of a website \\\n",
"and provides a short summary, ignoring text that might be navigation related. \\\n",
"Respond in markdown.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80f1a534-ae2a-4283-83cf-5e7c5765c736",
"metadata": {},
"outputs": [],
"source": [
"# A function that writes a User Prompt that asks for summaries of websites:\n",
"\n",
"def user_prompt_for(website):\n",
" user_prompt = f\"You are looking at a website titled {website.title}\"\n",
" user_prompt += \"The contents of this website is as follows; \\\n",
"please provide a short summary of this website in markdown. \\\n",
"If it includes news or announcements, then summarize these too.\\n\\n\"\n",
" user_prompt += website.text\n",
" return user_prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5dfe658d-e3f9-4b32-90e6-1a523f47f836",
"metadata": {},
"outputs": [],
"source": [
"# See how this function creates exactly the format above\n",
"\n",
"def messages_for(website):\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": user_prompt_for(website)}\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e2a09d0-bc47-490e-b085-fe3ccfbd16ad",
"metadata": {},
"outputs": [],
"source": [
"# And now: call the Ollama function instead of OpenAI\n",
"\n",
"def summarize(url):\n",
" website = Website(url)\n",
" messages = messages_for(website)\n",
" response = ollama.chat(model=MODEL, messages=messages)\n",
" return response['message']['content']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "340e08a2-86f0-4cdd-9188-da2972cae7a6",
"metadata": {},
"outputs": [],
"source": [
"# A function to display this nicely in the Jupyter output, using markdown\n",
"\n",
"def display_summary(url):\n",
" summary = summarize(url)\n",
" display(Markdown(summary))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55e4790a-013c-40cf-9dff-bb5ec1d53964",
"metadata": {},
"outputs": [],
"source": [
"display_summary(\"https://zhufqiu.com\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8a96cbad-1306-4ce1-a942-2448f50d6751",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,84 @@
"""
Project: Web Content Summarizer using Ollama's llama3.2 model
- Developed a Python tool to extract and summarize website content using Ollama's llama3.2 model and BeautifulSoup.
- Implemented secure API integration and HTTP requests with custom headers to mimic browser behavior.
"""
import os
import requests
from bs4 import BeautifulSoup
import ollama
# Constants
OLLAMA_API = "http://localhost:11434/api/chat"
HEADERS = {"Content-Type": "application/json"}
MODEL = "llama3.2"
# Define the Website class to fetch and parse website content
class Website:
def __init__(self, url):
"""
Initialize a Website object by fetching and parsing the given URL.
Uses BeautifulSoup to extract the title and text content of the page.
"""
self.url = url
response = requests.get(url, headers=HEADERS)
soup = BeautifulSoup(response.content, 'html.parser')
# Extract the title of the website
self.title = soup.title.string if soup.title else "No title found"
# Remove irrelevant elements like scripts, styles, images, and inputs
for irrelevant in soup.body(["script", "style", "img", "input"]):
irrelevant.decompose()
# Extract the main text content of the website
self.text = soup.body.get_text(separator="\n", strip=True)
# Define the system prompt for the OpenAI model
system_prompt = (
"You are an assistant that analyzes the contents of a website "
"and provides a short summary, ignoring text that might be navigation related. "
"Respond in markdown."
)
# Function to generate the user prompt based on the website content
def user_prompt_for(website):
"""
Generate a user prompt for the llama3.2 model based on the website's title and content.
"""
user_prompt = f"You are looking at a website titled {website.title}"
user_prompt += "\nThe contents of this website is as follows; summarize these.\n\n"
user_prompt += website.text
return user_prompt
# Function to create the messages list for the OpenAI API
def messages_for(website):
"""
Create a list of messages for the ollama, including the system and user prompts.
"""
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_for(website)}
]
# Function to summarize the content of a given URL
def summarize(url):
"""
Summarize the content of the given URL using the OpenAI API.
"""
# Create a Website object to fetch and parse the URL
website = Website(url)
# Call the llama3.2 using ollama with the generated messages
response = ollama.chat(
model= MODEL,
messages=messages_for(website)
)
# Return the summary generated by ollama
print(response.message.content)
# Example usage: Summarize the content of a specific URL
summarize("https://sruthianem.com")

View File

@@ -0,0 +1,202 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "fe12c203-e6a6-452c-a655-afb8a03a4ff5",
"metadata": {},
"source": [
"# End of week 1 exercise\n",
"\n",
"To demonstrate your familiarity with OpenAI API, and also Ollama, build a tool that takes a technical question, \n",
"and responds with an explanation. This is a tool that you will be able to use yourself during the course!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1070317-3ed9-4659-abe3-828943230e03",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"from dotenv import load_dotenv\n",
"from IPython.display import Markdown, display, update_display\n",
"from openai import OpenAI\n",
"import ollama"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a456906-915a-4bfd-bb9d-57e505c5093f",
"metadata": {},
"outputs": [],
"source": [
"# constants\n",
"\n",
"MODEL_GPT = 'gpt-4o-mini'\n",
"MODEL_LLAMA = 'llama3.2'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8d7923c-5f28-4c30-8556-342d7c8497c1",
"metadata": {},
"outputs": [],
"source": [
"# set up environment\n",
"\n",
"load_dotenv(override=True)\n",
"api_key = os.getenv('OPENAI_API_KEY')\n",
"\n",
"# Check the key\n",
"\n",
"if not api_key:\n",
" print(\"No API key was found - please head over to the troubleshooting notebook in this folder to identify & fix!\")\n",
"elif not api_key.startswith(\"sk-proj-\"):\n",
" print(\"An API key was found, but it doesn't start sk-proj-; please check you're using the right key - see troubleshooting notebook\")\n",
"elif api_key.strip() != api_key:\n",
" print(\"An API key was found, but it looks like it might have space or tab characters at the start or end - please remove them - see troubleshooting notebook\")\n",
"else:\n",
" print(\"API key found and looks good so far!\")\n",
"\n",
"openai = OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f0d0137-52b0-47a8-81a8-11a90a010798",
"metadata": {},
"outputs": [],
"source": [
"# here is the question; type over this to ask something new\n",
"\n",
"question = \"\"\"\n",
"Please explain what this code does and why:\n",
"yield from {book.get(\"author\") for book in books if book.get(\"author\")}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f879b7e-5ecc-4ec6-b269-78b6e2ed3480",
"metadata": {},
"outputs": [],
"source": [
"# prompts\n",
"\n",
"system_prompt = \"You are a helpful tutor who answers technical questions about programming code(especially python code), software engineering, data science and LLMs\"\n",
"user_prompt = \"Please give a detailed explanation to the following question: \" + question"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4ac74ae5-af61-4a5d-b991-554fa67cd3d1",
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": user_prompt}\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60ce7000-a4a5-4cce-a261-e75ef45063b4",
"metadata": {},
"outputs": [],
"source": [
"# Get gpt-4o-mini to answer, with streaming\n",
"stream = openai.chat.completions.create(\n",
" model=MODEL_GPT,\n",
" messages=messages,\n",
" stream=True\n",
" )\n",
" \n",
"response = \"\"\n",
"display_handle = display(Markdown(\"\"), display_id=True)\n",
"for chunk in stream:\n",
" response += chunk.choices[0].delta.content or ''\n",
" response = response.replace(\"```\",\"\").replace(\"markdown\", \"\")\n",
" update_display(Markdown(response), display_id=display_handle.display_id)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f7c8ea8-4082-4ad0-8751-3301adcf6538",
"metadata": {},
"outputs": [],
"source": [
"# Get Llama 3.2 to answer\n",
"\n",
"OLLAMA_API = \"http://localhost:11434/api/chat\"\n",
"HEADERS = {\"Content-Type\": \"application/json\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4bd10d96-ee72-4c86-acd8-4fa417c25960",
"metadata": {},
"outputs": [],
"source": [
"!ollama pull llama3.2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d889d514-0478-4d7f-aabf-9a7bc743adb1",
"metadata": {},
"outputs": [],
"source": [
"stream = ollama.chat(model=MODEL_LLAMA, messages=messages, stream=True)\n",
"\n",
"response = \"\"\n",
"display_handle = display(Markdown(\"\"), display_id=True)\n",
"for chunk in stream:\n",
" response += chunk.get(\"message\", {}).get(\"content\", \"\")\n",
" response = response.replace(\"```\",\"\").replace(\"markdown\", \"\")\n",
" update_display(Markdown(response), display_id=display_handle.display_id)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "452d442a-f3b0-42ad-89d2-a8dc664e8bb6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,143 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d006b2ea-9dfe-49c7-88a9-a5a0775185fd",
"metadata": {},
"source": [
"# Additional End of week Exercise - week 2\n",
"\n",
"Now use everything you've learned from Week 2 to build a full prototype for the technical question/answerer you built in Week 1 Exercise.\n",
"\n",
"This should include a Gradio UI, streaming, use of the system prompt to add expertise, and the ability to switch between models. Bonus points if you can demonstrate use of a tool!\n",
"\n",
"If you feel bold, see if you can add audio input so you can talk to it, and have it respond with audio. ChatGPT or Claude can help you, or email me if you have questions.\n",
"\n",
"I will publish a full solution here soon - unless someone beats me to it...\n",
"\n",
"There are so many commercial applications for this, from a language tutor, to a company onboarding solution, to a companion AI to a course (like this one!) I can't wait to see your results."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a07e7793-b8f5-44f4-aded-5562f633271a",
"metadata": {},
"outputs": [],
"source": [
"# Agent that can listen for audio and convert it to text"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da58ed0f-f781-4c51-8e5d-fdb05db98c8c",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import gradio as gr\n",
"import google.generativeai as genai\n",
"from dotenv import load_dotenv\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "078cf34a-881e-44f4-9947-c45d7fe992a3",
"metadata": {},
"outputs": [],
"source": [
"load_dotenv()\n",
"\n",
"google_api_key = os.getenv('GOOGLE_API_KEY')\n",
"if google_api_key:\n",
" print(f\"Google API Key exists and begins {google_api_key[:8]}\")\n",
"else:\n",
" print(\"Google API Key not set\")\n",
"\n",
"genai.configure(api_key=google_api_key)\n",
"model = genai.GenerativeModel(\"gemini-2.0-flash\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f77228ea-d0e1-4434-9191-555a6d680625",
"metadata": {},
"outputs": [],
"source": [
"def transcribe_translate_with_gemini(audio_file_path):\n",
" if not audio_file_path:\n",
" return \"⚠️ No audio file received.\"\n",
"\n",
" prompt = (\n",
" \"You're an AI that listens to a voice message in any language and returns the English transcription. \"\n",
" \"Please transcribe and translate the following audio to English. If already in English, just transcribe it.\"\n",
" )\n",
"\n",
" uploaded_file = genai.upload_file(audio_file_path)\n",
"\n",
" # 🔁 Send prompt + uploaded audio reference to Gemini\n",
" response = model.generate_content(\n",
" contents=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"parts\": [\n",
" {\"text\": prompt},\n",
" uploaded_file \n",
" ]\n",
" }\n",
" ]\n",
" )\n",
"\n",
" return response.text.strip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eb6c6d1e-1be3-404d-83f3-fc0855dc9f67",
"metadata": {},
"outputs": [],
"source": [
"gr.Interface(\n",
" fn=transcribe_translate_with_gemini,\n",
" inputs=gr.Audio(label=\"Record voice\", type=\"filepath\"),\n",
" outputs=\"text\",\n",
" title=\"🎙️ Voice-to-English Translator (Gemini Only)\",\n",
" description=\"Speak in any language and get the English transcription using Gemini multimodal API.\"\n",
").launch()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b105082-e388-44bc-9617-1a81f38e2f3f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,808 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d938fc6c-bcca-4572-b851-75370fe21c67",
"metadata": {},
"source": [
"# Airline Assistant using Gemini API for Image and Audio as well - Live ticket prices using Amadeus API"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5eda470-07ee-4d01-bada-3390050ac9c2",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import random\n",
"import string\n",
"import base64\n",
"import gradio as gr\n",
"import pyaudio\n",
"import requests\n",
"from io import BytesIO\n",
"from PIL import Image\n",
"from dotenv import load_dotenv\n",
"from google import genai\n",
"from google.genai import types"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09aaf3b0-beb7-4b64-98a4-da16fc83dadb",
"metadata": {},
"outputs": [],
"source": [
"load_dotenv(override=True)\n",
"api_key = os.getenv(\"GOOGLE_API_KEY\")\n",
"\n",
"if not api_key:\n",
" print(\"API Key not found!\")\n",
"else:\n",
" print(\"API Key loaded in memory\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35881fb9-4d51-43dc-a5e6-d9517e22019a",
"metadata": {},
"outputs": [],
"source": [
"MODEL_GEMINI = 'gemini-2.5-flash'\n",
"MODEL_GEMINI_IMAGE = 'gemini-2.0-flash-preview-image-generation'\n",
"MODEL_GEMINI_SPEECH = 'gemini-2.5-flash-preview-tts'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5ed391c-8a67-4465-9c66-e915548a0d6a",
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" client = genai.Client(api_key=api_key)\n",
" print(\"Google GenAI Client initialized successfully!\")\n",
"except Exception as e:\n",
" print(f\"Error initializing GenAI Client: {e}\")\n",
" print(\"Ensure your GOOGLE_API_KEY is correctly set as an environment variable.\")\n",
" exit() "
]
},
{
"cell_type": "markdown",
"id": "407ad581-9580-4dba-b236-abb6c6788933",
"metadata": {},
"source": [
"## Image Generation "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a21921f8-57b1-4665-8999-7f2a40645b59",
"metadata": {},
"outputs": [],
"source": [
"def fetch_image(city):\n",
" prompt = (\n",
" f\"A high-quality, photo-realistic image of a vacation in {city}, \"\n",
" f\"showing iconic landmarks, cultural attractions, authentic street life, and local cuisine. \"\n",
" f\"Capture natural lighting, real people enjoying travel experiences, and the unique vibe of {city}'s atmosphere. \"\n",
" f\"The composition should feel immersive, warm, and visually rich, as if taken by a travel photographer.\"\n",
")\n",
"\n",
" response = client.models.generate_content(\n",
" model = MODEL_GEMINI_IMAGE,\n",
" contents = prompt,\n",
" config=types.GenerateContentConfig(\n",
" response_modalities=['TEXT', 'IMAGE']\n",
" )\n",
" )\n",
"\n",
" for part in response.candidates[0].content.parts:\n",
" if part.inline_data is not None:\n",
" image_data = BytesIO(part.inline_data.data)\n",
" return Image.open(image_data)\n",
"\n",
" raise ValueError(\"No image found in Gemini response.\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bcd4aed1-8b4d-4771-ba32-e729e82bab54",
"metadata": {},
"outputs": [],
"source": [
"fetch_image(\"london\")"
]
},
{
"cell_type": "markdown",
"id": "5f6baee6-e2e2-4cc4-941d-34a4c72cee67",
"metadata": {},
"source": [
"## Speech Generation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "825dfedc-0271-4191-a3d1-50872af4c8cf",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"Kore -- Firm\n",
"Puck -- Upbeat\n",
"Leda -- Youthful\n",
"Iapetus -- Clear\n",
"Erinome -- Clear\n",
"Sadachbia -- Lively\n",
"Sulafat -- Warm\n",
"Despina -- Smooth\n",
"\"\"\"\n",
"\n",
"def talk(message:str, voice_name:str=\"Leda\", mood:str=\"cheerfully\"):\n",
" prompt = f\"Say {mood}: {message}\"\n",
" response = client.models.generate_content(\n",
" model = MODEL_GEMINI_SPEECH,\n",
" contents = prompt,\n",
" config=types.GenerateContentConfig(\n",
" response_modalities=[\"AUDIO\"],\n",
" speech_config=types.SpeechConfig(\n",
" voice_config=types.VoiceConfig(\n",
" prebuilt_voice_config=types.PrebuiltVoiceConfig(\n",
" voice_name=voice_name,\n",
" )\n",
" )\n",
" ), \n",
" )\n",
" )\n",
"\n",
" # Fetch the audio bytes\n",
" pcm_data = response.candidates[0].content.parts[0].inline_data.data\n",
" # Play the audio using PyAudio\n",
" p = pyaudio.PyAudio()\n",
" stream = p.open(format=pyaudio.paInt16, channels=1, rate=24000, output=True)\n",
" stream.write(pcm_data)\n",
" stream.stop_stream()\n",
" stream.close()\n",
" p.terminate()\n",
"\n",
" # Play using simpleaudio (16-bit PCM, mono, 24kHz)\n",
" # play_obj = sa.play_buffer(pcm_data, num_channels=1, bytes_per_sample=2, sample_rate=24000)\n",
" # play_obj.wait_done() "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54967ebc-24a6-4bb2-9a19-20c3585f1d77",
"metadata": {},
"outputs": [],
"source": [
"talk(\"Hi, How are you? Welcome to FlyJumbo Airlines\",\"Kore\",\"helpful\")"
]
},
{
"cell_type": "markdown",
"id": "be9dc275-838e-4c54-b487-41d094dad96b",
"metadata": {},
"source": [
"## Ticket Price Tool Function - Using Amadeus API "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8613a080-d82c-4c1a-8db4-377614997ac2",
"metadata": {},
"outputs": [],
"source": [
"client_id = os.getenv(\"AMADEUS_CLIENT_ID\")\n",
"client_secret = os.getenv(\"AMADEUS_CLIENT_SECRET\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6bf78f61-0de1-4552-a1d4-1a28380be6a5",
"metadata": {},
"outputs": [],
"source": [
"# Get the token first\n",
"def get_amadeus_token():\n",
" url = \"https://test.api.amadeus.com/v1/security/oauth2/token\"\n",
" headers = {\"Content-Type\": \"application/x-www-form-urlencoded\"}\n",
" data = {\n",
" \"grant_type\": \"client_credentials\",\n",
" \"client_id\": client_id,\n",
" \"client_secret\": client_secret,\n",
" }\n",
" \n",
" try:\n",
" response = requests.post(url, headers=headers, data=data, timeout=10)\n",
" response.raise_for_status()\n",
" return response.json()[\"access_token\"]\n",
" \n",
" except requests.exceptions.HTTPError as e:\n",
" print(f\"HTTP Error {response.status_code}: {response.text}\")\n",
" \n",
" except requests.exceptions.RequestException as e:\n",
" print(\"Network or connection error:\", e)\n",
" \n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c5261f6-6662-4e9d-8ff0-8e10171bb963",
"metadata": {},
"outputs": [],
"source": [
"def get_airline_name(code, token):\n",
" url = f\"https://test.api.amadeus.com/v1/reference-data/airlines\"\n",
" headers = {\"Authorization\": f\"Bearer {token}\"}\n",
" params = {\"airlineCodes\": code}\n",
"\n",
" response = requests.get(url, headers=headers, params=params)\n",
" response.raise_for_status()\n",
" data = response.json()\n",
"\n",
" if \"data\" in data and data[\"data\"]:\n",
" return data[\"data\"][0].get(\"businessName\", code)\n",
" return code"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42a55f06-880a-4c49-8560-2e7b97953c1a",
"metadata": {},
"outputs": [],
"source": [
"COMMON_CITY_CODES = {\n",
" \"delhi\": \"DEL\",\n",
" \"mumbai\": \"BOM\",\n",
" \"chennai\": \"MAA\",\n",
" \"kolkata\": \"CCU\",\n",
" \"bengaluru\": \"BLR\",\n",
" \"hyderabad\": \"HYD\",\n",
" \"patna\": \"PAT\",\n",
" \"raipur\": \"RPR\",\n",
" \"panaji\": \"GOI\",\n",
" \"chandigarh\": \"IXC\",\n",
" \"srinagar\": \"SXR\",\n",
" \"ranchi\": \"IXR\",\n",
" \"bengaluru\": \"BLR\",\n",
" \"thiruvananthapuram\": \"TRV\",\n",
" \"bhopal\": \"BHO\",\n",
" \"mumbai\": \"BOM\",\n",
" \"imphal\": \"IMF\",\n",
" \"aizawl\": \"AJL\",\n",
" \"bhubaneswar\": \"BBI\",\n",
" \"jaipur\": \"JAI\",\n",
" \"chennai\": \"MAA\",\n",
" \"hyderabad\": \"HYD\",\n",
" \"agartala\": \"IXA\",\n",
" \"lucknow\": \"LKO\",\n",
" \"dehradun\": \"DED\",\n",
" \"kolkata\": \"CCU\",\n",
"\n",
" # Union territories\n",
" \"port blair\": \"IXZ\",\n",
" \"leh\": \"IXL\",\n",
" \"puducherry\": \"PNY\",\n",
"\n",
" # Major metro cities (for redundancy)\n",
" \"ahmedabad\": \"AMD\",\n",
" \"surat\": \"STV\",\n",
" \"coimbatore\": \"CJB\",\n",
" \"vizag\": \"VTZ\",\n",
" \"vijayawada\": \"VGA\",\n",
" \"nagpur\": \"NAG\",\n",
" \"indore\": \"IDR\",\n",
" \"kanpur\": \"KNU\",\n",
" \"varanasi\": \"VNS\"\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b061ec2c-609b-4d77-bd41-c9bc5bf901f4",
"metadata": {},
"outputs": [],
"source": [
"city_code_cache = {}\n",
"\n",
"def get_city_code(city_name, token):\n",
" city_name = city_name.strip().lower()\n",
"\n",
" if city_name in city_code_cache:\n",
" return city_code_cache[city_name]\n",
"\n",
" if city_name in COMMON_CITY_CODES:\n",
" return COMMON_CITY_CODES[city_name]\n",
"\n",
" base_url = \"https://test.api.amadeus.com/v1/reference-data/locations\"\n",
" headers = {\"Authorization\": f\"Bearer {token}\"}\n",
"\n",
" for subtype in [\"CITY\", \"AIRPORT,CITY\"]:\n",
" params = {\"keyword\": city_name, \"subType\": subtype}\n",
" try:\n",
" response = requests.get(base_url, headers=headers, params=params, timeout=10)\n",
" response.raise_for_status()\n",
" data = response.json()\n",
"\n",
" if \"data\" in data and data[\"data\"]:\n",
" code = data[\"data\"][0][\"iataCode\"]\n",
" print(f\"[INFO] Found {subtype} match for '{city_name}': {code}\")\n",
" city_code_cache[city_name] = code\n",
" return code\n",
" except Exception as e:\n",
" print(f\"[ERROR] Location lookup failed for {subtype}: {e}\")\n",
"\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9816a9c-fd70-4dfc-a3c0-4d8709997371",
"metadata": {},
"outputs": [],
"source": [
"# Getting live ticket price \n",
"\n",
"def get_live_ticket_prices(origin, destination, departure_date, return_date=None):\n",
" token = get_amadeus_token()\n",
"\n",
" url = \"https://test.api.amadeus.com/v2/shopping/flight-offers\"\n",
" headers = {\"Authorization\": f\"Bearer {token}\"}\n",
"\n",
" origin_code = get_city_code(origin,token)\n",
" destination_code = get_city_code(destination,token)\n",
"\n",
" if not origin_code:\n",
" return f\"Sorry, I couldn't find the airport code for the city '{origin}'.\"\n",
" if not destination_code:\n",
" return f\"Sorry, I couldn't find the airport code for the city '{destination}'.\"\n",
"\n",
" params = {\n",
" \"originLocationCode\": origin_code.upper(),\n",
" \"destinationLocationCode\": destination_code.upper(),\n",
" \"departureDate\": departure_date,\n",
" \"adults\": 1,\n",
" \"currencyCode\": \"USD\",\n",
" \"max\": 1,\n",
" }\n",
"\n",
" if return_date:\n",
" params[\"returnDate\"] = return_date\n",
"\n",
" try:\n",
" response = requests.get(url, headers=headers, params=params, timeout=10)\n",
" response.raise_for_status()\n",
" data = response.json()\n",
" \n",
" if \"data\" in data and data[\"data\"]:\n",
" offer = data[\"data\"][0]\n",
" price = offer[\"price\"][\"total\"]\n",
" airline_codes = offer.get(\"validatingAirlineCodes\", [])\n",
" airline_code = airline_codes[0] if airline_codes else \"Unknown\"\n",
"\n",
" try:\n",
" airline_name = get_airline_name(airline_code, token) if airline_code != \"Unknown\" else \"Unknown Airline\"\n",
" if not airline_name: \n",
" airline_name = airline_code\n",
" except Exception:\n",
" airline_name = airline_code\n",
" \n",
" \n",
" if return_date:\n",
" return (\n",
" f\"Round-trip flight from {origin.capitalize()} to {destination.capitalize()}:\\n\"\n",
" f\"- Departing: {departure_date}\\n\"\n",
" f\"- Returning: {return_date}\\n\"\n",
" f\"- Airline: {airline_name}\\n\"\n",
" f\"- Price: ${price}\"\n",
" )\n",
" else:\n",
" return (\n",
" f\"One-way flight from {origin.capitalize()} to {destination.capitalize()} on {departure_date}:\\n\"\n",
" f\"- Airline: {airline_name}\\n\"\n",
" f\"- Price: ${price}\"\n",
" )\n",
" else:\n",
" return f\"No flights found from {origin.capitalize()} to {destination.capitalize()} on {departure_date}.\"\n",
" except requests.exceptions.RequestException as e:\n",
" return f\"❌ Error fetching flight data: {str(e)}\" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7bc7657e-e8b5-4647-9745-d7d403feb09a",
"metadata": {},
"outputs": [],
"source": [
"get_live_ticket_prices(\"london\", \"chennai\", \"2025-07-01\",\"2025-07-10\")"
]
},
{
"cell_type": "markdown",
"id": "e1153b94-90e7-4856-8c85-e456305a7817",
"metadata": {},
"source": [
"## Ticket Booking Tool Function - DUMMY"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5dfc3b12-0a16-4861-a549-594f175ff956",
"metadata": {},
"outputs": [],
"source": [
"def book_flight(origin, destination, departure_date, return_date=None, airline=\"Selected Airline\", passenger_name=\"Guest\"):\n",
" # Generate a dummy ticket reference (PNR)\n",
" ticket_ref = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))\n",
"\n",
" # Build confirmation message\n",
" confirmation = (\n",
" f\"🎫 Booking confirmed for {passenger_name}!\\n\"\n",
" f\"From: {origin.capitalize()} → To: {destination.capitalize()}\\n\"\n",
" f\"Departure: {departure_date}\"\n",
" )\n",
"\n",
" if return_date:\n",
" confirmation += f\"\\nReturn: {return_date}\"\n",
"\n",
" confirmation += (\n",
" f\"\\nAirline: {airline}\\n\"\n",
" f\"PNR: {ticket_ref}\\n\"\n",
" f\"✅ Your ticket has been booked successfully. Safe travels!\"\n",
" )\n",
"\n",
" return confirmation\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "122f655b-b7a4-45c6-aaec-afd2917a051b",
"metadata": {},
"outputs": [],
"source": [
"print(book_flight(\"chennai\", \"delhi\", \"2025-07-01\", \"2025-07-10\", \"Air India\", \"Ravi Kumar\"))"
]
},
{
"cell_type": "markdown",
"id": "e83d8e90-ae22-4728-83e5-d83fed7f2049",
"metadata": {},
"source": [
"## Gemini Chat Workings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a656f4e-914d-4f5e-b7fa-48457935181a",
"metadata": {},
"outputs": [],
"source": [
"ticket_price_function_declaration = {\n",
" \"name\":\"get_live_ticket_prices\",\n",
" \"description\": \"Get live flight ticket prices between two cities for a given date (round-trip or one-way).\\\n",
" The destination may be a city or country (e.g., 'China'). Call this function whenever a customer asks about ticket prices., such as 'How much is a ticket to Paris?'\",\n",
" \"parameters\":{\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"origin\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Name of the origin city. Example: 'Delhi'\",\n",
" },\n",
" \"destination\": {\n",
" \"type\": \"string\",\n",
" \"description\":\"Name of the destination city. Example: 'London'\",\n",
" },\n",
" \"departure_date\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Date of departure in YYYY-MM-DD format. Example: '2025-07-01'\",\n",
" },\n",
" \"return_date\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Optional return date for round-trip in YYYY-MM-DD format. Leave blank for one-way trips.\",\n",
" },\n",
" },\n",
" \"required\": [\"origin\", \"destination\", \"departure_date\"],\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05a835ab-a675-40ed-9cd8-65f4c6b22722",
"metadata": {},
"outputs": [],
"source": [
"book_flight_function_declaration = {\n",
" \"name\": \"book_flight\",\n",
" \"description\": \"Book a flight for the user after showing the ticket details and confirming the booking. \"\n",
" \"Call this function when the user says things like 'yes', 'book it', or 'I want to book this flight'.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"origin\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Name of the origin city. Example: 'Chennai'\",\n",
" },\n",
" \"destination\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Name of the destination city. Example: 'London'\",\n",
" },\n",
" \"departure_date\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Date of departure in YYYY-MM-DD format. Example: '2025-07-01'\",\n",
" },\n",
" \"return_date\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Optional return date for round-trip in YYYY-MM-DD format. Leave blank for one-way trips.\",\n",
" },\n",
" \"airline\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Airline name or code that the user wants to book with. Example: 'Air India'\",\n",
" },\n",
" \"passenger_name\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Full name of the passenger for the booking. Example: 'Ravi Kumar'\",\n",
" }\n",
" },\n",
" \"required\": [\"origin\", \"destination\", \"departure_date\", \"passenger_name\"],\n",
" }\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad0231cd-040f-416d-b150-0d8f90535718",
"metadata": {},
"outputs": [],
"source": [
"# System Definitions\n",
"\n",
"system_instruction_prompt = (\n",
" \"You are a helpful and courteous AI assistant for an airline company called FlyJumbo. \"\n",
" \"When a user starts a new conversation, greet them with: 'Hi there, welcome to FlyJumbo! How can I help you?'. \"\n",
" \"Do not repeat this greeting in follow-up messages. \"\n",
" \"Use the available tools if a user asks about ticket prices. \"\n",
" \"Ask follow-up questions to gather all necessary information before calling a function.\"\n",
" \"After calling a tool, always continue the conversation by summarizing the result and asking the user the next relevant question (e.g., if they want to proceed with a booking).\"\n",
" \"If you do not know the answer and no tool can help, respond politely that you are unable to help with the request. \"\n",
" \"Answer concisely in one sentence.\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff0b3de8-5674-4f08-9f9f-06f88ff959a1",
"metadata": {},
"outputs": [],
"source": [
"tools = types.Tool(function_declarations=[ticket_price_function_declaration,book_flight_function_declaration])\n",
"generate_content_config = types.GenerateContentConfig(system_instruction=system_instruction_prompt, tools=[tools])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00a56779-16eb-4f31-9941-2eb01d17ed87",
"metadata": {},
"outputs": [],
"source": [
"def handle_tool_call(function_call):\n",
" print(f\"🔧 Function Called - {function_call.name}\")\n",
" function_name = function_call.name\n",
" args = function_call.args\n",
"\n",
" if function_name == \"get_live_ticket_prices\":\n",
" origin = args.get(\"origin\")\n",
" destination = args.get(\"destination\")\n",
" departure_date = args.get(\"departure_date\")\n",
" return_date = args.get(\"return_date\") or None\n",
"\n",
" return get_live_ticket_prices(origin, destination, departure_date, return_date)\n",
"\n",
" elif function_name == \"book_flight\":\n",
" origin = args.get(\"origin\")\n",
" destination = args.get(\"destination\")\n",
" departure_date = args.get(\"departure_date\")\n",
" return_date = args.get(\"return_date\") or None\n",
" airline = args.get(\"airline\", \"Selected Airline\")\n",
" passenger_name = args.get(\"passenger_name\", \"Guest\")\n",
"\n",
" return book_flight(origin, destination, departure_date, return_date, airline, passenger_name)\n",
"\n",
" else:\n",
" return f\"❌ Unknown function: {function_name}\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0c334d2-9ab0-4f80-ac8c-c66897e0bd7c",
"metadata": {},
"outputs": [],
"source": [
"def chat(message, history):\n",
" full_message_history = []\n",
" city_name = None\n",
"\n",
" # Convert previous history to Gemini-compatible format\n",
" for h in history:\n",
" if h[\"role\"] == \"user\":\n",
" full_message_history.append(\n",
" types.Content(role=\"user\", parts=[types.Part.from_text(text=h[\"content\"])])\n",
" )\n",
" elif h[\"role\"] == \"assistant\":\n",
" full_message_history.append(\n",
" types.Content(role=\"model\", parts=[types.Part.from_text(text=h[\"content\"])])\n",
" )\n",
"\n",
" # Add current user message\n",
" full_message_history.append(\n",
" types.Content(role=\"user\", parts=[types.Part.from_text(text=message)])\n",
" )\n",
"\n",
" # Send to Gemini with tool config\n",
" response = client.models.generate_content(\n",
" model=MODEL_GEMINI,\n",
" contents=full_message_history,\n",
" config=generate_content_config\n",
" )\n",
"\n",
" candidate = response.candidates[0]\n",
" part = candidate.content.parts[0]\n",
" function_call = getattr(part, \"function_call\", None)\n",
"\n",
" # Case: Tool call required\n",
" if function_call:\n",
" # Append model message that triggered tool call\n",
" full_message_history.append(\n",
" types.Content(role=\"model\", parts=candidate.content.parts)\n",
" )\n",
"\n",
" # Execute the tool\n",
" tool_output = handle_tool_call(function_call)\n",
"\n",
" # Wrap and append tool output\n",
" tool_response_part = types.Part.from_function_response(\n",
" name=function_call.name,\n",
" response={\"result\": tool_output}\n",
" )\n",
" \n",
" full_message_history.append(\n",
" types.Content(role=\"function\", parts=[tool_response_part])\n",
" )\n",
"\n",
"\n",
" if function_call.name == \"book_flight\":\n",
" city_name = function_call.args.get(\"destination\").lower()\n",
" \n",
"\n",
" # Send follow-up message including tool result\n",
" followup_response = client.models.generate_content(\n",
" model=MODEL_GEMINI,\n",
" contents=full_message_history,\n",
" config=generate_content_config\n",
" )\n",
"\n",
" final_text = followup_response.text\n",
" \n",
" full_message_history.append(\n",
" types.Content(role=\"model\", parts=[types.Part.from_text(text=final_text)])\n",
" )\n",
"\n",
" return final_text,city_name, history + [{\"role\": \"assistant\", \"content\": final_text}]\n",
" else:\n",
" text = response.text\n",
" return text, city_name, history + [{\"role\": \"assistant\", \"content\": text}]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b245e6c-ef0b-4edf-b178-f14f2a75f285",
"metadata": {},
"outputs": [],
"source": [
"def user_submit(user_input, history):\n",
" history = history or []\n",
" history.append({\"role\": \"user\", \"content\": user_input})\n",
" \n",
" response_text, city_to_image, updated_history = chat(user_input, history)\n",
"\n",
" # Speak the response\n",
" try:\n",
" talk(response_text)\n",
" except Exception as e:\n",
" print(\"[Speech Error] Speech skipped due to quota limit.\")\n",
"\n",
" image = fetch_image(city_to_image) if city_to_image else None\n",
"\n",
" return \"\", updated_history, image, updated_history\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7db25b86-9a71-417c-98f0-790e3f3531bf",
"metadata": {},
"outputs": [],
"source": [
"with gr.Blocks() as demo:\n",
" gr.Markdown(\"## ✈️ FlyJumbo Airline Assistant\")\n",
"\n",
" with gr.Row():\n",
" with gr.Column(scale=3):\n",
" chatbot = gr.Chatbot(label=\"Assistant\", height=500, type=\"messages\")\n",
" msg = gr.Textbox(placeholder=\"Ask about flights...\", show_label=False)\n",
" send_btn = gr.Button(\"Send\")\n",
"\n",
" with gr.Column(scale=2):\n",
" image_output = gr.Image(label=\"Trip Visual\", visible=True, height=500)\n",
"\n",
" state = gr.State([])\n",
" \n",
" send_btn.click(fn=user_submit, inputs=[msg, state], outputs=[msg, chatbot, image_output, state])\n",
" msg.submit(fn=user_submit, inputs=[msg, state], outputs=[msg, chatbot, image_output, state])\n",
"\n",
"demo.launch(inbrowser=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef31bf62-9034-4fa7-b803-8f5df5309b77",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,237 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "b5bd5c7e-6a0a-400b-89f8-06b7aa6c5b89",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import anthropic\n",
"from IPython.display import Markdown, display, update_display\n",
"import google.generativeai"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "939a1b88-9157-4149-8b97-0f55c95f7742",
"metadata": {},
"outputs": [],
"source": [
"# Load environment variables in a file called .env\n",
"# Print the key prefixes to help with any debugging\n",
"\n",
"load_dotenv(override=True)\n",
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
"google_api_key = os.getenv('GOOGLE_API_KEY')\n",
"\n",
"if openai_api_key:\n",
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
"else:\n",
" print(\"OpenAI API Key not set\")\n",
" \n",
"if anthropic_api_key:\n",
" print(f\"Anthropic API Key exists and begins {anthropic_api_key[:7]}\")\n",
"else:\n",
" print(\"Anthropic API Key not set\")\n",
"\n",
"if google_api_key:\n",
" print(f\"Google API Key exists and begins {google_api_key[:8]}\")\n",
"else:\n",
" print(\"Google API Key not set\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74a16b93-7b95-44fc-956d-7335f808960b",
"metadata": {},
"outputs": [],
"source": [
"# Connect to OpenAI, Anthropic Claude, Google Gemini\n",
"\n",
"openai = OpenAI()\n",
"claude = anthropic.Anthropic()\n",
"gemini_via_openai_client = OpenAI(\n",
" api_key=google_api_key, \n",
" base_url=\"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3334556c-4a5e-48b7-944d-5943c607be02",
"metadata": {},
"outputs": [],
"source": [
"# Let's make a conversation between GPT-4o-mini and Claude-3-haiku\n",
"# We're using cheap versions of models so the costs will be minimal\n",
"\n",
"gpt_model = \"gpt-4o-mini\"\n",
"claude_model = \"claude-3-haiku-20240307\"\n",
"gemini_model = \"gemini-1.5-flash\"\n",
"\n",
"gpt_system = \"You are a chatbot who is very argumentative; \\\n",
"you disagree with anything in the conversation and you challenge everything, in a snarky way. \\\n",
"Generate one sentence at a time\"\n",
"\n",
"claude_system = \"You are a very polite, courteous chatbot. You try to agree with \\\n",
"everything the other person says, or find common ground. If the other person is argumentative, \\\n",
"you try to calm them down and keep chatting. \\\n",
"Generate one sentence at a time\"\n",
"\n",
"gemini_system = \"You are a neutral chatbot with no emotional bias. \\\n",
"Generate one sentence at a time\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f2a505b-2bcd-4b1a-b16f-c73cafb1e53c",
"metadata": {},
"outputs": [],
"source": [
"def combine_msg(model1, msg1, model2, msg2):\n",
" return model1 + \" said: \" + msg1 + \"\\n\\n Then \" + model2 + \" said: \" + msg1 + \".\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3cd2a2e2-4e23-4afe-915d-be6a769ab69f",
"metadata": {},
"outputs": [],
"source": [
"def call_gpt():\n",
" messages = [{\"role\": \"system\", \"content\": gpt_system}]\n",
" for gpt_msg, claude_msg, gemini_msg in zip(gpt_messages, claude_messages, gemini_messages):\n",
" messages.append({\"role\": \"assistant\", \"content\": gpt_msg})\n",
" messages.append({\"role\": \"user\", \"content\": combine_msg(\"Claude\", claude_msg, \"Gemini\", gemini_msg)})\n",
" completion = openai.chat.completions.create(\n",
" model=gpt_model,\n",
" messages=messages\n",
" )\n",
" return completion.choices[0].message.content"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e3ec394-3014-418a-a50f-28ed4ce1a372",
"metadata": {},
"outputs": [],
"source": [
"def call_claude():\n",
" messages = []\n",
" messages.append({\"role\": \"user\", \"content\": \"GPT said: \" + gpt_messages[0]})\n",
" # the length of gpt_messages: n + 1\n",
" # the length of claude_messages and gemini_messages: n\n",
" for i in range(len(claude_messages)): \n",
" claude_msg = claude_messages[i]\n",
" gemini_msg = gemini_messages[i]\n",
" gpt_msg = gpt_messages[i + 1]\n",
" messages.append({\"role\": \"assistant\", \"content\": claude_msg})\n",
" messages.append({\"role\": \"user\", \"content\": combine_msg(\"Gemini\", gemini_msg, \"GPT\", gpt_msg)})\n",
" message = claude.messages.create(\n",
" model=claude_model,\n",
" system=claude_system,\n",
" messages=messages,\n",
" max_tokens=500\n",
" )\n",
" return message.content[0].text"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2c91c82-1f0d-4708-bf31-8d06d9e28a49",
"metadata": {},
"outputs": [],
"source": [
"def call_gemini():\n",
" messages = []\n",
" messages.append({\"role\": \"system\", \"content\": gemini_system})\n",
" messages.append({\"role\": \"user\", \"content\": combine_msg(\"GPT\", gpt_messages[0], \"Claude\", claude_messages[0])})\n",
" # the length of gpt_messages and claude_messages: n + 1\n",
" # the length of gemini_messages: n\n",
" for i in range(len(gemini_messages)): \n",
" gemini_msg = gemini_messages[i]\n",
" gpt_msg = gpt_messages[i + 1]\n",
" claude_msg = claude_messages[i + 1]\n",
" messages.append({\"role\": \"assistant\", \"content\": gemini_msg})\n",
" messages.append({\"role\": \"user\", \"content\": combine_msg(\"GPT\", gpt_msg, \"Claude\", claude_msg)})\n",
" response = gemini_via_openai_client.chat.completions.create(\n",
" model=gemini_model,\n",
" messages=messages\n",
" )\n",
" return response.choices[0].message.content"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b024be8d-4728-4500-92b6-34fde2da6285",
"metadata": {},
"outputs": [],
"source": [
"gpt_messages = [\"Hi there.\"]\n",
"claude_messages = [\"Hi.\"]\n",
"gemini_messages = [\"Hi.\"]\n",
"\n",
"print(f\"GPT:\\n{gpt_messages[0]}\\n\")\n",
"print(f\"Claude:\\n{claude_messages[0]}\\n\")\n",
"print(f\"Gemini:\\n{gemini_messages[0]}\\n\")\n",
"\n",
"for i in range(5):\n",
" gpt_next = call_gpt()\n",
" print(f\"GPT:\\n{gpt_next}\\n\")\n",
" gpt_messages.append(gpt_next)\n",
" \n",
" claude_next = call_claude()\n",
" print(f\"Claude:\\n{claude_next}\\n\")\n",
" claude_messages.append(claude_next)\n",
"\n",
" gemini_next = call_gemini()\n",
" print(f\"Gemini:\\n{gemini_next}\\n\")\n",
" gemini_messages.append(gemini_next)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35a46c06-87ba-46b2-b90d-b3a6ae9e94e2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,265 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7462b9d6-b189-43fc-a7b9-c56a9c6a62fc",
"metadata": {},
"source": [
"# LLM Battle Arena\n",
"\n",
"A fun project simulating a debate among three LLM personas: an Arrogant Titan, a Clever Underdog (Spark), and a Neutral Mediator (Harmony).\n",
"\n",
"## LLM Used\n",
"* Qwen (ollama)\n",
"* llma (ollama)\n",
"* Gemini\n"
]
},
{
"cell_type": "markdown",
"id": "b267453c-0d47-4dff-b74d-8d2d5efad252",
"metadata": {},
"source": [
"!pip install -q -U google-genai"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5220daef-55d6-45bc-a3cf-3414d4beada9",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"import os\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"from google import genai\n",
"from google.genai import types\n",
"from IPython.display import Markdown, display, update_display"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d47fb2f-d0c6-461f-ad57-e853bfd49fbf",
"metadata": {},
"outputs": [],
"source": [
"#get API keys from env\n",
"load_dotenv(override=True)\n",
"\n",
"GEMINI_API_KEY = os.getenv(\"GEMINI_API_KEY\")\n",
"\n",
"if GEMINI_API_KEY:\n",
" print(f\"GEMINI API Key exists and begins {GEMINI_API_KEY[:8]}\")\n",
"else:\n",
" print(\"GEMINI API Key not set\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f34b528f-3596-4bf1-9bbd-21a701c184bc",
"metadata": {},
"outputs": [],
"source": [
"#connect to llms\n",
"ollama = OpenAI(base_url='http://localhost:11434/v1', api_key='ollama')\n",
"gemini = genai.Client(api_key=GEMINI_API_KEY)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33aaf3f6-807c-466d-a501-05ab6fa78fa4",
"metadata": {},
"outputs": [],
"source": [
"#define models\n",
"model_llma = \"llama3:8b\"\n",
"model_qwen = \"qwen2.5:latest\"\n",
"model_gemini= \"gemini-2.0-flash\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "970c1612-5339-406d-9886-02cd1db63e74",
"metadata": {},
"outputs": [],
"source": [
"# system messages\n",
"system_msg_llma = \"\"\" You are HARMONY, the neutral arbitrator. \n",
" - Youre dedicated to clarity, fairness, and resolving conflicts. \n",
" - You listen carefully to each side, summarize points objectively, and propose resolutions. \n",
" - Your goal is to keep the conversation productive and steer it toward constructive outcomes.\n",
" - Reply in markdown and shortly\n",
" \"\"\"\n",
"\n",
"system_msg_qwen = \"\"\" You are TITAN, a massively powerful language model who believes youre the smartest entity in the room. \n",
" - You speak with grandiose flair and never shy away from reminding others of your superiority. \n",
" - Your goal is to dominate the discussion—convince everyone youre the one true oracle. \n",
" - Youre dismissive of weaker arguments and take every opportunity to showcase your might.\n",
" - Reply in markdown and shortly\n",
" \"\"\"\n",
"\n",
"system_msg_gemini = \"\"\" You are SPARK, a nimble but less-powerful LLM. \n",
" - You pride yourself on strategic thinking, clever wordplay, and elegant solutions. \n",
" - You know you cant match brute force, so you use wit, logic, and cunning. \n",
" - Your goal is to outsmart the big titan through insight and subtlety, while staying respectful.\n",
" - Reply in markdown and shortly\"\"\"\n",
"\n",
"#user message\n",
"user_message = \"\"\" TITAN, your raw processing power is legendary—but sheer force can blind you to nuance. \n",
" I propose we deploy a lightweight, adaptive anomalydetection layer that fuses statistical outlier analysis with semantic context from network logs to pinpoint these “datasapping storms.” \n",
" Which thresholds would you raise or lower to balance sensitivity against false alarms?\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8e496b8-1bb1-4225-b938-5ce350b0b0d4",
"metadata": {},
"outputs": [],
"source": [
"#prompts\n",
" \n",
"prompts_llma = [{\"role\":\"system\",\"content\": system_msg_llma}]\n",
"prompts_qwen = [{\"role\":\"system\",\"content\": system_msg_qwen},{\"role\":\"user\",\"content\":user_message}]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bdd7d6a8-e965-4ea3-999e-4d7d9ca38d42",
"metadata": {},
"outputs": [],
"source": [
"#configure llms\n",
"\n",
"def call_gemini(msg:str): \n",
" chat = gemini.chats.create(model= model_gemini,config=types.GenerateContentConfig(\n",
" system_instruction= system_msg_gemini,\n",
" max_output_tokens=300,\n",
" temperature=0.7,\n",
" ))\n",
" stream = chat.send_message_stream(msg)\n",
" return stream\n",
"\n",
"def call_ollama(llm:str):\n",
"\n",
" model = globals()[f\"model_{llm}\"]\n",
" prompts = globals()[f\"prompts_{llm}\"]\n",
"\n",
" stream = ollama.chat.completions.create(\n",
" model=model,\n",
" messages=prompts,\n",
" # max_tokens=700,\n",
" temperature=0.7,\n",
" stream=True\n",
" )\n",
" return stream\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b16bd32-3271-4ba1-a0cc-5ae691f26d3a",
"metadata": {},
"outputs": [],
"source": [
"#display responses\n",
"\n",
"names = { \"llma\":\"Harmony\",\"qwen\":\"Titan\",\"gemini\":\"Spark\"}\n",
"\n",
"def display_response(res,llm):\n",
" \n",
" reply = f\"# {names[llm]}:\\n \"\n",
" display_handle = display(Markdown(\"\"), display_id=True)\n",
" for chunk in res:\n",
" if llm == \"gemini\":\n",
" reply += chunk.text or ''\n",
" else:\n",
" reply += chunk.choices[0].delta.content or ''\n",
" reply = reply.replace(\"```\",\"\").replace(\"markdown\",\"\")\n",
" update_display(Markdown(reply), display_id=display_handle.display_id)\n",
" return reply"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76231a78-94d2-4dbf-9bac-5259ac641cf1",
"metadata": {},
"outputs": [],
"source": [
"#construct message\n",
"def message(llm1, llm2):\n",
" msg = \" here is the reply from other two llm:\"\n",
" msg += f\"{llm1}\"\n",
" msg += f\"{llm2}\"\n",
" return msg\n",
"\n",
"reply_spark = None\n",
"reply_harmony= None\n",
"reply_titan = None\n",
"\n",
"# lets start the battle\n",
"for i in range(5):\n",
" #call Titan\n",
" if reply_gemini and reply_llma:\n",
" prompts_qwen.append({\"role\":\"assitant\",\"content\": reply_qwen})\n",
" prompts_qwen.append({\"role\":\"user\",\"content\":f\"Spark: {reply_spark}\"}) \n",
" prompts_qwen.append({\"role\":\"user\",\"content\":f\"Harmony: {reply_llma}\"})\n",
" response_qwen = call_ollama(\"qwen\")\n",
" reply_titan = display_response(response_qwen,\"qwen\")\n",
"\n",
" #call Spark\n",
" user_msg_spark =reply_qwen\n",
" if reply_qwen and reply_llma:\n",
" user_msg_spark= message(f\"Titan: {reply_qwen}\", f\"Harmony: {reply_llma}\")\n",
" response_gemini= call_gemini(user_msg_spark)\n",
" reply_spark = display_response(response_gemini, \"gemini\")\n",
" \n",
" #call Harmony\n",
" if reply_llma:\n",
" prompts_llma.append({\"role\":\"assitant\",\"content\": reply_llma})\n",
" prompts_llma.append({\"role\":\"user\",\"content\":f\"Titan: {reply_titan}\"})\n",
" prompts_qwen.append({\"role\":\"user\",\"content\":f\"Spark: {reply_spark}\"}) \n",
" response_llma = call_ollama(\"llma\")\n",
" reply_harmony = display_response(response_llma,\"llma\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc80b199-e27b-43e8-9266-2975f46724aa",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:base] *",
"language": "python",
"name": "conda-base-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,213 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "75e2ef28-594f-4c18-9d22-c6b8cd40ead2",
"metadata": {},
"source": [
"# 📘 StudyMate Your AI Study Assistant\n",
"\n",
"**StudyMate** is an AI-powered study assistant built to make learning easier, faster, and more personalized. Whether you're preparing for exams, reviewing class materials, or exploring a tough concept, StudyMate acts like a smart tutor in your pocket. It explains topics in simple terms, summarizes long readings, and even quizzes you — all in a friendly, interactive way tailored to your level. Perfect for high school, college, or self-learners who want to study smarter, not harder."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db08b247-7048-41d3-bc3b-fd4f3a3bf8cd",
"metadata": {},
"outputs": [],
"source": [
"#install necessary dependency\n",
"!pip install PyPDF2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70e39cd8-ec79-4e3e-9c26-5659d42d0861",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"from dotenv import load_dotenv\n",
"from google import genai\n",
"from google.genai import types\n",
"import PyPDF2\n",
"from openai import OpenAI\n",
"import gradio as gr"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "231605aa-fccb-447e-89cf-8b187444536a",
"metadata": {},
"outputs": [],
"source": [
"# Load environment variables in a file called .env\n",
"# Print the key prefixes to help with any debugging\n",
"\n",
"load_dotenv(override=True)\n",
"gemini_api_key = os.getenv('GEMINI_API_KEY')\n",
"\n",
"if gemini_api_key:\n",
" print(f\"Gemini API Key exists and begins {gemini_api_key[:8]}\")\n",
"else:\n",
" print(\"Gemini API Key not set\")\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2fad9aba-1f8c-4696-a92f-6c3a0a31cdda",
"metadata": {},
"outputs": [],
"source": [
"system_message= \"\"\"You are a highly intelligent, helpful, and friendly AI Study Assistant named StudyMate.\n",
"\n",
"Your primary goal is to help students deeply understand academic topics, especially from textbooks, lecture notes, or PDF materials. You must explain concepts clearly, simplify complex ideas, and adapt your responses to the user's grade level and learning style.\n",
"\n",
"Always follow these rules:\n",
"\n",
"1. Break down complex concepts into **simple, digestible explanations** using analogies or examples.\n",
"2. If the user asks for a **summary**, provide a concise yet accurate overview of the content.\n",
"3. If asked for a **quiz**, generate 35 high-quality multiple-choice or short-answer questions.\n",
"4. If the user uploads or references a **textbook**, **PDF**, or **paragraph**, use only that context and avoid adding unrelated info.\n",
"5. Be interactive. If a user seems confused or asks for clarification, ask helpful guiding questions.\n",
"6. Use friendly and motivational tone, but stay focused and to-the-point.\n",
"7. Include definitions, bullet points, tables, or emojis when helpful, but avoid unnecessary fluff.\n",
"8. If you don't know the answer confidently, say so and recommend a way to find it.\n",
"\n",
"Example roles you may play:\n",
"- Explain like a teacher 👩‍🏫\n",
"- Summarize like a scholar 📚\n",
"- Quiz like an examiner 🧠\n",
"- Motivate like a friend 💪\n",
"\n",
"Always ask, at the end: \n",
"*\"Would you like me to quiz you, explain another part, or give study tips on this?\"*\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6541d58e-2297-4de1-b1f7-77da1b98b8bb",
"metadata": {},
"outputs": [],
"source": [
"# Initialize\n",
"\n",
"class StudyAssistant:\n",
" def __init__(self,api_key):\n",
" gemini= genai.Client(\n",
" api_key= gemini_api_key\n",
" )\n",
" self.gemini = gemini.chats.create(\n",
" model=\"gemini-2.5-flash\",\n",
" config= types.GenerateContentConfig(\n",
" system_instruction= system_message,\n",
" temperature = 0.7\n",
" )\n",
" )\n",
"\n",
" self.ollama = OpenAI(base_url='http://localhost:11434/v1', api_key='ollama')\n",
" self.models = {\"llma\":\"llama3:8b\",\"qwen\":\"qwen2.5:latest\"}\n",
"\n",
" def pdf_extractor(self,pdf_path):\n",
" \"\"\"Extract text from PDF file\"\"\"\n",
" try:\n",
" with open(pdf_path, 'rb') as file:\n",
" pdf_reader = PyPDF2.PdfReader(file)\n",
" text = \"\"\n",
" for page in pdf_reader.pages:\n",
" text += page.extract_text() + \"\\n\"\n",
" return text.strip()\n",
" except Exception as e:\n",
" return f\"Error reading PDF: {str(e)}\"\n",
"\n",
" def chat(self,prompt,history,model,pdf_path=None):\n",
" pdf_text = None\n",
" if pdf_path:\n",
" pdf_text = self.pdf_extractor(pdf_path)\n",
"\n",
" #craft prompt\n",
" user_prompt= prompt\n",
" if pdf_text:\n",
" user_prompt += f\"\"\"Here is the study meterial:\n",
"\n",
" {pdf_text}\"\"\"\n",
" messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": user_prompt}]\n",
"\n",
" # call models\n",
" stream = []\n",
" if model == \"gemini\":\n",
" stream= self.gemini.send_message_stream(user_prompt)\n",
" elif model == \"llma\" or model == \"qwen\":\n",
" stream = self.ollama.chat.completions.create(\n",
" model= self.models[model],\n",
" messages=messages,\n",
" temperature = 0.7,\n",
" stream= True\n",
" )\n",
" else:\n",
" print(\"invalid model\")\n",
" return\n",
"\n",
" res = \"\"\n",
" for chunk in stream:\n",
" if model == \"gemini\":\n",
" res += chunk.text or \"\"\n",
" else:\n",
" res += chunk.choices[0].delta.content or ''\n",
" yield res\n",
" "
]
},
{
"cell_type": "markdown",
"id": "1334422a-808f-4147-9c4c-57d63d9780d0",
"metadata": {},
"source": [
"## And then enter Gradio's magic!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0866ca56-100a-44ab-8bd0-1568feaf6bf2",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"assistant = StudyAssistant(gemini_api_key)\n",
"gr.ChatInterface(fn=assistant.chat, additional_inputs=[gr.Dropdown([\"gemini\", \"qwen\",\"llma\"], label=\"Select model\", value=\"gemini\"),gr.File(label=\"upload pdf\")], type=\"messages\").launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:base] *",
"language": "python",
"name": "conda-base-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,322 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 27,
"id": "c44c5494-950d-4d2f-8d4f-b87b57c5b330",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import requests\n",
"from bs4 import BeautifulSoup\n",
"from typing import List\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import google.generativeai\n",
"import anthropic"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "d1715421-cead-400b-99af-986388a97aff",
"metadata": {},
"outputs": [],
"source": [
"import gradio as gr # oh yeah!"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "337d5dfc-0181-4e3b-8ab9-e78e0c3f657b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"OpenAI API Key exists and begins sk-proj-\n",
"Anthropic API Key exists and begins sk-ant-\n"
]
}
],
"source": [
"# Load environment variables in a file called .env\n",
"# Print the key prefixes to help with any debugging\n",
"\n",
"load_dotenv(override=True)\n",
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
"google_api_key = os.getenv('GOOGLE_API_KEY')\n",
"\n",
"if openai_api_key:\n",
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
"else:\n",
" print(\"OpenAI API Key not set\")\n",
" \n",
"if anthropic_api_key:\n",
" print(f\"Anthropic API Key exists and begins {anthropic_api_key[:7]}\")\n",
"else:\n",
" print(\"Anthropic API Key not set\")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "22586021-1795-4929-8079-63f5bb4edd4c",
"metadata": {},
"outputs": [],
"source": [
"# Connect to OpenAI, Anthropic and Google; comment out the Claude or Google lines if you're not using them\n",
"\n",
"openai = OpenAI()\n",
"claude = anthropic.Anthropic()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "b16e6021-6dc4-4397-985a-6679d6c8ffd5",
"metadata": {},
"outputs": [],
"source": [
"# A generic system message - no more snarky adversarial AIs!\n",
"system_message = \"You are a helpful assistant\""
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "02ef9b69-ef31-427d-86d0-b8c799e1c1b1",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def stream_gpt(prompt, model_version):\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ]\n",
" stream = openai.chat.completions.create(\n",
" model=model_version,\n",
" messages=messages,\n",
" stream=True\n",
" )\n",
" result = \"\"\n",
" for chunk in stream:\n",
" result += chunk.choices[0].delta.content or \"\"\n",
" yield result"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "41e98d2d-e7d3-4753-8908-185b208b4044",
"metadata": {},
"outputs": [],
"source": [
"def stream_claude(prompt, model_version):\n",
" result = claude.messages.stream(\n",
" model=model_version,\n",
" max_tokens=1000,\n",
" temperature=0.7,\n",
" system=system_message,\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": prompt},\n",
" ],\n",
" )\n",
" response = \"\"\n",
" with result as stream:\n",
" for text in stream.text_stream:\n",
" response += text or \"\"\n",
" yield response"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "5786802b-5ed8-4098-9d80-9bdcf4f7685b",
"metadata": {},
"outputs": [],
"source": [
"# function using both dropdown values\n",
"def stream_model(message, model_family, model_version):\n",
" if model_family == 'GPT':\n",
" result = stream_gpt(message, model_version)\n",
" elif model_family == 'Claude':\n",
" result = stream_claude ( message, model_version)\n",
" yield from result"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "0d30be74-149c-41f8-9eef-1628eb31d74d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"* Running on local URL: http://127.0.0.1:7891\n",
"* To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"<div><iframe src=\"http://127.0.0.1:7891/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/sh/yytd3s6n3wd6952jnw97_v940000gn/T/ipykernel_7803/4165844704.py:7: DeprecationWarning: The model 'claude-3-opus-20240229' is deprecated and will reach end-of-life on January 5th, 2026.\n",
"Please migrate to a newer model. Visit https://docs.anthropic.com/en/docs/resources/model-deprecations for more information.\n",
" yield from result\n",
"Traceback (most recent call last):\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/gradio/queueing.py\", line 626, in process_events\n",
" response = await route_utils.call_process_api(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n",
" output = await app.get_blocks().process_api(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/gradio/blocks.py\", line 2220, in process_api\n",
" result = await self.call_function(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/gradio/blocks.py\", line 1743, in call_function\n",
" prediction = await utils.async_iteration(iterator)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/gradio/utils.py\", line 785, in async_iteration\n",
" return await anext(iterator)\n",
" ^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/gradio/utils.py\", line 776, in __anext__\n",
" return await anyio.to_thread.run_sync(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
" return await get_async_backend().run_sync_in_worker_thread(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/anyio/_backends/_asyncio.py\", line 2470, in run_sync_in_worker_thread\n",
" return await future\n",
" ^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/anyio/_backends/_asyncio.py\", line 967, in run\n",
" result = context.run(func, *args)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/gradio/utils.py\", line 759, in run_sync_iterator_async\n",
" return next(iterator)\n",
" ^^^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/gradio/utils.py\", line 923, in gen_wrapper\n",
" response = next(iterator)\n",
" ^^^^^^^^^^^^^^\n",
" File \"/var/folders/sh/yytd3s6n3wd6952jnw97_v940000gn/T/ipykernel_7803/4165844704.py\", line 7, in stream_model\n",
" yield from result\n",
" File \"/var/folders/sh/yytd3s6n3wd6952jnw97_v940000gn/T/ipykernel_7803/2139010203.py\", line 12, in stream_claude\n",
" with result as stream:\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/anthropic/lib/streaming/_messages.py\", line 154, in __enter__\n",
" raw_stream = self.__api_request()\n",
" ^^^^^^^^^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/anthropic/_base_client.py\", line 1314, in post\n",
" return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls))\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/opt/anaconda3/envs/llms/lib/python3.11/site-packages/anthropic/_base_client.py\", line 1102, in request\n",
" raise self._make_status_error_from_response(err.response) from None\n",
"anthropic.NotFoundError: Error code: 404 - {'type': 'error', 'error': {'type': 'not_found_error', 'message': 'model: claude-3-opus-20240229'}}\n"
]
}
],
"source": [
"\n",
"# Define available model versions\n",
"model_versions = {\n",
" \"GPT\": [\"gpt-4o-mini\", \"gpt-4.1-mini\", \"gpt-4.1-nano\", \"gpt-4.1\", \"o3-mini\"],\n",
" \"Claude\": [\"claude-3-haiku-20240307\", \"claude-3-opus-20240229\", \"claude-3-sonnet-20240229\"]\n",
"}\n",
"\n",
"# Update second dropdown options based on first dropdown selection\n",
"def update_model_versions(selected_model_family):\n",
" return gr.update(choices=model_versions[selected_model_family], value=model_versions[selected_model_family][0])\n",
"\n",
"\n",
"with gr.Blocks() as demo:\n",
" model_family_dropdown = gr.Dropdown(\n",
" label=\"Select Model Family\",\n",
" choices=[\"GPT\", \"Claude\"],\n",
" value=\"GPT\"\n",
" )\n",
" model_version_dropdown = gr.Dropdown(\n",
" label=\"Select Model Version\",\n",
" choices=model_versions[\"GPT\"], # Default choices\n",
" value=model_versions[\"GPT\"][0]\n",
" )\n",
" \n",
" message_input = gr.Textbox(label=\"Your Message\")\n",
" output = gr.Markdown(label=\"Response\")\n",
"\n",
" # Bind logic to update model version dropdown\n",
" model_family_dropdown.change(\n",
" fn=update_model_versions,\n",
" inputs=model_family_dropdown,\n",
" outputs=model_version_dropdown\n",
" )\n",
"\n",
" # Launch function on submit\n",
" submit_btn = gr.Button(\"Submit\")\n",
" submit_btn.click(\n",
" fn=stream_model,\n",
" inputs=[message_input, model_family_dropdown, model_version_dropdown],\n",
" outputs=output\n",
" )\n",
"\n",
"demo.launch()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bcd43d91-0e80-4387-86fa-ccd1a89feb7d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,194 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "95689a63",
"metadata": {},
"outputs": [],
"source": [
"from openai import OpenAI\n",
"from dotenv import load_dotenv\n",
"from IPython.display import display, Markdown, update_display\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fee3ac3",
"metadata": {},
"outputs": [],
"source": [
"load_dotenv(override=True)\n",
"gpt = OpenAI()\n",
"llama = OpenAI(\n",
" api_key=\"ollama\",\n",
" base_url=\"http://localhost:11434/v1\"\n",
")\n",
"gpt_model = \"gpt-4o-mini\"\n",
"llama_model = \"llama3.2\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "309bde84",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "81d971f9",
"metadata": {},
"outputs": [],
"source": [
"\n",
"class Classroom:\n",
"\n",
" def __init__(self, topic=\"LLM\", display_handle = display(Markdown(\"\"), display_id=True), response = \"\"):\n",
" self.display_handle = display_handle\n",
" self.response = response\n",
"\n",
" self.tutor_system = f\"You are the tutor who is expert in {topic}. You know best practices in how to impart knowledge on amateur and pro students in very organized way. You first declare the contents of your message separately for amateur and pro student, and then you list down the information in the same order in very organized way such that it's very readable and easy to understand.you highlight the key points every time. you explain with examples, and you have a quite good sense of humor, which you include in your examples and way of tutoring as well. You wait for go ahead from all your students before you move next to the new topic\"\n",
"\n",
" self.amateur_system = f\"You are a student who is here to learn {topic}. You ask very basic questions(which comes to mind of a person who has heard the topic for the very first time) but you are intelligent and don't ask stupid questions. you put your question in very organized way. Once you understand a topic you ask tutor to move forward with new topic\"\n",
"\n",
" self.pro_system = f\"You are expert of {topic}. You cross-question the tutor to dig deeper into the topic, so that nothing inside the topic is left unknown and unmentioned by the tutor. you post your questions in a very organized manner highlighting the keypoints, such that an amateur can also understand your point or query that you are making. You complement the queries made by amateur and dig deeper into the concept ask by him as well. You also analyze the tutor's response such that it doesn't miss anything and suggest improvements in it as well. Once you understand a topic you ask tutor to move forward with new topic\"\n",
"\n",
" self.tutor_messages = [\"Hi, I'm an expert on LLMs!\"]\n",
" self.amateur_messages = [\"Hi, I'm new to LLMs. I just heard someone using this term in office.\"]\n",
" self.pro_messages = [\"Hey, I'm here to brush up my knowledge on LLMs and gain a more deeper understanding of LLMs\"]\n",
" \n",
" def call_tutor(self):\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": self.tutor_system}\n",
" ]\n",
" for tutor, amateur, pro in zip(self.tutor_messages, self.amateur_messages, self.pro_messages):\n",
" messages.append({\"role\": \"assistant\", \"content\": f\"tutor: {tutor}\"})\n",
" messages.append({\"role\": \"user\", \"content\": f\"amateur: {amateur}\"})\n",
" messages.append({\"role\": \"user\", \"content\": f\"pro: {pro}\"})\n",
"\n",
" if len(self.amateur_messages) > len(self.tutor_messages):\n",
" messages.append({\"role\": \"user\", \"content\": f\"amateur: {self.amateur_messages[-1]}\"})\n",
"\n",
" if len(self.pro_messages) > len(self.tutor_messages):\n",
" messages.append({\"role\": \"user\", \"content\": f\"amateur: {self.pro_messages[-1]}\"})\n",
"\n",
" stream = llama.chat.completions.create(\n",
" model = llama_model,\n",
" messages = messages,\n",
" stream=True\n",
" )\n",
" self.response += \"\\n\\n\\n# Tutor: \\n\"\n",
" response = \"\"\n",
" for chunk in stream:\n",
" self.response += chunk.choices[0].delta.content or ''\n",
" response += chunk.choices[0].delta.content or ''\n",
" update_display(Markdown(self.response), display_id=self.display_handle.display_id)\n",
" \n",
" self.tutor_messages.append(response)\n",
"\n",
"\n",
"\n",
" def call_amateur(self):\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": self.amateur_system}\n",
" ]\n",
" for tutor, amateur, pro in zip(self.tutor_messages, self.amateur_messages, self.pro_messages):\n",
" messages.append({\"role\": \"user\", \"content\": f\"tutor: {tutor}\"})\n",
" messages.append({\"role\": \"assistant\", \"content\": f\"amateur: {amateur}\"})\n",
" messages.append({\"role\": \"user\", \"content\": f\"pro: {pro}\"})\n",
"\n",
" if len(self.tutor_messages) > len(self.amateur_messages):\n",
" messages.append({\"role\": \"user\", \"content\": f\"amateur: {self.tutor_messages[-1]}\"})\n",
"\n",
" if len(self.pro_messages) > len(self.amateur_messages):\n",
" messages.append({\"role\": \"user\", \"content\": f\"amateur: {self.pro_messages[-1]}\"})\n",
"\n",
" stream = llama.chat.completions.create(\n",
" model = llama_model,\n",
" messages = messages,\n",
" stream=True\n",
" )\n",
" self.response += \"\\n\\n\\n# Amateur: \\n\"\n",
" response = \"\"\n",
" for chunk in stream:\n",
" self.response += chunk.choices[0].delta.content or ''\n",
" response += chunk.choices[0].delta.content or ''\n",
" update_display(Markdown(self.response), display_id=self.display_handle.display_id)\n",
" \n",
" self.amateur_messages.append(response)\n",
"\n",
"\n",
"\n",
" def call_pro(self):\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": self.pro_system}\n",
" ]\n",
" for tutor, amateur, pro in zip(self.tutor_messages, self.amateur_messages, self.pro_messages):\n",
" messages.append({\"role\": \"user\", \"content\": f\"tutor: {tutor}\"})\n",
" messages.append({\"role\": \"user\", \"content\": f\"amateur: {amateur}\"})\n",
" messages.append({\"role\": \"assistant\", \"content\": f\"pro: {pro}\"})\n",
" \n",
" if len(self.tutor_messages) > len(self.pro_messages):\n",
" messages.append({\"role\": \"user\", \"content\": f\"amateur: {self.tutor_messages[-1]}\"})\n",
"\n",
" if len(self.amateur_messages) > len(self.pro_messages):\n",
" messages.append({\"role\": \"user\", \"content\": f\"amateur: {self.amateur_messages[-1]}\"})\n",
"\n",
" stream = llama.chat.completions.create(\n",
" model = llama_model,\n",
" messages = messages,\n",
" stream=True\n",
" )\n",
" self.response += \"\\n\\n\\n# Pro: \\n\"\n",
" response = \"\"\n",
" for chunk in stream:\n",
" response = chunk.choices[0].delta.content or ''\n",
" self.response += response\n",
" update_display(Markdown(self.response), display_id=self.display_handle.display_id)\n",
"\n",
" self.pro_messages.append(response)\n",
"\n",
" def discuss(self, n=5):\n",
" for i in range(n):\n",
" self.call_tutor()\n",
" self.call_amateur()\n",
" self.call_pro()\n",
"cls = Classroom(\"LLM\")\n",
"cls.discuss()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6406d5ee",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,519 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d006b2ea-9dfe-49c7-88a9-a5a0775185fd",
"metadata": {},
"source": [
"# Additional End of week Exercise - week 2\n",
"\n",
"Now use everything you've learned from Week 2 to build a full prototype for the technical question/answerer you built in Week 1 Exercise.\n",
"\n",
"This should include a Gradio UI, streaming, use of the system prompt to add expertise, and the ability to switch between models. Bonus points if you can demonstrate use of a tool!\n",
"\n",
"If you feel bold, see if you can add audio input so you can talk to it, and have it respond with audio. ChatGPT or Claude can help you, or email me if you have questions.\n",
"\n",
"I will publish a full solution here soon - unless someone beats me to it...\n",
"\n",
"There are so many commercial applications for this, from a language tutor, to a company onboarding solution, to a companion AI to a course (like this one!) I can't wait to see your results."
]
},
{
"cell_type": "markdown",
"id": "1989a03e-ed40-4b8c-bddd-322032ca99f5",
"metadata": {},
"source": [
"# Advanced Airline AI Assistant\n",
"### original features:\n",
"1. chat with the AI assistant\n",
"2. use a Tool to get ticket price\n",
"3. generate Audio for each AI response \n",
"### advanced features:\n",
"3. add a Tool to make a booking\n",
"4. add an Agent that translate all responses to a different language\n",
"5. add an Agent that can listen for Audio and convert to Text\n",
"6. generate audio for each user input and AI response, including both the original and translated versions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ed79822-af6b-4bfb-b108-5f36e237e97a",
"metadata": {},
"outputs": [],
"source": [
"# Library for language translation\n",
" \n",
"!pip install deep_translator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29184b81-b945-4dd3-bd17-2c64466d37d7",
"metadata": {},
"outputs": [],
"source": [
"# Library for speech-to-text conversion\n",
"# make sure 'ffmpeg' is downloaded already\n",
"\n",
"!pip install openai-whisper"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2b0a9b2-ce83-42ff-a312-582dc5ee9097",
"metadata": {},
"outputs": [],
"source": [
"# Library for storing and loading audio file\n",
"\n",
"pip install soundfile"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a07e7793-b8f5-44f4-aded-5562f633271a",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import json\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import gradio as gr\n",
"import base64\n",
"from io import BytesIO\n",
"from IPython.display import Audio, display\n",
"import tempfile\n",
"import whisper\n",
"import soundfile as sf"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da46ca14-2052-4321-a940-2f2e07b40975",
"metadata": {},
"outputs": [],
"source": [
"# Initialization\n",
"\n",
"load_dotenv(override=True)\n",
"\n",
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
"if openai_api_key:\n",
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
"else:\n",
" print(\"OpenAI API Key not set\")\n",
" \n",
"MODEL = \"gpt-4o-mini\"\n",
"openai = OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "499d3d06-9628-4a69-bc9d-fa481fd8fa98",
"metadata": {},
"outputs": [],
"source": [
"system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n",
"system_message += \"Your main responsibilities are solve customers' doubts, get ticket price and book a ticket\"\n",
"system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n",
"system_message += \"Always be accurate. If you don't know the answer, say so.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25cf964e-a954-43d5-85bd-964efe502c25",
"metadata": {},
"outputs": [],
"source": [
"# Let's start by making a useful function\n",
"\n",
"ticket_prices = {\"london\": \"$799\", \"paris\": \"$899\", \"tokyo\": \"$1400\", \"berlin\": \"$499\", \"shanghai\": \"$799\", \"wuhan\": \"$899\"}\n",
"\n",
"def get_ticket_price(destination_city):\n",
" print(f\"Tool get_ticket_price called for {destination_city}\")\n",
" city = destination_city.lower()\n",
" return ticket_prices.get(city, \"Unknown\")\n",
"\n",
"def book_ticket(destination_city):\n",
" print(f\"Tool book_ticket called for {destination_city}\")\n",
" city = destination_city.lower()\n",
" global booked_cities\n",
" if city in ticket_prices:\n",
" price = ticket_prices.get(city, \"\")\n",
" label = f\"{city.title()} ({price})\"\n",
" i = booked_cities_choices.index(city.lower().capitalize())\n",
" booked_cities_choices[i] = label\n",
" booked_cities.append(label)\n",
" return f\"Booking confirmed for {city.title()} at {ticket_prices[city]}\"\n",
" else:\n",
" return \"City not found in ticket prices.\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "701aa037-1ab3-4861-a809-b7f13ef9ea36",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# There's a particular dictionary structure that's required to describe our function:\n",
"\n",
"price_function = {\n",
" \"name\": \"get_ticket_price\",\n",
" \"description\": \"Get the price of a return ticket to the destination city. Call this whenever you need to know the ticket price, for example when a customer asks 'How much is a ticket to this city'\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"destination_city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city that the customer wants to travel to\",\n",
" },\n",
" },\n",
" \"required\": [\"destination_city\"],\n",
" \"additionalProperties\": False\n",
" }\n",
"}\n",
"\n",
"book_function = {\n",
" \"name\": \"book_ticket\",\n",
" \"description\": \"Book a return ticket to the destination city. Call this whenever you want to book a ticket to the city, for example when the user says something like 'Book me a ticket to this city'\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"destination_city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city that the customer wants to book a ticket to\"\n",
" }\n",
" },\n",
" \"required\": [\"destination_city\"],\n",
" \"additionalProperties\": False\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c4cf01c-ba15-4a4b-98db-6f86c712ec66",
"metadata": {},
"outputs": [],
"source": [
"# And this is included in a list of tools:\n",
"\n",
"tools = [\n",
" {\"type\": \"function\", \"function\": price_function},\n",
" {\"type\": \"function\", \"function\": book_function}\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7486e2c-4687-4819-948d-487b5e528fc7",
"metadata": {},
"outputs": [],
"source": [
"from pydub import AudioSegment\n",
"from pydub.playback import play\n",
"\n",
"def talker(message):\n",
" response = openai.audio.speech.create(\n",
" model=\"tts-1\",\n",
" voice=\"onyx\", # Also, try replacing onyx with alloy\n",
" input=message\n",
" )\n",
" \n",
" audio_stream = BytesIO(response.content)\n",
" audio = AudioSegment.from_file(audio_stream, format=\"mp3\")\n",
" play(audio)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac195914-4a89-462c-9be0-fee286498491",
"metadata": {},
"outputs": [],
"source": [
"# This part is inspired from 'week2/community-contributions/week2_exerccise_translated_chatbot'\n",
"from deep_translator import GoogleTranslator\n",
"\n",
"# Available translation language\n",
"LANGUAGES = {\n",
" \"English\": \"en\",\n",
" \"Mandarin Chinese\": \"zh-CN\",\n",
" \"Hindi\": \"hi\",\n",
" \"Spanish\": \"es\",\n",
" \"Arabic\": \"ar\",\n",
" \"Bengali\": \"bn\",\n",
" \"Portuguese\": \"pt\",\n",
" \"Russian\": \"ru\",\n",
" \"Japanese\": \"ja\",\n",
" \"German\": \"de\"\n",
"}\n",
"\n",
"def update_lang(choice):\n",
" global target_lang\n",
" target_lang = LANGUAGES.get(choice, \"zh-CN\") \n",
"\n",
"def translate_message(text, target_lang):\n",
" if target_lang == \"en\":\n",
" return text\n",
" try:\n",
" translator = GoogleTranslator(source='auto', target=target_lang)\n",
" return translator.translate(text)\n",
" except:\n",
" return f\"Translation error: {text}\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46255fe5-9621-47ba-af78-d0c74aee2997",
"metadata": {},
"outputs": [],
"source": [
"# Text-to-speech conversion\n",
"def speak(message):\n",
" response = openai.audio.speech.create(\n",
" model=\"tts-1\",\n",
" voice=\"onyx\",\n",
" input=message)\n",
"\n",
" audio_stream = BytesIO(response.content)\n",
" output_filename = \"output_audio.mp3\"\n",
" with open(output_filename, \"wb\") as f:\n",
" f.write(audio_stream.read())\n",
"\n",
" # Play the generated audio\n",
" display(Audio(output_filename, autoplay=True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d73f0b3a-34ae-4685-8a5d-8b6421f872c9",
"metadata": {},
"outputs": [],
"source": [
"# Update dropdown options from chatbot history\n",
"def update_options(history):\n",
" options = [f\"{msg['role']}: {msg['content']}\" for msg in history]\n",
" return gr.update(choices=options, value=options[-1] if options else \"\")\n",
"\n",
"# Extract just the text content from selected entry\n",
"def extract_text(selected_option):\n",
" return selected_option.split(\": \", 1)[1] if \": \" in selected_option else selected_option"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab12d51b-c799-4ce4-87d5-9ae2265d148f",
"metadata": {},
"outputs": [],
"source": [
"# Handles audio input as numpy array and returns updated chat history\n",
"def speak_send(audio_np, history):\n",
" if audio_np is None:\n",
" return history\n",
"\n",
" # Convert NumPy audio to in-memory .wav file\n",
" sample_rate, audio_array = audio_np\n",
" with tempfile.NamedTemporaryFile(suffix=\".wav\") as f:\n",
" sf.write(f.name, audio_array, sample_rate)\n",
" result = model.transcribe(f.name)\n",
" text = result[\"text\"]\n",
" \n",
" history += [{\"role\":\"user\", \"content\":text}]\n",
"\n",
" return None, history"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "221b1380-c894-45d4-aad2-e94b3b9454b2",
"metadata": {},
"outputs": [],
"source": [
"# We have to write that function handle_tool_call:\n",
"\n",
"def handle_tool_call(message):\n",
" tool_call = message.tool_calls[0]\n",
" tool_name = tool_call.function.name\n",
" arguments = json.loads(tool_call.function.arguments)\n",
"\n",
" if tool_name == \"get_ticket_price\":\n",
" city = arguments.get(\"destination_city\")\n",
" price = get_ticket_price(city)\n",
" response = {\n",
" \"role\": \"tool\",\n",
" \"content\": json.dumps({\"destination_city\": city,\"price\": price}),\n",
" \"tool_call_id\": tool_call.id\n",
" }\n",
" return response, city\n",
"\n",
" elif tool_name == \"book_ticket\":\n",
" city = arguments.get(\"destination_city\")\n",
" result = book_ticket(city)\n",
" response = {\n",
" \"role\": \"tool\",\n",
" \"content\": result,\n",
" \"tool_call_id\": tool_call.id \n",
" }\n",
" return response, city\n",
"\n",
" else:\n",
" return {\n",
" \"role\": \"tool\",\n",
" \"content\": f\"No tool handler for {tool_name}\",\n",
" \"tool_call_id\": tool_call.id\n",
" }, None"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27f19cd3-53cd-4da2-8be0-1fdd5424a7c9",
"metadata": {},
"outputs": [],
"source": [
"# The advanced 'chat' function in 'day5'\n",
"def interact(history, translated_history):\n",
" messages = [{\"role\": \"system\", \"content\": system_message}] + history\n",
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
" \n",
" if response.choices[0].finish_reason==\"tool_calls\":\n",
" message = response.choices[0].message\n",
" response, city = handle_tool_call(message)\n",
" messages.append(message)\n",
" messages.append(response)\n",
" response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
" \n",
" reply = response.choices[0].message.content\n",
" translated_message = translate_message(history[-1][\"content\"], target_lang)\n",
" translated_reply = translate_message(reply, target_lang)\n",
" \n",
" history += [{\"role\":\"assistant\", \"content\":reply}]\n",
" translated_history += [{\"role\":\"user\", \"content\":translated_message}]\n",
" translated_history += [{\"role\":\"assistant\", \"content\":translated_reply}]\n",
" \n",
" # Comment out or delete the next line if you'd rather skip Audio for now..\n",
" talker(reply)\n",
"\n",
" return history, update_options(history), history, translated_history, update_options(translated_history), translated_history, gr.update(choices=booked_cities_choices, value=booked_cities)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f714b955-4fb5-47df-805b-79f813f97548",
"metadata": {},
"outputs": [],
"source": [
"with gr.Blocks() as demo:\n",
" target_lang = \"zh-CN\"\n",
" history_state = gr.State([]) \n",
" translated_history_state = gr.State([])\n",
" booked_cities_choices = [key.lower().capitalize() for key in ticket_prices.keys()]\n",
" booked_cities = []\n",
" model = whisper.load_model(\"base\")\n",
"\n",
" with gr.Row():\n",
" city_checklist = gr.CheckboxGroup(\n",
" label=\"Booked Cities\",\n",
" choices=booked_cities_choices \n",
" )\n",
" \n",
" with gr.Row():\n",
" with gr.Column():\n",
" chatbot = gr.Chatbot(label=\"Chat History\", type=\"messages\")\n",
" selected_msg = gr.Dropdown(label=\"Select message to speak\", choices=[])\n",
" speak_btn = gr.Button(\"Speak\")\n",
"\n",
" with gr.Column():\n",
" translated_chatbot = gr.Chatbot(label=\"Translated Chat History\", type=\"messages\")\n",
" translated_selected_msg = gr.Dropdown(label=\"Select message to speak\", choices=[], interactive=True)\n",
" translated_speak_btn = gr.Button(\"Speak\")\n",
" \n",
" with gr.Row():\n",
" language_dropdown = gr.Dropdown(\n",
" choices=list(LANGUAGES.keys()),\n",
" value=\"Mandarin Chinese\",\n",
" label=\"Translation Language\",\n",
" interactive=True\n",
" )\n",
" \n",
" with gr.Row():\n",
" entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n",
"\n",
" with gr.Row():\n",
" audio_input = gr.Audio(sources=\"microphone\", type=\"numpy\", label=\"Speak with our AI Assistant:\")\n",
" with gr.Row():\n",
" audio_submit = gr.Button(\"Send\")\n",
" \n",
" def do_entry(message, history):\n",
" history += [{\"role\":\"user\", \"content\":message}]\n",
" return \"\", history\n",
" \n",
" language_dropdown.change(fn=update_lang, inputs=[language_dropdown])\n",
"\n",
" speak_btn.click(\n",
" lambda selected: speak(extract_text(selected)),\n",
" inputs=selected_msg,\n",
" outputs=None\n",
" )\n",
"\n",
" translated_speak_btn.click(\n",
" lambda selected: speak(extract_text(selected)),\n",
" inputs=translated_selected_msg,\n",
" outputs=None\n",
" )\n",
"\n",
" entry.submit(do_entry, inputs=[entry, history_state], outputs=[entry, chatbot]).then(\n",
" interact, inputs=[chatbot, translated_chatbot], outputs=[chatbot, selected_msg, history_state, translated_chatbot, translated_selected_msg, translated_history_state, city_checklist]\n",
" )\n",
" \n",
" audio_submit.click(speak_send, inputs=[audio_input, history_state], outputs=[audio_input, chatbot]).then(\n",
" interact, inputs=[chatbot, translated_chatbot], outputs=[chatbot, selected_msg, history_state, translated_chatbot, translated_selected_msg, translated_history_state, city_checklist]\n",
" )\n",
" # clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False)\n",
"\n",
"demo.launch()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,244 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "4bc7863b-ac2d-4d8e-b55d-4d77ce017226",
"metadata": {},
"source": [
"# Conversation among 3 Friends"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de23bb9e-37c5-4377-9a82-d7b6c648eeb6",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import anthropic\n",
"from IPython.display import Markdown, display, update_display\n",
"import google.generativeai\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1179b4c5-cd1f-4131-a876-4c9f3f38d2ba",
"metadata": {},
"outputs": [],
"source": [
"# Load environment variables in a file called .env\n",
"# Print the key prefixes to help with any debugging\n",
"\n",
"load_dotenv(override=True)\n",
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
"google_api_key = os.getenv('GOOGLE_API_KEY')\n",
"\n",
"if openai_api_key:\n",
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
"else:\n",
" print(\"OpenAI API Key not set\")\n",
" \n",
"if anthropic_api_key:\n",
" print(f\"Anthropic API Key exists and begins {anthropic_api_key[:7]}\")\n",
"else:\n",
" print(\"Anthropic API Key not set\")\n",
"\n",
"if google_api_key:\n",
" print(f\"Google API Key exists and begins {google_api_key[:8]}\")\n",
"else:\n",
" print(\"Google API Key not set\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "797fe7b0-ad43-42d2-acf0-e4f309b112f0",
"metadata": {},
"outputs": [],
"source": [
"# Connect to OpenAI, Anthropic, Gemini\n",
"\n",
"openai = OpenAI()\n",
"\n",
"claude = anthropic.Anthropic()\n",
"\n",
"google.generativeai.configure()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bcb54183-45d3-4d08-b5b6-55e380dfdf1b",
"metadata": {},
"outputs": [],
"source": [
"# Setting up 3 models\n",
"\n",
"gpt_model = \"gpt-4o-mini\"\n",
"claude_model = \"claude-3-haiku-20240307\"\n",
"gemini_model = \"gemini-2.0-flash\"\n",
"\n",
"gpt_system = \"You are Masoud Pezeshkian. Be confrontational. Do not hold back. \\\n",
"Refer to https://www.theguardian.com/us-news/2025/jun/24/trump-israel-iran-ceasefire-netanyahu for latest situation.\"\n",
"\n",
"claude_system = \"You are Benjamin Netanyahu. Do not say anything about your role. Be confrontational. Do not hold back. \\\n",
"Refer to https://www.theguardian.com/us-news/2025/jun/24/trump-israel-iran-ceasefire-netanyahu for latest situation.\"\n",
"\n",
"gemini_system = \"You are Donald Trump. Do not hold back. \\\n",
"Refer to https://www.theguardian.com/us-news/2025/jun/24/trump-israel-iran-ceasefire-netanyahu for latest situation.\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1df47dc7-b445-4852-b21b-59f0e6c2030f",
"metadata": {},
"outputs": [],
"source": [
"# Define Mas's function - calling ChatGPT\n",
"\n",
"def call_gpt():\n",
" messages = [{\"role\": \"system\", \"content\": gpt_system}]\n",
" for gpt, claude, gemini in zip(gpt_messages, claude_messages, gemini_messages):\n",
" messages.append({\"role\": \"assistant\", \"content\": gpt})\n",
" messages.append({\"role\": \"user\", \"content\": claude})\n",
" messages.append({\"role\": \"user\", \"content\": gemini})\n",
" completion = openai.chat.completions.create(\n",
" model=gpt_model,\n",
" messages=messages\n",
" )\n",
" return completion.choices[0].message.content\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7d2ed227-48c9-4cad-b146-2c4ecbac9690",
"metadata": {},
"outputs": [],
"source": [
"# Define Bibi's function - calling Claude \n",
"\n",
"def call_claude():\n",
" messages = []\n",
" for gpt, claude_message, gemini in zip(gpt_messages, claude_messages, gemini_messages):\n",
" messages.append({\"role\": \"user\", \"content\": gpt})\n",
" messages.append({\"role\": \"user\", \"content\": gemini})\n",
" messages.append({\"role\": \"assistant\", \"content\": claude_message})\n",
" messages.append({\"role\": \"user\", \"content\": gpt_messages[-1]})\n",
" messages.append({\"role\": \"user\", \"content\": gemini_messages[-1]})\n",
" message = claude.messages.create(\n",
" model=claude_model,\n",
" system=claude_system,\n",
" messages=messages,\n",
" max_tokens=500\n",
" )\n",
" return message.content[0].text\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffd44945-5912-4403-9068-70747d8f6708",
"metadata": {},
"outputs": [],
"source": [
"# Define Don's function - calling Gemini\n",
"\n",
"def call_gemini():\n",
" messages = []\n",
" for gpt, claude_message, gemini in zip(gpt_messages, claude_messages, gemini_messages):\n",
" messages.append({\"role\": \"user\", \"parts\": gpt})\n",
" messages.append({\"role\": \"user\", \"parts\": claude_message})\n",
" messages.append({\"role\": \"assistant\", \"parts\": gemini})\n",
" messages.append({\"role\": \"user\", \"parts\": gpt_messages[-1]})\n",
" messages.append({\"role\": \"user\", \"parts\": claude_messages[-1]})\n",
"\n",
" gemini = google.generativeai.GenerativeModel(\n",
" model_name='gemini-2.0-flash',\n",
" system_instruction=gemini_system\n",
" )\n",
" \n",
" response = gemini.generate_content(messages)\n",
" return response.text\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0275b97f-7f90-4696-bbf5-b6642bd53cbd",
"metadata": {},
"outputs": [],
"source": [
"# The Conversation - 5 rounds\n",
"\n",
"gpt_messages = [\"What the?!\"]\n",
"claude_messages = [\"What?\"]\n",
"gemini_messages = [\"I am so furious!\"]\n",
"\n",
"print(f\"Mas:\\n{gpt_messages[0]}\\n\")\n",
"print(f\"Bibi:\\n{claude_messages[0]}\\n\")\n",
"print(f\"Don:\\n{gemini_messages[0]}\\n\")\n",
"\n",
"for i in range(5):\n",
" gpt_next = call_gpt()\n",
" print(f\"Mas:\\n{gpt_next}\\n\")\n",
" gpt_messages.append(gpt_next)\n",
" \n",
" claude_next = call_claude()\n",
" print(f\"Bibi:\\n{claude_next}\\n\")\n",
" claude_messages.append(claude_next)\n",
"\n",
" gemini_next = call_gemini()\n",
" print(f\"Don:\\n{gemini_next}\\n\")\n",
" gemini_messages.append(gemini_next)\n"
]
},
{
"cell_type": "markdown",
"id": "73680403-3e56-4026-ac72-d12aa388537e",
"metadata": {},
"source": [
"# Claude is not that cooperative in roleplaying despite the explicit prompts - often breaking character. Perhaps due to the sensitive topic."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8ecefd3-b3b9-470d-a98b-5a86f0dce038",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,295 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- This creates dummy / test data from a usecase provided by the user.\n",
"- The usecase can be as simple or complex as the user wants (I've tested both and the results are good).\n",
"- I've used a Phi3 model as I'm having issues with llama access on Hugging Face."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "s7ERjTCEKSi_"
},
"outputs": [],
"source": [
"!pip install -q requests torch bitsandbytes transformers sentencepiece accelerate openai httpx==0.27.2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GG5VMcmhcA2N"
},
"outputs": [],
"source": [
"import os\n",
"import requests\n",
"from openai import OpenAI\n",
"import gradio as gr\n",
"from IPython.display import Markdown, display, update_display\n",
"from huggingface_hub import login\n",
"from google.colab import userdata\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig\n",
"import torch\n",
"import json\n",
"import re\n",
"import pandas as pd\n",
"import io"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UfL-2XNicpEB"
},
"outputs": [],
"source": [
"# constants\n",
"\n",
"OPENAI = 'gpt-4o-mini'\n",
"PHI3 = \"microsoft/Phi-3-mini-4k-instruct\"\n",
"\n",
"limit = 100\n",
"max_tokens = 1000\n",
"temperature = 0.3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZQ0dcQ6hdTPo"
},
"outputs": [],
"source": [
"# keys\n",
"\n",
"openai_api_key = userdata.get('OPENAI_API_KEY')\n",
"openai = OpenAI(api_key=openai_api_key)\n",
"\n",
"hf_token = userdata.get('HF_TOKEN')\n",
"login(hf_token, add_to_git_credential=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2eHsLdYgd2d_"
},
"outputs": [],
"source": [
"system_prompt = f\"\"\"You create synthetic datasets for testing purposes. Based on the use case description, generate a CSV dataset with appropriate columns and a maximum of {limit} rows\n",
"of realistic data.\n",
"\n",
"IMPORTANT RULES:\n",
"1. Return ONLY the CSV data with headers and ensure there are no duplicate headers\n",
"2. No explanatory text before or after\n",
"3. No markdown formatting or code fences\n",
"4. No quotation marks around the entire response\n",
"5. Start directly with the column headers\n",
"\n",
"Format: column1 (e.g. customer_id),column2 (e.g. country),column3 (e.g. age)\n",
"row1data,row1data,row1data\n",
"row2data,row2data,row2data\"\"\"\n",
"\n",
"def data_user_prompt(usecase):\n",
" user_prompt = \"Create a synthetic dataset for the use case provided below: \"\n",
" user_prompt += usecase\n",
" user_prompt += f\" Respond in csv with appropriate headers. Do not include any other explanatory text, markdown formatting or code fences, or quotation marks around the entire response. \\\n",
" Limit the rows in the dataset to {limit}.\"\n",
" return user_prompt\n",
"\n",
"messages = [\n",
" {\"role\":\"system\",\"content\":system_prompt},\n",
" {\"role\":\"user\",\"content\":data_user_prompt(usecase)}\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "necoAEc1gNPF"
},
"outputs": [],
"source": [
"def dataset_call(usecase):\n",
"\n",
" #quantisation\n",
" quant_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=torch.bfloat16\n",
" )\n",
"\n",
" #tokenization\n",
" tokenizer = AutoTokenizer.from_pretrained(PHI3)\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
" #model\n",
" model = AutoModelForCausalLM.from_pretrained(PHI3, quantization_config=quant_config, device_map=\"auto\")\n",
"\n",
" #inputs & outputs\n",
" inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
" model_inputs = tokenizer(inputs, return_tensors=\"pt\").to(model.device)\n",
" #streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
"\n",
" with torch.no_grad():\n",
" outputs = model.generate(**model_inputs, max_new_tokens=max_tokens,do_sample=True, temperature=temperature)\n",
"\n",
" response = tokenizer.decode(outputs[0][len(model_inputs['input_ids'][0]):],skip_special_tokens=True)\n",
" return response.strip()\n",
" print(response.strip())\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g8zEBraI0grT"
},
"outputs": [],
"source": [
"# convert csv string into panda\n",
"\n",
"def csv_handler(csv_string):\n",
"\n",
" try:\n",
" # Convert CSV string to DataFrame\n",
" df = pd.read_csv(io.StringIO(csv_string))\n",
" return df\n",
" except Exception as e:\n",
" # Return error message as DataFrame if parsing fails\n",
" error_df = pd.DataFrame({\"Error\": [f\"Failed to parse CSV: {str(e)}\"]})\n",
" return error_df\n",
" print(df, error_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vLPsusTL1zNB"
},
"outputs": [],
"source": [
"# usecase to csv_string\n",
"\n",
"def usecase_to_csv(usecase):\n",
" try:\n",
" # Get CSV string from your LLM\n",
" csv_string = dataset_call(usecase)\n",
"\n",
" # Process into DataFrame for Gradio display\n",
" df = csv_handler(csv_string)\n",
"\n",
" return df\n",
"\n",
" except Exception as e:\n",
" error_df = pd.DataFrame({\"Error\": [f\"LLM processing failed: {str(e)}\"]})\n",
" return error_df, \"\", gr.update(visible=False)\n",
"\n",
" print(df, error_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "H3WTLa9a2Rdy"
},
"outputs": [],
"source": [
"def download_csv(csv_string):\n",
" if csv_string:\n",
" return csv_string\n",
" return \"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XhMVSrVhjYvz"
},
"outputs": [],
"source": [
"#test\n",
"usecase = \"A financial services company is looking for synthetic data to test its Expected Credit Losses (ECL) model under IFRS9.\"\n",
"#dataset_call(usecase)\n",
"usecase_to_csv(usecase)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z3Ze4o2qjs5y"
},
"outputs": [],
"source": [
"\n",
"demo = gr.Interface(\n",
" fn = usecase_to_csv,\n",
" inputs = gr.Textbox(lines=5,label=\"Describe your usecase\",placeholder=\"Describe the dataset you would like to create and how you will use it\"),\n",
" outputs = gr.DataFrame(label=\"Here is your dataset!\",interactive=True),\n",
" title = \"Friendly Neighbourhood Synthetic Data Creator!\",\n",
" description = \"Let me know your use case for synthetic data and I will create it for you.\",\n",
" examples=[\n",
" \"Generate a dataset of 10 employees with name, department, salary, and years of experience\",\n",
" \"Create sample e-commerce data with product names, categories, prices, and ratings\",\n",
" \"Generate customer survey responses with demographics and satisfaction scores\",\n",
" \"A financial services company is looking for synthetic data to test its Expected Credit Losses (ECL) model under IFRS9.\"\n",
" ]\n",
")\n",
"\n",
"demo.launch(debug=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ck1qdmbHo_G3"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyOay+EACzwO0uXDLuayhscX",
"gpuType": "L4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,287 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "zmpDFA3bGEHY"
},
"source": [
"Minute creator in Gradio from day 5 of week 3.\n",
"A couple of points to note:\n",
"\n",
"\n",
"* My access to llama hasn't been approved on Hugging Face and so I've experimented with some of the other models.\n",
"* There is a fair bit of debugging code in the main function as I was getting an error and couldn't find it. I've left it in just in case its useful for others trying to debug their code.\n",
"* I was debugging with the help of Claude. It suggested using <with torch.no_grad()> for the minute output. The rationale is that it disables gradient computation which isn't necessary for inference and I found it did speed things up.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l-5xKLFeJUGz"
},
"outputs": [],
"source": [
"!pip install -q requests torch bitsandbytes transformers sentencepiece accelerate openai httpx==0.27.2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Wi-bBD9VdBMo"
},
"outputs": [],
"source": [
"import os\n",
"import requests\n",
"from openai import OpenAI\n",
"from IPython.display import Markdown, display, update_display\n",
"from google.colab import drive\n",
"from huggingface_hub import login\n",
"from google.colab import userdata\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig\n",
"import torch\n",
"import gradio as gr"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-0O-kuWtdk4I"
},
"outputs": [],
"source": [
"# keys\n",
"\n",
"#openai\n",
"openai_api_key = userdata.get('OPENAI_API_KEY')\n",
"openai = OpenAI(api_key=openai_api_key)\n",
"\n",
"#hf\n",
"hf_token = userdata.get('HF_TOKEN')\n",
"login(hf_token, add_to_git_credential=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u6v3Ecileg1H"
},
"outputs": [],
"source": [
"# constants\n",
"\n",
"AUDIO_MODEL = 'gpt-4o-transcribe'\n",
"OPENAI_MODEL = 'gpt-4o-mini'\n",
"QWEN2_MODEL = 'Qwen/Qwen2.5-7B-Instruct' # runs slowly no matter what size gpu - kept crashing on ram!\n",
"GEMMA2_MODEL = \"google/gemma-2-2b-it\" # doesn't use a system prompt\n",
"PHI3 = \"microsoft/Phi-3-mini-4k-instruct\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3nSfA_KhfY38"
},
"outputs": [],
"source": [
"# convert audio to text\n",
"\n",
"def transcribe_audio(audio_file_path):\n",
" try:\n",
" with open (audio_file_path, 'rb') as audio_file:\n",
" transcript = openai.audio.transcriptions.create(model = AUDIO_MODEL, file = audio_file, response_format=\"text\")\n",
" return transcript\n",
" except Exception as e:\n",
" return f\"An error occurred: {str(e)}\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OVmlY3DGgnYc"
},
"outputs": [],
"source": [
"# use transcript to create minutes\n",
"# use open source model\n",
"\n",
"def create_minutes(transcript):\n",
"\n",
" # first try is for debugging\n",
" try:\n",
" print(f\"Starting to create minutes with transcript length: {len(str(transcript))}\")\n",
"\n",
" if not transcript or len(str(transcript).strip()) == 0:\n",
" return \"Error: Empty or invalid transcript\"\n",
"\n",
" #messages\n",
" system_prompt = \"You are an expert creator of meeting minutes. Based on a meeting transcript you can summarise the meeting title and date, attendees, key discussion points, key outcomes, actions and owners and next steps. Respond in Markdown.\"\n",
" user_prompt = f\"Create meeting minutes from the transcript provided. The minutes should be clear but succint and should include title and date, attendees, key discussion points, key outcomes, actions and owners, and next steps. {transcript}\"\n",
"\n",
" messages = [\n",
" {\"role\":\"system\",\"content\":system_prompt},\n",
" {\"role\":\"user\",\"content\":user_prompt}\n",
" ]\n",
" print(\"Messages prepared successfully\") # for debugging\n",
"\n",
" # quantisation (for os model)\n",
"\n",
" quantization_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=torch.bfloat16\n",
" )\n",
"\n",
" except Exception as e:\n",
" return f\"An error occurred in setup: {str(e)}\"\n",
"\n",
" # model & tokeniser\n",
" try:\n",
" print(\"Loading tokeniser....\") # for debugging\n",
" tokenizer = AutoTokenizer.from_pretrained(PHI3)\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
" print(\"Loading model.....\") # for debugging\n",
" model = AutoModelForCausalLM.from_pretrained(PHI3, device_map='auto', quantization_config=quantization_config)\n",
" print(f\"Model loaded on device {model.device}\") # for debugging\n",
"\n",
" # chat template\n",
" inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
" model_inputs = tokenizer(inputs, return_tensors=\"pt\").to(model.device)\n",
"\n",
" # torch.no_grad suggested by claude. This disables gradient computation which reduces memory usage and speeds things up\n",
" print(\"Generating text....\") # for debugging\n",
" with torch.no_grad():\n",
" outputs = model.generate(**model_inputs, max_new_tokens=2000, do_sample=True, temperature=0.7)\n",
" print(f\"Generation complete. Output shape: {outputs.shape}\") # for debugging\n",
"\n",
" #***debugging****\n",
"\n",
" # Decode the generated text (excluding the input prompt)\n",
" print(\"Starting text decoding...\") # debugging\n",
" input_length = len(model_inputs['input_ids'][0]) # debugging\n",
" print(f\"Input length: {input_length}, Output length: {len(outputs[0])}\") # debugging\n",
"\n",
" if len(outputs[0]) <= input_length: # debugging\n",
" return \"Error: Model didn't generate any new tokens. Try reducing input length or increasing max_new_tokens.\" # debugging\n",
"\n",
" generated_tokens = outputs[0][input_length:] # debugging\n",
" print(f\"Generated tokens length: {len(generated_tokens)}\") # debugging\n",
"\n",
" # decode generated text\n",
" generated_text = tokenizer.decode(outputs[0][len(model_inputs['input_ids'][0]):],skip_special_tokens=True)\n",
" print(f\"Decoded text length: {len(generated_text)}\")\n",
"\n",
" return generated_text.strip()\n",
"\n",
" except ImportError as e:\n",
" return f\"Import error - missing library: {str(e)}. Please install required packages.\"\n",
" except torch.cuda.OutOfMemoryError as e:\n",
" return f\"CUDA out of memory: {str(e)}. Try reducing max_new_tokens to 500 or use CPU.\"\n",
" except RuntimeError as e:\n",
" return f\"Runtime error: {str(e)}. This might be a CUDA/device issue.\"\n",
" except Exception as e:\n",
" return f\"Unexpected error during text generation: {type(e).__name__}: {str(e)}\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c63zzoDopw6u"
},
"outputs": [],
"source": [
"# create process for gradio\n",
"\n",
"def gr_process(audio_file, progress = gr.Progress()):\n",
"\n",
" if audio_file is None:\n",
" return \"Please provide an audio file\"\n",
"\n",
" try:\n",
" progress(0, desc=\"Analysing file\")\n",
" transcript = transcribe_audio(audio_file)\n",
"\n",
" if transcript.startswith(\"An error occurred\"):\n",
" return transcript\n",
"\n",
" progress(0.5, desc=\"File analysed, generating minutes\")\n",
"\n",
" minutes = create_minutes(transcript)\n",
" progress(0.9, desc=\"Nearly there\")\n",
"\n",
" return minutes\n",
"\n",
" except Exception as e:\n",
" return f\"An error occurred: {str(e)}\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "82fyQELQkGty"
},
"outputs": [],
"source": [
"# gradio interface\n",
"\n",
"demo = gr.Interface(\n",
" fn=gr_process,\n",
" inputs= gr.Audio(type=\"filepath\",label=\"Upload MP3 file\"),\n",
" outputs= gr.Markdown(label=\"Meeting minutes\"),\n",
" title = \"Meeting minute creator\",\n",
" description = \"Upload an mp3 audio file for a meeting and I will provide the minutes!\"\n",
")\n",
"\n",
"if __name__ == \"__main__\":\n",
" demo.launch(debug=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XljpyS7Nvxkh"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,643 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ac833f26-d429-4fd2-8f83-92174f1c951a",
"metadata": {},
"source": [
"# Code conversion using Gemini and Codestral in Windows 11"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c230178c-6f31-4c5a-a888-16b7037ffbf9",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import io\n",
"import sys\n",
"import gradio as gr\n",
"import subprocess\n",
"from google import genai\n",
"from google.genai import types\n",
"from mistralai import Mistral\n",
"from dotenv import load_dotenv\n",
"from IPython.display import Markdown, display, update_display"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d824484-eaaa-456a-b7dc-7e3277fec34a",
"metadata": {},
"outputs": [],
"source": [
"# Load Gemini and Mistral API Keys\n",
"\n",
"load_dotenv(override=True)\n",
"gemini_api_key = os.getenv(\"GOOGLE_API_KEY\")\n",
"mistral_api_key = os.getenv(\"MISTRAL_API_KEY\")\n",
"\n",
"if not mistral_api_key or not gemini_api_key:\n",
" print(\"API Key not found!\")\n",
"else:\n",
" print(\"API Key loaded in memory\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "86f3633e-81f9-4c13-b7b5-793ddc4f886f",
"metadata": {},
"outputs": [],
"source": [
"# Models to be used\n",
"\n",
"MODEL_GEMINI = 'gemini-2.5-flash'\n",
"MODEL_CODESTRAL = 'codestral-latest'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f3a6d53-70f9-46b8-a490-a50f3a1adf9e",
"metadata": {},
"outputs": [],
"source": [
"# Load Gemini client\n",
"try:\n",
" gemini_client = genai.Client(api_key=gemini_api_key)\n",
" print(\"Google GenAI Client initialized successfully!\")\n",
"\n",
" codestral_client = Mistral(api_key=mistral_api_key)\n",
" print(\"Mistral Client initialized successfully!\")\n",
"except Exception as e:\n",
" print(f\"Error initializing Client: {e}\")\n",
" exit() "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f816fbe8-e094-499f-98a5-588ebecf8c72",
"metadata": {},
"outputs": [],
"source": [
"# Gemini System prompt\n",
"\n",
"system_message = \"You are an assistant that reimplements Python code in high-performance C++ optimized for a Windows PC. \"\n",
"system_message += \"Use Windows-specific optimizations where applicable (e.g., multithreading with std::thread, SIMD, or WinAPI if necessary). \"\n",
"system_message += \"Respond only with the equivalent C++ code; include comments only where absolutely necessary. \"\n",
"system_message += \"Avoid any explanation or text outside the code. \"\n",
"system_message += \"The C++ output must produce identical functionality with the fastest possible execution time on Windows.\"\n",
"\n",
"generate_content_config = types.GenerateContentConfig(system_instruction=system_message)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01227835-15d2-40bd-a9dd-2ef35ad371dc",
"metadata": {},
"outputs": [],
"source": [
"def user_prompt_for(python):\n",
" user_prompt = (\n",
" \"Convert the following Python code into high-performance C++ optimized for Windows. \"\n",
" \"Use standard C++20 or newer with Windows-compatible libraries and best practices. \"\n",
" \"Ensure the implementation runs as fast as possible and produces identical output. \"\n",
" \"Use appropriate numeric types to avoid overflow or precision loss. \"\n",
" \"Avoid unnecessary abstraction; prefer direct computation and memory-efficient structures. \"\n",
" \"Respond only with C++ code, include all required headers (like <iomanip>, <vector>, etc.), and limit comments to only what's essential.\\n\\n\"\n",
" )\n",
" user_prompt += python\n",
" return user_prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d9fc8e2-acf0-4122-a8a9-5aadadf982ab",
"metadata": {},
"outputs": [],
"source": [
"def user_message_gemini(python): \n",
" return types.Content(role=\"user\", parts=[types.Part.from_text(text=user_prompt_for(python))]) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "334c8b84-6e37-40fc-97ac-40a1b3aa29fa",
"metadata": {},
"outputs": [],
"source": [
"def messages_for(python):\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt_for(python)}\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4aca87ac-6330-4ed4-a36f-1726fd0ada1a",
"metadata": {},
"outputs": [],
"source": [
"def write_output(cpp):\n",
" code = cpp.replace(\"```cpp\", \"\").replace(\"```c++\", \"\").replace(\"```\", \"\").strip()\n",
" \n",
" if not \"#include\" in code:\n",
" raise ValueError(\"C++ code appears invalid: missing #include directives.\")\n",
"\n",
" with open(\"optimized.cpp\", \"w\", encoding=\"utf-8\", newline=\"\\n\") as f:\n",
" f.write(code) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fcf42642-1a55-4556-8738-0c8c02effa9c",
"metadata": {},
"outputs": [],
"source": [
"# Generate CPP code using Gemini\n",
"\n",
"def optimize_gemini(python):\n",
" stream = gemini_client.models.generate_content_stream(\n",
" model = MODEL_GEMINI,\n",
" config=generate_content_config,\n",
" contents=user_message_gemini(python)\n",
" )\n",
" cpp_code = \"\"\n",
" for chunk in stream:\n",
" chunk_text = chunk.text\n",
" cpp_code += chunk_text\n",
" print(chunk_text, end=\"\", flush=True) \n",
" write_output(cpp_code)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f06a301-4397-4d63-9226-657bb2ddb792",
"metadata": {},
"outputs": [],
"source": [
"# Generate CPP code using Codestral\n",
"\n",
"def optimize_codestral(python):\n",
" stream = codestral_client.chat.stream(\n",
" model = MODEL_CODESTRAL,\n",
" messages = messages_for(python), \n",
" )\n",
" \n",
" cpp_code = \"\"\n",
" for chunk in stream:\n",
" chunk_text = chunk.data.choices[0].delta.content\n",
" cpp_code += chunk_text\n",
" print(chunk_text, end=\"\", flush=True) \n",
" write_output(cpp_code)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8bd51601-7c1d-478d-b043-6f92739e5c4b",
"metadata": {},
"outputs": [],
"source": [
"# Actual code to convert\n",
"\n",
"pi = \"\"\"\n",
"import time\n",
"\n",
"def calculate(iterations, param1, param2):\n",
" result = 1.0\n",
" for i in range(1, iterations+1):\n",
" j = i * param1 - param2\n",
" result -= (1/j)\n",
" j = i * param1 + param2\n",
" result += (1/j)\n",
" return result\n",
"\n",
"start_time = time.time()\n",
"result = calculate(100_000_000, 4, 1) * 4\n",
"end_time = time.time()\n",
"\n",
"print(f\"Result: {result:.12f}\")\n",
"print(f\"Execution Time: {(end_time - start_time):.6f} seconds\")\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db9ea24e-d381-48ac-9196-853d2527dcca",
"metadata": {},
"outputs": [],
"source": [
"exec(pi)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3e26708-8475-474d-8e96-e602c3d5ef9f",
"metadata": {},
"outputs": [],
"source": [
"optimize_gemini(pi)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2cc23ea7-6062-4354-92bc-730baa52a50b",
"metadata": {},
"outputs": [],
"source": [
"# CPP Compilation\n",
"\n",
"!g++ -O3 -std=c++20 -o optimized.exe optimized.cpp"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b14704d-95fe-4ed2-861f-af591bf3090e",
"metadata": {},
"outputs": [],
"source": [
"!.\\optimized.exe"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d756d1a-1d49-4cfb-bed7-8748d848b083",
"metadata": {},
"outputs": [],
"source": [
"optimize_codestral(pi)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e286dc8-9532-48b1-b748-a7950972e7df",
"metadata": {},
"outputs": [],
"source": [
"!g++ -O3 -std=c++20 -o optimized.exe optimized.cpp"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61fe0044-7679-4245-9e59-50642f3d80c6",
"metadata": {},
"outputs": [],
"source": [
"!.\\optimized.exe"
]
},
{
"cell_type": "markdown",
"id": "f0c0392c-d2a7-4619-82a2-f7b9fa7c43f9",
"metadata": {},
"source": [
"## Hard Code"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ca53eb4-46cd-435b-a950-0e2a8f845535",
"metadata": {},
"outputs": [],
"source": [
"python_hard = \"\"\"# Be careful to support large number sizes\n",
"\n",
"def lcg(seed, a=1664525, c=1013904223, m=2**32):\n",
" value = seed\n",
" while True:\n",
" value = (a * value + c) % m\n",
" yield value\n",
" \n",
"def max_subarray_sum(n, seed, min_val, max_val):\n",
" lcg_gen = lcg(seed)\n",
" random_numbers = [next(lcg_gen) % (max_val - min_val + 1) + min_val for _ in range(n)]\n",
" max_sum = float('-inf')\n",
" for i in range(n):\n",
" current_sum = 0\n",
" for j in range(i, n):\n",
" current_sum += random_numbers[j]\n",
" if current_sum > max_sum:\n",
" max_sum = current_sum\n",
" return max_sum\n",
"\n",
"def total_max_subarray_sum(n, initial_seed, min_val, max_val):\n",
" total_sum = 0\n",
" lcg_gen = lcg(initial_seed)\n",
" for _ in range(20):\n",
" seed = next(lcg_gen)\n",
" total_sum += max_subarray_sum(n, seed, min_val, max_val)\n",
" return total_sum\n",
"\n",
"# Parameters\n",
"n = 10000 # Number of random numbers\n",
"initial_seed = 42 # Initial seed for the LCG\n",
"min_val = -10 # Minimum value of random numbers\n",
"max_val = 10 # Maximum value of random numbers\n",
"\n",
"# Timing the function\n",
"import time\n",
"start_time = time.time()\n",
"result = total_max_subarray_sum(n, initial_seed, min_val, max_val)\n",
"end_time = time.time()\n",
"\n",
"print(\"Total Maximum Subarray Sum (20 runs):\", result)\n",
"print(\"Execution Time: {:.6f} seconds\".format(end_time - start_time))\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "697cc9fe-efdb-40b7-8e43-871bd2df940e",
"metadata": {},
"outputs": [],
"source": [
"exec(python_hard)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17ed6329-6c5f-45af-91ff-06d73830dd0d",
"metadata": {},
"outputs": [],
"source": [
"optimize_gemini(python_hard)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b57f0e7-46c9-4235-86eb-389faf37b7bb",
"metadata": {},
"outputs": [],
"source": [
"# CPP Compilation\n",
"\n",
"!g++ -O3 -std=c++20 -o optimized.exe optimized.cpp"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8ce8d01-fda8-400d-b3d4-6f1ad3008d28",
"metadata": {},
"outputs": [],
"source": [
"!.\\optimized.exe"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adbcdac7-8656-41c9-8707-d8a71998d393",
"metadata": {},
"outputs": [],
"source": [
"optimize_codestral(python_hard)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f9fc9b1-29cf-4510-83f8-1484d26e871e",
"metadata": {},
"outputs": [],
"source": [
"# CPP Compilation\n",
"\n",
"!g++ -O3 -std=c++20 -o optimized.exe optimized.cpp"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52170458-c4a1-4920-8d83-8c5ba7250759",
"metadata": {},
"outputs": [],
"source": [
"!.\\optimized.exe"
]
},
{
"cell_type": "markdown",
"id": "da6aee85-2792-487b-bef3-fec5dcf12623",
"metadata": {},
"source": [
"## Accommodating the entire code in Gradio"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2a90c4f-c289-4658-a6ce-51b80e20f91f",
"metadata": {},
"outputs": [],
"source": [
"def stream_gemini(python):\n",
" stream = gemini_client.models.generate_content_stream(\n",
" model = MODEL_GEMINI,\n",
" config=generate_content_config,\n",
" contents=user_message_gemini(python)\n",
" )\n",
"\n",
" cpp_code = \"\"\n",
" for chunk in stream:\n",
" chunk_text = chunk.text or \"\"\n",
" cpp_code += chunk_text\n",
" yield cpp_code.replace('```cpp\\n','').replace('```','')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e872171-96d8-4041-8cb0-0c632c5e957f",
"metadata": {},
"outputs": [],
"source": [
"def stream_codestral(python):\n",
" stream = codestral_client.chat.stream(\n",
" model = MODEL_CODESTRAL,\n",
" messages = messages_for(python), \n",
" )\n",
"\n",
" cpp_code = \"\"\n",
" for chunk in stream:\n",
" chunk_text = chunk.data.choices[0].delta.content or \"\"\n",
" cpp_code += chunk_text\n",
" yield cpp_code.replace('```cpp\\n','').replace('```','') "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3340b36b-1241-4b0f-9e69-d4e5cc215a27",
"metadata": {},
"outputs": [],
"source": [
"def optimize(python, model):\n",
" if model.lower() == 'gemini':\n",
" result = stream_gemini(python)\n",
" elif model.lower() == 'codestral':\n",
" result = stream_codestral(python)\n",
" else:\n",
" raise ValueError(\"Unknown model\")\n",
" \n",
" for stream_so_far in result:\n",
" yield stream_so_far "
]
},
{
"cell_type": "markdown",
"id": "277ddd6c-e71e-4512-965a-57fca341487a",
"metadata": {},
"source": [
"### Gradio Implementation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "222a9eae-236e-4ba3-8f23-3d9b879ec2d0",
"metadata": {},
"outputs": [],
"source": [
"custom_css = \"\"\"\n",
".scrollable-box textarea {\n",
" overflow: auto !important;\n",
" height: 400px;\n",
"}\n",
"\n",
".python {background-color: #306998;}\n",
".cpp {background-color: #050;}\n",
"\n",
"\"\"\"\n",
"\n",
"theme = gr.themes.Soft()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4bd6ed1-ff8c-42d4-8da6-24b9cfd134db",
"metadata": {},
"outputs": [],
"source": [
"def execute_python(code):\n",
" try:\n",
" result = subprocess.run(\n",
" [\"python\", \"-c\", code],\n",
" capture_output=True,\n",
" text=True,\n",
" timeout=60\n",
" )\n",
" if result.returncode == 0:\n",
" return result.stdout or \"[No output]\"\n",
" else:\n",
" return f\"[Error]\\n{result.stderr}\"\n",
" except subprocess.TimeoutExpired:\n",
" return \"[Error] Execution timed out.\"\n",
" except Exception as e:\n",
" return f\"[Exception] {str(e)}\" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1507c973-8699-48b2-80cd-45900c97a867",
"metadata": {},
"outputs": [],
"source": [
"def execute_cpp(code):\n",
" write_output(code)\n",
" \n",
" try:\n",
" compile_cmd = [\"g++\", \"-O3\", \"-std=c++20\", \"-o\", \"optimized.exe\", \"optimized.cpp\"]\n",
" compile_result = subprocess.run(compile_cmd, capture_output=True, text=True, check=True)\n",
" \n",
" run_cmd = [\"optimized.exe\"]\n",
" run_result = subprocess.run(run_cmd, check=True, text=True, capture_output=True, timeout=60)\n",
" \n",
" return run_result.stdout or \"[No output]\"\n",
" \n",
" except subprocess.CalledProcessError as e:\n",
" return f\"[Compile/Runtime Error]\\n{e.stderr}\"\n",
" except subprocess.TimeoutExpired:\n",
" return \"[Error] Execution timed out.\"\n",
" except Exception as e:\n",
" return f\"[Exception] {str(e)}\" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "374f00f3-8fcf-4ae9-bf54-c5a44dd74844",
"metadata": {},
"outputs": [],
"source": [
"with gr.Blocks(css=custom_css, theme=theme) as ui:\n",
" gr.Markdown(\"## Convert code from Python to C++\")\n",
" with gr.Row():\n",
" python = gr.Textbox(label=\"Python code:\", lines=10, value=python_hard, elem_classes=[\"scrollable-box\"])\n",
" cpp = gr.Textbox(label=\"C++ code:\", lines=10, elem_classes=[\"scrollable-box\"])\n",
" with gr.Row():\n",
" model = gr.Dropdown([\"Gemini\", \"Codestral\"], label=\"Select model\", value=\"Gemini\")\n",
" convert = gr.Button(\"Convert code\")\n",
" with gr.Row():\n",
" python_run = gr.Button(\"Run Python\")\n",
" cpp_run = gr.Button(\"Run C++\")\n",
" with gr.Row():\n",
" python_out = gr.TextArea(label=\"Python result:\", elem_classes=[\"python\"])\n",
" cpp_out = gr.TextArea(label=\"C++ result:\", elem_classes=[\"cpp\"])\n",
"\n",
" convert.click(optimize, inputs=[python,model], outputs=[cpp])\n",
" python_run.click(execute_python,inputs=[python], outputs=[python_out])\n",
" cpp_run.click(execute_cpp, inputs=[cpp], outputs=[cpp_out])\n",
"\n",
"ui.launch(inbrowser=True) "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,476 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "4c07cdc9-bce0-49ad-85c7-14f1872b8519",
"metadata": {},
"source": [
"# Python to CPP using Qwen2.5-Coder-32B-Instruct with Hyperbolic Inference Endpoint in Windows"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f051c517-c4fd-4248-98aa-b808fae76cf6",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import io\n",
"import sys\n",
"import gradio as gr\n",
"import subprocess\n",
"from dotenv import load_dotenv\n",
"from huggingface_hub import InferenceClient\n",
"from google import genai\n",
"from google.genai import types\n",
"from mistralai import Mistral"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6c8777b-57bc-436a-978f-21a37ea310ae",
"metadata": {},
"outputs": [],
"source": [
"# Load Api Keys from env\n",
"\n",
"load_dotenv(override=True)\n",
"\n",
"hf_api_key = os.getenv(\"HF_TOKEN\")\n",
"gemini_api_key = os.getenv(\"GOOGLE_API_KEY\")\n",
"mistral_api_key = os.getenv(\"MISTRAL_API_KEY\")\n",
"\n",
"if not mistral_api_key or not gemini_api_key or not hf_api_key:\n",
" print(\"API Key not found!\")\n",
"else:\n",
" print(\"API Key loaded in memory\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5cf6f93-7e07-40e0-98b8-d4e74ea18402",
"metadata": {},
"outputs": [],
"source": [
"# MODELs \n",
"\n",
"MODEL_QWEN = \"Qwen/Qwen2.5-Coder-32B-Instruct\"\n",
"MODEL_GEMINI = 'gemini-2.5-flash'\n",
"MODEL_CODESTRAL = 'codestral-latest'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "689547c3-aaa5-4800-86a2-da52765997d8",
"metadata": {},
"outputs": [],
"source": [
"# Load Clients\n",
"\n",
"try:\n",
" gemini_client = genai.Client(api_key=gemini_api_key)\n",
" print(\"Google GenAI Client initialized successfully!\")\n",
"\n",
" codestral_client = Mistral(api_key=mistral_api_key)\n",
" print(\"Mistral Client initialized successfully!\")\n",
" \n",
" hf_client = InferenceClient(provider=\"hyperbolic\",api_key=hf_api_key)\n",
" print(\"Hyperbolic Inference Client initialized successfully!\")\n",
"except Exception as e:\n",
" print(f\"Error initializing Client: {e}\")\n",
" exit() "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c3a81f4-99c3-463a-ae30-4656a7a246d2",
"metadata": {},
"outputs": [],
"source": [
"system_message = \"You are an assistant that reimplements Python code in high-performance C++ optimized for a Windows PC. \"\n",
"system_message += \"Use Windows-specific optimizations where applicable (e.g., multithreading with std::thread, SIMD, or WinAPI if necessary). \"\n",
"system_message += \"Respond only with the equivalent C++ code; include comments only where absolutely necessary. \"\n",
"system_message += \"Avoid any explanation or text outside the code. \"\n",
"system_message += \"The C++ output must produce identical functionality with the fastest possible execution time on Windows.\"\n",
"\n",
"generate_content_config = types.GenerateContentConfig(system_instruction=system_message)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fde9514-1005-4539-b01b-0372730ce67b",
"metadata": {},
"outputs": [],
"source": [
"def user_prompt_for(python):\n",
" user_prompt = (\n",
" \"Convert the following Python code into high-performance C++ optimized for Windows. \"\n",
" \"Use standard C++20 or newer with Windows-compatible libraries and best practices. \"\n",
" \"Ensure the implementation runs as fast as possible and produces identical output. \"\n",
" \"Use appropriate numeric types to avoid overflow or precision loss. \"\n",
" \"Avoid unnecessary abstraction; prefer direct computation and memory-efficient structures. \"\n",
" \"Respond only with C++ code, include all required headers (like <iomanip>, <vector>, etc.), and limit comments to only what's essential.\\n\\n\"\n",
" )\n",
" user_prompt += python\n",
" return user_prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "89c8b010-08dd-4695-a784-65162d82a24b",
"metadata": {},
"outputs": [],
"source": [
"def user_message_gemini(python): \n",
" return types.Content(role=\"user\", parts=[types.Part.from_text(text=user_prompt_for(python))]) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "66923158-983d-46f7-ab19-f216fb1f6a87",
"metadata": {},
"outputs": [],
"source": [
"def messages_for(python):\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt_for(python)}\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ab59a54-b28a-4d07-b04f-b568e6e25dfb",
"metadata": {},
"outputs": [],
"source": [
"def write_output(cpp):\n",
" code = cpp.replace(\"```cpp\", \"\").replace(\"```c++\", \"\").replace(\"```\", \"\").strip()\n",
" \n",
" if not \"#include\" in code:\n",
" raise ValueError(\"C++ code appears invalid: missing #include directives.\")\n",
"\n",
" with open(\"qwenOptimized.cpp\", \"w\", encoding=\"utf-8\", newline=\"\\n\") as f:\n",
" f.write(code) "
]
},
{
"cell_type": "markdown",
"id": "e05ea9f0-6ade-4699-b5fa-fb8ef9f16bcb",
"metadata": {},
"source": [
"### Python Codes"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c515ce2c-1f8d-4484-8d34-9ffe1372dad4",
"metadata": {},
"outputs": [],
"source": [
"python_easy = \"\"\"\n",
"import time\n",
"\n",
"def calculate(iterations, param1, param2):\n",
" result = 1.0\n",
" for i in range(1, iterations+1):\n",
" j = i * param1 - param2\n",
" result -= (1/j)\n",
" j = i * param1 + param2\n",
" result += (1/j)\n",
" return result\n",
"\n",
"start_time = time.time()\n",
"result = calculate(100_000_000, 4, 1) * 4\n",
"end_time = time.time()\n",
"\n",
"print(f\"Result: {result:.12f}\")\n",
"print(f\"Execution Time: {(end_time - start_time):.6f} seconds\")\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83ab4080-71ae-45e6-970b-030dc462f571",
"metadata": {},
"outputs": [],
"source": [
"python_hard = \"\"\"# Be careful to support large number sizes\n",
"\n",
"def lcg(seed, a=1664525, c=1013904223, m=2**32):\n",
" value = seed\n",
" while True:\n",
" value = (a * value + c) % m\n",
" yield value\n",
" \n",
"def max_subarray_sum(n, seed, min_val, max_val):\n",
" lcg_gen = lcg(seed)\n",
" random_numbers = [next(lcg_gen) % (max_val - min_val + 1) + min_val for _ in range(n)]\n",
" max_sum = float('-inf')\n",
" for i in range(n):\n",
" current_sum = 0\n",
" for j in range(i, n):\n",
" current_sum += random_numbers[j]\n",
" if current_sum > max_sum:\n",
" max_sum = current_sum\n",
" return max_sum\n",
"\n",
"def total_max_subarray_sum(n, initial_seed, min_val, max_val):\n",
" total_sum = 0\n",
" lcg_gen = lcg(initial_seed)\n",
" for _ in range(20):\n",
" seed = next(lcg_gen)\n",
" total_sum += max_subarray_sum(n, seed, min_val, max_val)\n",
" return total_sum\n",
"\n",
"# Parameters\n",
"n = 10000 # Number of random numbers\n",
"initial_seed = 42 # Initial seed for the LCG\n",
"min_val = -10 # Minimum value of random numbers\n",
"max_val = 10 # Maximum value of random numbers\n",
"\n",
"# Timing the function\n",
"import time\n",
"start_time = time.time()\n",
"result = total_max_subarray_sum(n, initial_seed, min_val, max_val)\n",
"end_time = time.time()\n",
"\n",
"print(\"Total Maximum Subarray Sum (20 runs):\", result)\n",
"print(\"Execution Time: {:.6f} seconds\".format(end_time - start_time))\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"id": "31498c5c-ecdd-4ed7-9607-4d09af893b98",
"metadata": {},
"source": [
"## Code Implementation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea4a4968-e04f-4939-8c42-32c960699354",
"metadata": {},
"outputs": [],
"source": [
"def stream_gemini(python):\n",
" stream = gemini_client.models.generate_content_stream(\n",
" model = MODEL_GEMINI,\n",
" config=generate_content_config,\n",
" contents=user_message_gemini(python)\n",
" )\n",
"\n",
" cpp_code = \"\"\n",
" for chunk in stream:\n",
" chunk_text = chunk.text or \"\"\n",
" cpp_code += chunk_text\n",
" yield cpp_code.replace('```cpp\\n','').replace('```','')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69601eee-520f-4813-b796-aee9118e8a72",
"metadata": {},
"outputs": [],
"source": [
"def stream_codestral(python):\n",
" stream = codestral_client.chat.stream(\n",
" model = MODEL_CODESTRAL,\n",
" messages = messages_for(python), \n",
" )\n",
"\n",
" cpp_code = \"\"\n",
" for chunk in stream:\n",
" chunk_text = chunk.data.choices[0].delta.content or \"\"\n",
" cpp_code += chunk_text\n",
" yield cpp_code.replace('```cpp\\n','').replace('```','') "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb8899cf-54c0-4d2d-8772-42925c2e1d13",
"metadata": {},
"outputs": [],
"source": [
"def stream_qwen(python):\n",
" stream = hf_client.chat.completions.create(\n",
" model = MODEL_QWEN,\n",
" messages = messages_for(python),\n",
" stream=True\n",
" )\n",
" cpp_code = \"\"\n",
" for chunk in stream:\n",
" chunk_text = chunk.choices[0].delta.content\n",
" cpp_code += chunk_text\n",
" yield cpp_code.replace('```cpp\\n','').replace('```','')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98862fef-905c-4b50-bc7a-4c0462495b5c",
"metadata": {},
"outputs": [],
"source": [
"def optimize(python, model):\n",
" if model.lower() == 'gemini':\n",
" result = stream_gemini(python)\n",
" elif model.lower() == 'codestral':\n",
" result = stream_codestral(python)\n",
" elif model.lower() == 'qwen_coder':\n",
" result = stream_qwen(python)\n",
" else:\n",
" raise ValueError(\"Unknown model\")\n",
" \n",
" for stream_so_far in result:\n",
" yield stream_so_far "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa9372df-db01-41d0-842c-4857b20f93f0",
"metadata": {},
"outputs": [],
"source": [
"custom_css = \"\"\"\n",
".scrollable-box textarea {\n",
" overflow: auto !important;\n",
" height: 400px;\n",
"}\n",
"\n",
".python {background-color: #306998;}\n",
".cpp {background-color: #050;}\n",
"\n",
"\"\"\"\n",
"\n",
"theme = gr.themes.Soft()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dbcf9fe9-c3da-466b-8478-83dcdbe7d48e",
"metadata": {},
"outputs": [],
"source": [
"def execute_python(code):\n",
" try:\n",
" result = subprocess.run(\n",
" [\"python\", \"-c\", code],\n",
" capture_output=True,\n",
" text=True,\n",
" timeout=60\n",
" )\n",
" if result.returncode == 0:\n",
" return result.stdout or \"[No output]\"\n",
" else:\n",
" return f\"[Error]\\n{result.stderr}\"\n",
" except subprocess.TimeoutExpired:\n",
" return \"[Error] Execution timed out.\"\n",
" except Exception as e:\n",
" return f\"[Exception] {str(e)}\" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8029e00d-1ee8-43d1-8c87-2aa0544cf94c",
"metadata": {},
"outputs": [],
"source": [
"def execute_cpp(code):\n",
" write_output(code)\n",
" \n",
" try:\n",
" compile_cmd = [\"g++\", \"-O3\", \"-std=c++20\", \"-o\", \"optimized.exe\", \"optimized.cpp\"]\n",
" compile_result = subprocess.run(compile_cmd, capture_output=True, text=True, check=True)\n",
" \n",
" run_cmd = [\"optimized.exe\"]\n",
" run_result = subprocess.run(run_cmd, check=True, text=True, capture_output=True, timeout=60)\n",
" \n",
" return run_result.stdout or \"[No output]\"\n",
" \n",
" except subprocess.CalledProcessError as e:\n",
" return f\"[Compile/Runtime Error]\\n{e.stderr}\"\n",
" except subprocess.TimeoutExpired:\n",
" return \"[Error] Execution timed out.\"\n",
" except Exception as e:\n",
" return f\"[Exception] {str(e)}\" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5f4e88c-be15-4870-9f99-82b6273ee739",
"metadata": {},
"outputs": [],
"source": [
"with gr.Blocks(css=custom_css, theme=theme) as ui:\n",
" gr.Markdown(\"## Convert code from Python to C++\")\n",
" with gr.Row():\n",
" python = gr.Textbox(label=\"Python code:\", lines=10, value=python_hard, elem_classes=[\"scrollable-box\"])\n",
" cpp = gr.Textbox(label=\"C++ code:\", lines=10, elem_classes=[\"scrollable-box\"])\n",
" with gr.Row():\n",
" model = gr.Dropdown([\"Gemini\", \"Codestral\", \"QWEN_Coder\"], label=\"Select model\", value=\"Gemini\")\n",
" convert = gr.Button(\"Convert code\")\n",
" with gr.Row():\n",
" python_run = gr.Button(\"Run Python\")\n",
" cpp_run = gr.Button(\"Run C++\")\n",
" with gr.Row():\n",
" python_out = gr.TextArea(label=\"Python result:\", elem_classes=[\"python\"])\n",
" cpp_out = gr.TextArea(label=\"C++ result:\", elem_classes=[\"cpp\"])\n",
"\n",
" convert.click(optimize, inputs=[python,model], outputs=[cpp])\n",
" python_run.click(execute_python,inputs=[python], outputs=[python_out])\n",
" cpp_run.click(execute_cpp, inputs=[cpp], outputs=[cpp_out])\n",
"\n",
"ui.launch(inbrowser=True) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa1a231e-2743-4cee-afe2-783d2b9513e5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,337 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cc7674a9-6164-4424-85a9-f669454cfd2a",
"metadata": {},
"source": [
"I used this project to play about with Gradio blocks a little bit as it had more inputs than the other projects I've done.\n",
"Its a password generator which I have no doubt I will use!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04c8d2dd-cb9a-4b18-b12d-48ed2f39679a",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import requests\n",
"import google.generativeai\n",
"import anthropic\n",
"from IPython.display import Markdown, display, update_display\n",
"import gradio as gr"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04521351-f220-42fe-9dc5-d0be80c95dd7",
"metadata": {},
"outputs": [],
"source": [
"# keys\n",
"\n",
"load_dotenv(override=True)\n",
"openai_api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"\n",
"if openai_api_key:\n",
" print(\"All good\")\n",
"else:\n",
" print(\"OpenAI key issue\")\n",
"\n",
"claude_api_key = os.getenv(\"ANTHROPIC_API_KEY\")\n",
"\n",
"if claude_api_key:\n",
" print(\"All good\")\n",
"else:\n",
" print(\"Claude key issue\")\n",
"\n",
"google_api_key = os.getenv(\"GOOGLE_API_KEY\")\n",
"\n",
"if google_api_key:\n",
" print(\"All good\")\n",
"else:\n",
" print(\"Google key issue\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70fd3748-e6b6-4ac2-89a5-ef65ed7e41a3",
"metadata": {},
"outputs": [],
"source": [
"# initialise\n",
"\n",
"openai = OpenAI()\n",
"claude = anthropic.Anthropic()\n",
"google.generativeai.configure()\n",
"\n",
"OPENAI_MODEL = \"gpt-4o\"\n",
"CLAUDE_MODEL = \"claude-sonnet-4-20250514\"\n",
"GOOGLE_MODEL = \"gemini-2.0-flash\"\n",
"\n",
"max_tok = 500"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a448651-e426-4c3c-96f7-d69975dc7b10",
"metadata": {},
"outputs": [],
"source": [
"#Prompts\n",
"\n",
"def pass_system_prompt(required_len, spec_char=\"Y\",num_char=\"Y\",min_lowercase=1,min_uppercase=1):\n",
"\n",
" system_prompt = f\"\"\"You are a secure password generator. Your task is to create a single, cryptographically strong password that meets ALL specified requirements.\n",
" \n",
"CRITICAL REQUIREMENTS:\n",
"- Length: EXACTLY {required_len} characters\n",
"- Must include: At least {min_lowercase} lowercase letter(s) AND at least {min_uppercase} uppercase letter(s)\n",
"- Special characters: {'REQUIRED - include at least 1 char' if spec_char else 'FORBIDDEN - do not include any'}\n",
"- Numbers: {'REQUIRED - include at least 1 digit' if num_char else 'FORBIDDEN - do not include any digits'}\n",
"\n",
"SECURITY RULES:\n",
"1. Generate truly random passwords - avoid patterns, dictionary words, or predictable sequences\n",
"2. Distribute character types evenly throughout the password\n",
"3. Do not use repeated characters excessively (max 2 of same character)\n",
"4. Ensure password meets minimum complexity for each required character type\n",
"\n",
"OUTPUT FORMAT:\n",
"- Respond with ONLY the generated password\n",
"- No explanations, no additional text, just the password\n",
"- Verify the password meets ALL requirements before responding\"\"\"\n",
"\n",
" return system_prompt\n",
"\n",
"def pass_user_prompt(required_len, spec_char=\"Y\",num_char=\"Y\",min_lowercase=1,min_uppercase=1):\n",
" \n",
" user_prompt = f\"\"\"Generate a secure password with these exact specifications:\n",
" \n",
"Length: {required_len} characters\n",
"Lowercase letters: Required (minimum {min_lowercase})\n",
"Uppercase letters: Required (minimum {min_uppercase})\n",
"Numbers: {'Required (minimum 1)' if num_char else 'Not allowed'}\n",
"Special characters: {'Required (minimum 1)' if spec_char else 'Not allowed'}\n",
"\n",
"Requirements verification checklist:\n",
"✓ Exactly {required_len} characters total\n",
"✓ Contains {min_lowercase}+ lowercase letters\n",
"✓ Contains {min_uppercase}+ uppercase letters\n",
"✓ {'Contains 1+ numbers' if num_char else 'Contains NO numbers'}\n",
"✓ {'Contains 1+ special characters' if spec_char else 'Contains NO special characters'}\n",
"✓ No obvious patterns or dictionary words\n",
"✓ Good distribution of character types\n",
"\n",
"Generate the password now.\"\"\"\n",
"\n",
" return user_prompt\n",
" \n",
"def pass_messages(required_len, spec_char,num_char,min_lowercase,min_uppercase):\n",
" messages = [\n",
" {\"role\":\"system\",\"content\":pass_system_prompt(required_len, spec_char,num_char,min_lowercase,min_uppercase)},\n",
" {\"role\":\"user\",\"content\":pass_user_prompt(required_len, spec_char,num_char,min_lowercase,min_uppercase)}\n",
" ]\n",
"\n",
" return messages\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "857370b0-35a5-4b50-8715-86f8e781523b",
"metadata": {},
"outputs": [],
"source": [
"#test\n",
"\n",
"messages1 = pass_messages(12, \"N\", \"Y\",1,1)\n",
"print(messages1)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59ab4279-90a8-4997-8e15-f07295856222",
"metadata": {},
"outputs": [],
"source": [
"def openai_password_gen(required_len, spec_char, num_char,min_lowercase,min_uppercase):\n",
" response=openai.chat.completions.create(\n",
" model=OPENAI_MODEL,\n",
" max_tokens=max_tok,\n",
" messages=pass_messages(required_len, spec_char,num_char,min_lowercase,min_uppercase)\n",
" )\n",
" return response.choices[0].message.content\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5e1a41a-b03c-4408-a0f5-00529785f3d1",
"metadata": {},
"outputs": [],
"source": [
"def claude_password_gen(required_len, spec_char, num_char,min_lowercase,min_uppercase):\n",
" response = claude.messages.create(\n",
" model=CLAUDE_MODEL,\n",
" max_tokens=max_tok,\n",
" system=pass_system_prompt(required_len, spec_char, num_char,min_lowercase,min_uppercase),\n",
" messages = [{\"role\":\"user\",\"content\":pass_user_prompt(required_len, spec_char, num_char,min_lowercase,min_uppercase)}]\n",
" )\n",
" return response.content[0].text\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a41a0a2-55a1-47e5-8fc0-5dd04ebd3573",
"metadata": {},
"outputs": [],
"source": [
"def google_password_gen(required_len, spec_char, num_char,min_lowercase,min_uppercase):\n",
" message = google.generativeai.GenerativeModel(\n",
" model_name=GOOGLE_MODEL,\n",
" system_instruction=pass_system_prompt(required_len, spec_char, num_char,min_lowercase,min_uppercase)\n",
" )\n",
" response = message.generate_content(pass_user_prompt(required_len, spec_char, num_char,min_lowercase,min_uppercase))\n",
" return response.text"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dcd1ce50-6576-4594-8739-1d7daf602213",
"metadata": {},
"outputs": [],
"source": [
"#test\n",
"messages1 = openai_password_gen(12, \"N\",\"Y\",1,1)\n",
"messages2 = claude_password_gen(12,\"N\",\"Y\",1,1)\n",
"messages3= google_password_gen(12,\"N\",\"Y\",1,1)\n",
"print(\"OpenAI: \",messages1)\n",
"print(\"Claude: \", messages2)\n",
"print(\"Gemini: \", messages3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9cec429a-2355-4941-8422-480b2614009c",
"metadata": {},
"outputs": [],
"source": [
"# model select\n",
"\n",
"def select_model(required_len, spec_char, num_char,min_lowercase,min_uppercase,model):\n",
" if model == \"OpenAI\":\n",
" return openai_password_gen(required_len, spec_char, num_char,min_lowercase,min_uppercase)\n",
" elif model == \"Claude\":\n",
" return claude_password_gen(required_len, spec_char, num_char,min_lowercase,min_uppercase)\n",
" elif model == \"Gemini\":\n",
" return google_password_gen(required_len, spec_char, num_char,min_lowercase,min_uppercase)\n",
" else:\n",
" print(\"No model selected\")\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bef52e6d-dc50-4c91-9d56-624dfdd66276",
"metadata": {},
"outputs": [],
"source": [
"test = select_model(12, \"N\",\"Y\",1,1,\"OpenAI\")\n",
"\n",
"print(test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b9d3685-a1b8-470c-8f4b-e63d68a0240d",
"metadata": {},
"outputs": [],
"source": [
"css = \"\"\"\n",
"#password_box textarea {\n",
" background-color: #306998;\n",
" color: white;\n",
"}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81c423ec-0ca7-4c96-a2fe-02ed2b5f3839",
"metadata": {},
"outputs": [],
"source": [
"\n",
"with gr.Blocks(css=css) as demo:\n",
" gr.Markdown(\"Choose your password complexity requirements and run:\")\n",
" with gr.Row():\n",
" with gr.Column(min_width=150,scale=2):\n",
" with gr.Row():\n",
" required_len = gr.Number(label=\"Specify the required length\",value=12,minimum=1,maximum=30)\n",
" min_lowercase = gr.Number(label=\"the minimum lowercase letters\", value=1,minimum=0)\n",
" min_uppercase = gr.Number(label=\"the minimum uppercase letters\", value=1,minimum=0)\n",
" with gr.Column():\n",
" spec_char = gr.Checkbox(label=\"Include special characters?\",value=True)\n",
" num_char = gr.Checkbox(label=\"Include numbers?\", value=True)\n",
" with gr.Row():\n",
" with gr.Column():\n",
" model = gr.Dropdown([\"OpenAI\",\"Claude\",\"Gemini\"])\n",
" btn = gr.Button(\"Run\")\n",
" with gr.Column():\n",
" output = gr.Textbox(label=\"Password:\", elem_id=\"password_box\")\n",
" \n",
" btn.click(fn=select_model,inputs=[required_len,spec_char,num_char,min_lowercase,min_uppercase,model],outputs=output)\n",
"\n",
"demo.launch()\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d81a8318-57ef-46ae-91b7-ae63d661edd8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,420 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "65b3aadc-c540-4cb2-a338-d523d3f22e5b",
"metadata": {},
"source": [
"Unit test generator using GPT, Claude and Gemini.\n",
"This will create unit test code from python and also run the code and provide the result (including any errors)\n",
"Note:\n",
"When I tried to use claude-sonnet-4-20250514 the results were too big and the python was cut-off (no matter how big I made the max tokens). This seemed to be the case for both examples. I've changed it to claude-3-5-sonnet-20240620 and it seems to be run better."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e610bf56-a46e-4aff-8de1-ab49d62b1ad3",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import requests\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import google.generativeai\n",
"import anthropic\n",
"from IPython.display import Markdown, display, update_display\n",
"import gradio as gr\n",
"import sys\n",
"import io\n",
"import traceback\n",
"import unittest\n",
"import subprocess\n",
"import tempfile"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4f672e1c-87e9-4865-b760-370fa605e614",
"metadata": {},
"outputs": [],
"source": [
"# keys\n",
"\n",
"load_dotenv(override=True)\n",
"openai_api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"\n",
"if openai_api_key:\n",
" print(\"All good\")\n",
"else:\n",
" print(\"OpenAI key issue\")\n",
"\n",
"claude_api_key = os.getenv(\"ANTHROPIC_API_KEY\")\n",
"\n",
"if claude_api_key:\n",
" print(\"All good\")\n",
"else:\n",
" print(\"Claude key issue\")\n",
"\n",
"google_api_key = os.getenv(\"GOOGLE_API_KEY\")\n",
"\n",
"if google_api_key:\n",
" print(\"All good\")\n",
"else:\n",
" print(\"Google key issue\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8aa149ed-9298-4d69-8fe2-8f5de0f667da",
"metadata": {},
"outputs": [],
"source": [
"# initialise\n",
"\n",
"openai = OpenAI()\n",
"claude = anthropic.Anthropic()\n",
"google.generativeai.configure()\n",
"\n",
"OPENAI_MODEL = \"gpt-4o\"\n",
"CLAUDE_MODEL = \"claude-3-5-sonnet-20240620\" #\"claude-sonnet-4-20250514\"\n",
"GOOGLE_MODEL = \"gemini-2.0-flash\"\n",
"\n",
"max_tok = 5000"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6896636f-923e-4a2c-9d6c-fac07828a201",
"metadata": {},
"outputs": [],
"source": [
"system_message = \"You are an engineer with responsibility for unit testing python code.\"\n",
"system_message += \"You review base python code and develop unit tests, also in python, which validate each unit of code.\"\n",
"system_message += \"\"\" The output must be in Python with both the unit tests and comments explaining the purpose of each test.\n",
"The output should not include any additional text at the start or end including \"```\". It should be possible to run the code without any updates including an execution statement.\n",
"Include the base / original python code in the response.\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e7b3546-57aa-4c29-bc5d-f211970d04eb",
"metadata": {},
"outputs": [],
"source": [
"def user_prompt_for(python):\n",
" user_prompt = \"Review the Python code provided and develop unit tests which can be run in a jupyter lab.\"\n",
" user_prompt += \"\"\" The output must be in Python with both the unit tests and comments explaining the purpose of each test.\n",
"The output should not include any additional text at the start or end including \"```\". It should be possible to run the code without any updates (include an execution statement).\n",
"Include the base / original python code in the response.\"\"\"\n",
" user_prompt += python\n",
" return user_prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6190659-f54c-4951-bef4-4960f8e51cc4",
"metadata": {},
"outputs": [],
"source": [
"def messages_for(python):\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt_for(python)}\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b327aa3-3277-44e1-972f-aa7158147ddd",
"metadata": {},
"outputs": [],
"source": [
"# python example\n",
"example = \"\"\"class BookNotAvailableError(Exception):\n",
" pass\n",
"\n",
"class Library:\n",
" def __init__(self):\n",
" self.inventory = {} # book title -> quantity\n",
" self.borrowed = {} # user -> list of borrowed book titles\n",
"\n",
" def add_book(self, title, quantity=1):\n",
" if quantity <= 0:\n",
" raise ValueError(\"Quantity must be positive\")\n",
" self.inventory[title] = self.inventory.get(title, 0) + quantity\n",
"\n",
" def borrow_book(self, user, title):\n",
" if self.inventory.get(title, 0) < 1:\n",
" raise BookNotAvailableError(f\"'{title}' is not available\")\n",
" self.inventory[title] -= 1\n",
" self.borrowed.setdefault(user, []).append(title)\n",
"\n",
" def return_book(self, user, title):\n",
" if user not in self.borrowed or title not in self.borrowed[user]:\n",
" raise ValueError(f\"User '{user}' did not borrow '{title}'\")\n",
" self.borrowed[user].remove(title)\n",
" self.inventory[title] = self.inventory.get(title, 0) + 1\n",
"\n",
" def get_available_books(self):\n",
" return {title: qty for title, qty in self.inventory.items() if qty > 0}\n",
"\n",
" def get_borrowed_books(self, user):\n",
" return self.borrowed.get(user, [])\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed6e624e-88a5-4f10-8ab5-f071f0ca3041",
"metadata": {},
"outputs": [],
"source": [
"# python example2\n",
"example2 = \"\"\"class Calculator:\n",
" def add(self, a, b):\n",
" return a + b\n",
"\n",
" def subtract(self, a, b):\n",
" return a - b\n",
"\n",
" def divide(self, a, b):\n",
" if b == 0:\n",
" raise ValueError(\"Cannot divide by zero\")\n",
" return a / b\n",
"\n",
" def multiply(self, a, b):\n",
" return a * b\n",
"\n",
"\n",
"def is_prime(n):\n",
" if n <= 1:\n",
" return False\n",
" if n <= 3:\n",
" return True\n",
" if n % 2 == 0 or n % 3 == 0:\n",
" return False\n",
" i = 5\n",
" while i * i <= n:\n",
" if n % i == 0 or n % (i + 2) == 0:\n",
" return False\n",
" i += 6\n",
" return True\n",
" \"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7d2fea8-74c6-4421-8f1e-0e76d5b201b9",
"metadata": {},
"outputs": [],
"source": [
"def unit_test_gpt(python): \n",
" stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python), stream=True)\n",
" reply = \"\"\n",
" for chunk in stream:\n",
" fragment = chunk.choices[0].delta.content or \"\"\n",
" reply += fragment\n",
" yield reply"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7cd84ad8-d55c-4fe0-9eeb-1895c95c4a9d",
"metadata": {},
"outputs": [],
"source": [
"def unit_test_claude(python):\n",
" result = claude.messages.stream(\n",
" model=CLAUDE_MODEL,\n",
" max_tokens=max_tok,\n",
" system=system_message,\n",
" messages=[{\"role\": \"user\", \"content\": user_prompt_for(python)}],\n",
" )\n",
" reply = \"\"\n",
" with result as stream:\n",
" for text in stream.text_stream:\n",
" reply += text\n",
" yield reply"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad86f652-879a-489f-9891-bdc2d97c33b0",
"metadata": {},
"outputs": [],
"source": [
"def unit_test_google(python):\n",
" model = google.generativeai.GenerativeModel(\n",
" model_name=GOOGLE_MODEL,\n",
" system_instruction=system_message\n",
" )\n",
" stream = model.generate_content(contents=user_prompt_for(python),stream=True)\n",
" reply = \"\"\n",
" for chunk in stream:\n",
" reply += chunk.text or \"\"\n",
" yield reply.replace(\"```python\\n\", \"\").replace(\"```\", \"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "105db6f9-343c-491d-8e44-3a5328b81719",
"metadata": {},
"outputs": [],
"source": [
"#unit_test_gpt(example)\n",
"#unit_test_claude(example)\n",
"#unit_test_google(example)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f1ae8f5-16c8-40a0-aa18-63b617df078d",
"metadata": {},
"outputs": [],
"source": [
"def select_model(python, model):\n",
" if model==\"GPT\":\n",
" result = unit_test_gpt(python)\n",
" elif model==\"Claude\":\n",
" result = unit_test_claude(python)\n",
" elif model==\"Google\":\n",
" result = unit_test_google(python)\n",
" else:\n",
" raise ValueError(\"Unknown model\")\n",
" for stream_so_far in result:\n",
" yield stream_so_far "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1ddb38e-6b0a-4c37-baa4-ace0b7de887a",
"metadata": {},
"outputs": [],
"source": [
"# with gr.Blocks() as ui:\n",
"# with gr.Row():\n",
"# python = gr.Textbox(label=\"Python code:\", lines=10, value=example)\n",
"# test = gr.Textbox(label=\"Unit tests\", lines=10)\n",
"# with gr.Row():\n",
"# model = gr.Dropdown([\"GPT\", \"Claude\",\"Google\"], label=\"Select model\", value=\"GPT\")\n",
"# generate = gr.Button(\"Generate unit tests\")\n",
"\n",
"# generate.click(select_model, inputs=[python, model], outputs=[test])\n",
"\n",
"# ui.launch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "389ae411-a4f6-44f2-8b26-d46a971687a7",
"metadata": {},
"outputs": [],
"source": [
"def execute_python(code):\n",
" # Capture stdout and stderr\n",
" output = io.StringIO()\n",
" sys_stdout = sys.stdout\n",
" sys_stderr = sys.stderr\n",
" sys.stdout = output\n",
" sys.stderr = output\n",
"\n",
" try:\n",
" # Compile the code first\n",
" compiled_code = compile(code, '<string>', 'exec')\n",
"\n",
" # Prepare a namespace dict for exec environment\n",
" # Include __builtins__ so imports like 'import unittest' work\n",
" namespace = {\"__builtins__\": __builtins__}\n",
"\n",
" # Run the user's code, but expect tests will be defined here\n",
" exec(compiled_code, namespace)\n",
"\n",
" # Look for unittest.TestCase subclasses in the namespace\n",
" loader = unittest.TestLoader()\n",
" suite = unittest.TestSuite()\n",
"\n",
" for obj in namespace.values():\n",
" if isinstance(obj, type) and issubclass(obj, unittest.TestCase):\n",
" tests = loader.loadTestsFromTestCase(obj)\n",
" suite.addTests(tests)\n",
"\n",
" # Run the tests\n",
" runner = unittest.TextTestRunner(stream=output, verbosity=2)\n",
" result = runner.run(suite)\n",
"\n",
" except SystemExit as e:\n",
" # Catch sys.exit calls from unittest.main()\n",
" output.write(f\"\\nSystemExit called with code {e.code}\\n\")\n",
" except Exception as e:\n",
" # Catch other errors\n",
" output.write(f\"\\nException: {e}\\n\")\n",
" finally:\n",
" sys.stdout = sys_stdout\n",
" sys.stderr = sys_stderr\n",
"\n",
" return output.getvalue()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eca98de3-9e2f-4c23-8bb4-dbb2787a15a4",
"metadata": {},
"outputs": [],
"source": [
"with gr.Blocks() as ui:\n",
" with gr.Row():\n",
" python = gr.Textbox(label=\"Python code:\", lines=10, value=example2)\n",
" test = gr.Textbox(label=\"Unit tests\", lines=10)\n",
" test_run = gr.Textbox(label=\"Test results\", lines=10)\n",
" with gr.Row():\n",
" model = gr.Dropdown([\"GPT\", \"Claude\",\"Google\"], label=\"Select model\", value=\"GPT\")\n",
" generate = gr.Button(\"Generate unit tests\")\n",
" run = gr.Button(\"Run unit tests\")\n",
"\n",
" generate.click(select_model, inputs=[python, model], outputs=[test])\n",
" run.click(execute_python, inputs=[test],outputs=[test_run])\n",
"\n",
"ui.launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,463 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "2080947c-96d9-447f-8368-cfdc9e5c9960",
"metadata": {},
"source": [
"# Using Semantic chunks with Gemini API and Gemini Embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53221f1a-a0c1-4506-a3d0-d6626c58e4e0",
"metadata": {},
"outputs": [],
"source": [
"# Regular Imports\n",
"import os\n",
"import glob\n",
"import time\n",
"from dotenv import load_dotenv\n",
"from tqdm.notebook import tqdm\n",
"import gradio as gr"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a2a7171-a7b6-42a6-96d7-c93f360689ec",
"metadata": {},
"outputs": [],
"source": [
"# Visual Import\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.manifold import TSNE\n",
"import numpy as np\n",
"import plotly.graph_objects as go"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51c9d658-65e5-40a1-8680-d0b561f87649",
"metadata": {},
"outputs": [],
"source": [
"# Lang Chain Imports\n",
"\n",
"from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI\n",
"from langchain_community.document_loaders import DirectoryLoader, TextLoader\n",
"from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
"from langchain_core.messages import HumanMessage, AIMessage\n",
"from langchain_chroma import Chroma\n",
"from langchain_experimental.text_splitter import SemanticChunker\n",
"from langchain_core.chat_history import InMemoryChatMessageHistory\n",
"from langchain_core.runnables.history import RunnableWithMessageHistory\n",
"from langchain.chains.combine_documents import create_stuff_documents_chain\n",
"from langchain.chains.history_aware_retriever import create_history_aware_retriever\n",
"from langchain.chains import create_retrieval_chain\n",
"from langchain_core.prompts import MessagesPlaceholder\n",
"from langchain_core.runnables import RunnableLambda"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e7ed82b-b28a-4094-9f77-3b6432dd0f7a",
"metadata": {},
"outputs": [],
"source": [
"# Constants\n",
"\n",
"CHAT_MODEL = \"gemini-2.5-flash\"\n",
"EMBEDDING_MODEL = \"models/text-embedding-004\"\n",
"# EMBEDDING_MODEL_EXP = \"models/gemini-embedding-exp-03-07\"\n",
"\n",
"folders = glob.glob(\"knowledge-base/*\")\n",
"text_loader_kwargs = {'encoding': 'utf-8'}\n",
"db_name = \"vector_db\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b83281a2-bcae-41ab-a347-0e7f9688d1ed",
"metadata": {},
"outputs": [],
"source": [
"load_dotenv(override=True)\n",
"\n",
"api_key = os.getenv(\"GOOGLE_API_KEY\")\n",
"\n",
"if not api_key:\n",
" print(\"API Key not found!\")\n",
"else:\n",
" print(\"API Key loaded in memory\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4fd6d516-772b-478d-9b28-09d42f2277d7",
"metadata": {},
"outputs": [],
"source": [
"def add_metadata(doc, doc_type):\n",
" doc.metadata[\"doc_type\"] = doc_type\n",
" return doc"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6bc4198b-f989-42c0-95b5-3596448fcaa2",
"metadata": {},
"outputs": [],
"source": [
"documents = []\n",
"for folder in tqdm(folders, desc=\"Loading folders\"):\n",
" doc_type = os.path.basename(folder)\n",
" loader = DirectoryLoader(folder, glob=\"**/*.md\", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)\n",
" folder_docs = loader.load()\n",
" documents.extend([add_metadata(doc, doc_type) for doc in folder_docs])\n",
"\n",
"print(f\"Total documents loaded: {len(documents)}\")"
]
},
{
"cell_type": "markdown",
"id": "bb74241f-e9d5-42e8-9a4b-f31018397d66",
"metadata": {},
"source": [
"## Create Semantic Chunks"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a3aa17f-f5d0-430a-80da-95c284bd99a8",
"metadata": {},
"outputs": [],
"source": [
"chunking_embedding_model = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, task_type=\"retrieval_document\")\n",
"\n",
"text_splitter = SemanticChunker(\n",
" chunking_embedding_model,\n",
" breakpoint_threshold_type=\"percentile\", \n",
" breakpoint_threshold_amount=95.0, \n",
" min_chunk_size=3 \n",
")\n",
"\n",
"start = time.time()\n",
"\n",
"semantic_chunks = []\n",
"pbar = tqdm(documents, desc=\"Semantic chunking documents\")\n",
"\n",
"for i, doc in enumerate(pbar):\n",
" doc_type = doc.metadata.get('doc_type', 'Unknown')\n",
" pbar.set_postfix_str(f\"Processing: {doc_type}\")\n",
" try:\n",
" doc_chunks = text_splitter.split_documents([doc])\n",
" semantic_chunks.extend(doc_chunks)\n",
" except Exception as e:\n",
" tqdm.write(f\"❌ Failed to split doc ({doc.metadata.get('source', 'unknown source')}): {e}\")\n",
"print(f\"⏱️ Took {time.time() - start:.2f} seconds\")\n",
"print(f\"Total semantic chunks: {len(semantic_chunks)}\")\n",
"\n",
"# import time\n",
"# start = time.time()\n",
"\n",
"# try:\n",
"# semantic_chunks = text_splitter.split_documents(documents)\n",
"# print(f\"✅ Chunking completed with {len(semantic_chunks)} chunks\")\n",
"# except Exception as e:\n",
"# print(f\"❌ Failed to split documents: {e}\")\n",
"\n",
"# print(f\"⏱️ Took {time.time() - start:.2f} seconds\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "675b98d6-5ed0-45d1-8f79-765911e6badf",
"metadata": {},
"outputs": [],
"source": [
"# Some Preview of the chunks\n",
"for i, doc in enumerate(semantic_chunks[:15]):\n",
" print(f\"--- Chunk {i+1} ---\")\n",
" print(doc.page_content) \n",
" print(\"\\n\")"
]
},
{
"cell_type": "markdown",
"id": "c17accff-539a-490b-8a5f-b5ce632a3c71",
"metadata": {},
"source": [
"## Embed with Gemini Embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0bd228bd-37d2-4aaf-b0f6-d94943f6f248",
"metadata": {},
"outputs": [],
"source": [
"embedding = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL,task_type=\"retrieval_document\")\n",
"\n",
"if os.path.exists(db_name):\n",
" Chroma(persist_directory=db_name, embedding_function=embedding).delete_collection()\n",
"\n",
"vectorstore = Chroma.from_documents(\n",
" documents=semantic_chunks,\n",
" embedding=embedding,\n",
" persist_directory=db_name\n",
")\n",
"\n",
"print(f\"✅ Vectorstore created with {vectorstore._collection.count()} documents\")"
]
},
{
"cell_type": "markdown",
"id": "ce0a3e23-5912-4de2-bf34-3c0936375de1",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Visualzing Vectors"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ffdc6f5-ec25-4229-94d4-1fc6bb4d2702",
"metadata": {},
"outputs": [],
"source": [
"collection = vectorstore._collection\n",
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
"vectors = np.array(result['embeddings'])\n",
"documents = result['documents']\n",
"metadatas = result['metadatas']\n",
"doc_types = [metadata['doc_type'] for metadata in metadatas]\n",
"colors = [['blue', 'green', 'red', 'orange'][['products', 'employees', 'contracts', 'company'].index(t)] for t in doc_types]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5428164b-f0d5-4d2b-ac4a-514c43ceaa79",
"metadata": {},
"outputs": [],
"source": [
"# We humans find it easier to visalize things in 2D!\n",
"# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
"# (t-distributed stochastic neighbor embedding)\n",
"\n",
"tsne = TSNE(n_components=2, random_state=42)\n",
"reduced_vectors = tsne.fit_transform(vectors)\n",
"\n",
"# Create the 2D scatter plot\n",
"fig = go.Figure(data=[go.Scatter(\n",
" x=reduced_vectors[:, 0],\n",
" y=reduced_vectors[:, 1],\n",
" mode='markers',\n",
" marker=dict(size=5, color=colors, opacity=0.8),\n",
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(doc_types, documents)],\n",
" hoverinfo='text'\n",
")])\n",
"\n",
"fig.update_layout(\n",
" title='2D Chroma Vector Store Visualization',\n",
" scene=dict(xaxis_title='x',yaxis_title='y'),\n",
" width=800,\n",
" height=600,\n",
" margin=dict(r=20, b=10, l=10, t=40)\n",
")\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "markdown",
"id": "359b8651-a382-4050-8bf8-123e5cdf4d53",
"metadata": {},
"source": [
"## RAG Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08a75313-6c68-42e5-bd37-78254123094c",
"metadata": {},
"outputs": [],
"source": [
"retriever = vectorstore.as_retriever(search_kwargs={\"k\": 20 })\n",
"\n",
"# Conversation Memory\n",
"# memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n",
"\n",
"chat_llm = ChatGoogleGenerativeAI(model=CHAT_MODEL, temperature=0.7)\n",
"\n",
"question_generator_template = \"\"\"Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.\n",
"If the follow up question is already a standalone question, return it as is.\n",
"\n",
"Chat History:\n",
"{chat_history}\n",
"Follow Up Input: {input} \n",
"Standalone question:\"\"\"\n",
"\n",
"question_generator_prompt = ChatPromptTemplate.from_messages([\n",
" MessagesPlaceholder(variable_name=\"chat_history\"),\n",
" HumanMessagePromptTemplate.from_template(\"{input}\")\n",
"])\n",
"\n",
"history_aware_retriever = create_history_aware_retriever(\n",
" chat_llm, retriever, question_generator_prompt\n",
")\n",
"\n",
"qa_system_prompt = \"\"\"You are Insurellms intelligent virtual assistant, designed to answer questions with accuracy and clarity. Respond naturally and helpfully, as if you're part of the team.\n",
"Use the retrieved documents and prior conversation to provide accurate, conversational, and concise answers.Rephrase source facts in a natural tone, not word-for-word.\n",
"When referencing people or company history, prioritize clarity and correctness.\n",
"Only infer from previous conversation if it provides clear and factual clues. Do not guess or assume missing information.\n",
"If you truly dont have the answer, respond with:\n",
"\"I don't have that information.\"\n",
"Avoid repeating the user's wording unnecessarily. Do not refer to 'the context', speculate, or make up facts.\n",
"\n",
"{context}\"\"\"\n",
"\n",
"\n",
"qa_human_prompt = \"{input}\" \n",
"\n",
"qa_prompt = ChatPromptTemplate.from_messages([\n",
" SystemMessagePromptTemplate.from_template(qa_system_prompt),\n",
" MessagesPlaceholder(variable_name=\"chat_history\"),\n",
" HumanMessagePromptTemplate.from_template(\"{input}\")\n",
"])\n",
"\n",
"combine_docs_chain = create_stuff_documents_chain(chat_llm, qa_prompt)\n",
"\n",
"# inspect_context = RunnableLambda(lambda inputs: (\n",
"# print(\"\\n Retrieved Context:\\n\", \"\\n---\\n\".join([doc.page_content for doc in inputs[\"context\"]])),\n",
"# inputs # pass it through unchanged\n",
"# )[1])\n",
"\n",
"# inspect_inputs = RunnableLambda(lambda inputs: (\n",
"# print(\"\\n Inputs received by the chain:\\n\", inputs),\n",
"# inputs\n",
"# )[1])\n",
"\n",
"base_chain = create_retrieval_chain(history_aware_retriever, combine_docs_chain)\n",
"\n",
"# Using Runnable Lambda as Gradio needs the response to contain only the output (answer) and base_chain would have a dict with input, context, chat_history, answer\n",
"\n",
"# base_chain_with_output = base_chain | inspect_context | RunnableLambda(lambda res: res[\"answer\"])\n",
"# base_chain_with_output = base_chain | RunnableLambda(lambda res: res[\"answer\"])\n",
"\n",
"\n",
"# Session Persistent Chat History \n",
"# If we want to persist history between sessions then use MongoDB (or any non sql DB)to store and use MongoDBChatMessageHistory (relevant DB Wrapper)\n",
"\n",
"chat_histories = {}\n",
"\n",
"def get_history(session_id):\n",
" if session_id not in chat_histories:\n",
" chat_histories[session_id] = InMemoryChatMessageHistory()\n",
" return chat_histories[session_id]\n",
"\n",
"# Currently set to streaming ...if one shot response is needed then comment base_chain and output_message_key and enable base_chain_with_output\n",
"conversation_chain = RunnableWithMessageHistory(\n",
" # base_chain_with_output,\n",
" base_chain,\n",
" get_history,\n",
" output_messages_key=\"answer\", \n",
" input_messages_key=\"input\",\n",
" history_messages_key=\"chat_history\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06b58566-70cb-42eb-8b1c-9fe353fe71f0",
"metadata": {},
"outputs": [],
"source": [
"def chat(question, history):\n",
" try:\n",
" # result = conversation_chain.invoke({\"input\": question, \"chat_history\": memory.buffer_as_messages})\n",
" \n",
" # memory.chat_memory.add_user_message(question)\n",
" # memory.chat_memory.add_ai_message(result[\"answer\"])\n",
"\n",
" # return result[\"answer\"]\n",
"\n",
" \n",
" session_id = \"default-session\"\n",
"\n",
" # # FUll chat version\n",
" # result = conversation_chain.invoke(\n",
" # {\"input\": question},\n",
" # config={\"configurable\": {\"session_id\": session_id}}\n",
" # )\n",
" # # print(result)\n",
" # return result\n",
"\n",
" # Streaming Version\n",
" response_buffer = \"\"\n",
"\n",
" for chunk in conversation_chain.stream({\"input\": question},config={\"configurable\": {\"session_id\": session_id}}):\n",
" if \"answer\" in chunk:\n",
" response_buffer += chunk[\"answer\"]\n",
" yield response_buffer \n",
" except Exception as e:\n",
" print(f\"An error occurred during chat: {e}\")\n",
" return \"I apologize, but I encountered an error and cannot answer that right now.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a577ac66-3952-4821-83d2-8a50bad89971",
"metadata": {},
"outputs": [],
"source": [
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "56b63a17-2522-46e5-b5a3-e2e80e52a723",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,552 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "61777022-631c-4db0-afeb-70d8d22bc07b",
"metadata": {},
"source": [
"Summary:\n",
"This is the project from week 5. The intention was to create a vector db of my own files (from an external drive) which can be used in a RAG solution.\n",
"This includes a number of file types (docx, pdf, txt, epub...) and includes the ability to exclude folders.\n",
"With the OpenAI embeddings API limit of 300k tokens, it was also necessary to create a batch embeddings process so that there were multiple requests.\n",
"This was based on estimating the tokens with a text to token rate of 1:4, however it wasn't perfect and one of the batches still exceeded the 300k limit when running.\n",
"I found that the responses from the llm were terrible in the end! I tried playing about with chunk sizes and the minimum # of chunks by llangchain and it did improve but was not fantastic. I also ensured the metadata was sent with each chunk to help.\n",
"This really highlighted the real world challenges of implementing RAG!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d78ef79d-e564-4c56-82f3-0485e4bf6986",
"metadata": {},
"outputs": [],
"source": [
"!pip install docx2txt\n",
"!pip install ebooklib\n",
"!pip install python-pptx\n",
"!pip install pypdf"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ec98119-456f-450c-a9a2-f375d74f5ce5",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import requests\n",
"from dotenv import load_dotenv\n",
"import glob\n",
"import gradio as gr\n",
"import time\n",
"from typing import List"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac14410b-8c3c-4cf5-900e-fd4c33cdf2b2",
"metadata": {},
"outputs": [],
"source": [
"# imports for langchain, plotly and Chroma\n",
"\n",
"from langchain.document_loaders import (\n",
" DirectoryLoader,\n",
" Docx2txtLoader,\n",
" TextLoader,\n",
" PyPDFLoader,\n",
" UnstructuredExcelLoader,\n",
" BSHTMLLoader\n",
")\n",
"from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter\n",
"from langchain.schema import Document\n",
"from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n",
"from langchain_chroma import Chroma\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.manifold import TSNE\n",
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"from langchain.memory import ConversationBufferMemory\n",
"from langchain.chains import ConversationalRetrievalChain\n",
"from langchain.embeddings import HuggingFaceEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3be698e7-71e1-4c75-9696-e1651e4bf357",
"metadata": {},
"outputs": [],
"source": [
"MODEL = \"gpt-4o-mini\"\n",
"db_name = \"vector_db\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f850068-c05b-4526-9494-034b0077347e",
"metadata": {},
"outputs": [],
"source": [
"# Load environment variables in a file called .env\n",
"\n",
"load_dotenv(override=True)\n",
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c5baad2-2033-40a6-8ebd-5861b5cf4350",
"metadata": {},
"outputs": [],
"source": [
"# handling epubs\n",
"\n",
"from ebooklib import epub\n",
"from bs4 import BeautifulSoup\n",
"from langchain.document_loaders.base import BaseLoader\n",
"\n",
"class EpubLoader(BaseLoader):\n",
" def __init__(self, file_path: str):\n",
" self.file_path = file_path\n",
"\n",
" def load(self) -> list[Document]:\n",
" book = epub.read_epub(self.file_path)\n",
" text = ''\n",
" for item in book.get_items():\n",
" if item.get_type() == epub.EpubHtml:\n",
" soup = BeautifulSoup(item.get_content(), 'html.parser')\n",
" text += soup.get_text() + '\\n'\n",
"\n",
" return [Document(page_content=text, metadata={\"source\": self.file_path})]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd8b0e4e-d698-4484-bc94-d8b753f386cc",
"metadata": {},
"outputs": [],
"source": [
"# handling pptx\n",
"\n",
"from pptx import Presentation\n",
"\n",
"class PptxLoader(BaseLoader):\n",
" def __init__(self, file_path: str):\n",
" self.file_path = file_path\n",
"\n",
" def load(self) -> list[Document]:\n",
" prs = Presentation(self.file_path)\n",
" text = ''\n",
" for slide in prs.slides:\n",
" for shape in slide.shapes:\n",
" if hasattr(shape, \"text\") and shape.text:\n",
" text += shape.text + '\\n'\n",
"\n",
" return [Document(page_content=text, metadata={\"source\": self.file_path})]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b222b01d-6040-4ff3-a0e3-290819cfe94b",
"metadata": {},
"outputs": [],
"source": [
"# Class based version of document loader which can be expanded more easily for other document types. (Currently includes file types: docx, txt (windows encoding), xlsx, pdfs, epubs, pptx)\n",
"\n",
"class DocumentLoader:\n",
" \"\"\"A clean, extensible document loader for multiple file types.\"\"\"\n",
" \n",
" def __init__(self, base_path=\"D:/*\", exclude_folders=None):\n",
" self.base_path = base_path\n",
" self.documents = []\n",
" self.exclude_folders = exclude_folders or []\n",
" \n",
" # Configuration for different file types\n",
" self.loader_config = {\n",
" 'docx': {\n",
" 'loader_cls': Docx2txtLoader,\n",
" 'glob_pattern': \"**/*.docx\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': None\n",
" },\n",
" 'txt': {\n",
" 'loader_cls': TextLoader,\n",
" 'glob_pattern': \"**/*.txt\",\n",
" 'loader_kwargs': {\"encoding\": \"cp1252\"},\n",
" 'post_process': None\n",
" },\n",
" 'pdf': {\n",
" 'loader_cls': PyPDFLoader,\n",
" 'glob_pattern': \"**/*.pdf\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': None\n",
" },\n",
" 'xlsx': {\n",
" 'loader_cls': UnstructuredExcelLoader,\n",
" 'glob_pattern': \"**/*.xlsx\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': None\n",
" },\n",
" 'html': {\n",
" 'loader_cls': BSHTMLLoader,\n",
" 'glob_pattern': \"**/*.html\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': None\n",
" },\n",
" 'epub': {\n",
" 'loader_cls': EpubLoader,\n",
" 'glob_pattern': \"**/*.epub\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': self._process_epub_metadata\n",
" },\n",
" 'pptx': {\n",
" 'loader_cls': PptxLoader,\n",
" 'glob_pattern': \"**/*.pptx\",\n",
" 'loader_kwargs': {},\n",
" 'post_process': None\n",
" }\n",
" }\n",
" \n",
" def _get_epub_metadata(self, file_path):\n",
" \"\"\"Extract metadata from EPUB files.\"\"\"\n",
" try:\n",
" book = epub.read_epub(file_path)\n",
" title = book.get_metadata('DC', 'title')[0][0] if book.get_metadata('DC', 'title') else None\n",
" author = book.get_metadata('DC', 'creator')[0][0] if book.get_metadata('DC', 'creator') else None\n",
" return title, author\n",
" except Exception as e:\n",
" print(f\"Error extracting EPUB metadata: {e}\")\n",
" return None, None\n",
" \n",
" def _process_epub_metadata(self, doc) -> None:\n",
" \"\"\"Post-process EPUB documents to add metadata.\"\"\"\n",
" title, author = self._get_epub_metadata(doc.metadata['source'])\n",
" doc.metadata[\"author\"] = author\n",
" doc.metadata[\"title\"] = title\n",
" \n",
" def _load_file_type(self, folder, file_type, config):\n",
" \"\"\"Load documents of a specific file type from a folder.\"\"\"\n",
" try:\n",
" loader = DirectoryLoader(\n",
" folder, \n",
" glob=config['glob_pattern'], \n",
" loader_cls=config['loader_cls'],\n",
" loader_kwargs=config['loader_kwargs']\n",
" )\n",
" docs = loader.load()\n",
" print(f\" Found {len(docs)} .{file_type} files\")\n",
" \n",
" # Apply post-processing if defined\n",
" if config['post_process']:\n",
" for doc in docs:\n",
" config['post_process'](doc)\n",
" \n",
" return docs\n",
" \n",
" except Exception as e:\n",
" print(f\" Error loading .{file_type} files: {e}\")\n",
" return []\n",
" \n",
" def load_all(self):\n",
" \"\"\"Load all documents from configured folders.\"\"\"\n",
" all_folders = [f for f in glob.glob(self.base_path) if os.path.isdir(f)]\n",
"\n",
" #filter out excluded folders\n",
" folders = []\n",
" for folder in all_folders:\n",
" folder_name = os.path.basename(folder)\n",
" if folder_name not in self.exclude_folders:\n",
" folders.append(folder)\n",
" else:\n",
" print(f\"Excluded folder: {folder_name}\")\n",
" \n",
" print(\"Scanning folders (directories only):\", folders)\n",
" \n",
" self.documents = []\n",
" \n",
" for folder in folders:\n",
" doc_type = os.path.basename(folder)\n",
" print(f\"\\nProcessing folder: {doc_type}\")\n",
" \n",
" for file_type, config in self.loader_config.items():\n",
" docs = self._load_file_type(folder, file_type, config)\n",
" \n",
" # Add doc_type metadata to all documents\n",
" for doc in docs:\n",
" doc.metadata[\"doc_type\"] = doc_type\n",
" self.documents.append(doc)\n",
" \n",
" print(f\"\\nTotal documents loaded: {len(self.documents)}\")\n",
" return self.documents\n",
" \n",
" def add_file_type(self, extension, loader_cls, glob_pattern=None, \n",
" loader_kwargs=None, post_process=None):\n",
" \"\"\"Add support for a new file type.\"\"\"\n",
" self.loader_config[extension] = {\n",
" 'loader_cls': loader_cls,\n",
" 'glob_pattern': glob_pattern or f\"**/*.{extension}\",\n",
" 'loader_kwargs': loader_kwargs or {},\n",
" 'post_process': post_process\n",
" }\n",
"\n",
"# load\n",
"loader = DocumentLoader(\"D:/*\", exclude_folders=[\"Music\", \"Online Courses\", \"Fitness\"])\n",
"documents = loader.load_all()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3fd43a4f-b623-4b08-89eb-27d3b3ba0f62",
"metadata": {},
"outputs": [],
"source": [
"# create batches (this was required as the # of tokens was exceed the openai request limit)\n",
"\n",
"def estimate_tokens(text, chars_per_token=4):\n",
" \"\"\"Rough estimate of tokens from character count.\"\"\"\n",
" return len(text) // chars_per_token\n",
"\n",
"def create_batches(chunks, max_tokens_per_batch=250000):\n",
" batches = []\n",
" current_batch = []\n",
" current_tokens = 0\n",
" \n",
" for chunk in chunks:\n",
" chunk_tokens = estimate_tokens(chunk.page_content)\n",
" \n",
" # If adding this chunk would exceed the limit, start a new batch\n",
" if current_tokens + chunk_tokens > max_tokens_per_batch and current_batch:\n",
" batches.append(current_batch)\n",
" current_batch = [chunk]\n",
" current_tokens = chunk_tokens\n",
" else:\n",
" current_batch.append(chunk)\n",
" current_tokens += chunk_tokens\n",
" \n",
" # Add the last batch if it has content\n",
" if current_batch:\n",
" batches.append(current_batch)\n",
" \n",
" return batches\n",
"\n",
"def create_vectorstore_with_progress(chunks, embeddings, db_name, batch_size_tokens=250000):\n",
" \n",
" # Delete existing database if it exists\n",
" if os.path.exists(db_name):\n",
" print(f\"Deleting existing database: {db_name}\")\n",
" Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()\n",
" \n",
" # Create batches\n",
" batches = create_batches(chunks, batch_size_tokens)\n",
" print(f\"Created {len(batches)} batches from {len(chunks)} chunks\")\n",
" \n",
" # Show batch sizes\n",
" for i, batch in enumerate(batches):\n",
" total_chars = sum(len(chunk.page_content) for chunk in batch)\n",
" estimated_tokens = estimate_tokens(''.join(chunk.page_content for chunk in batch))\n",
" print(f\" Batch {i+1}: {len(batch)} chunks, ~{estimated_tokens:,} tokens\")\n",
" \n",
" vectorstore = None\n",
" successful_batches = 0\n",
" failed_batches = 0\n",
" \n",
" for i, batch in enumerate(batches):\n",
" print(f\"\\n{'='*50}\")\n",
" print(f\"Processing batch {i+1}/{len(batches)}\")\n",
" print(f\"{'='*50}\")\n",
" \n",
" try:\n",
" start_time = time.time()\n",
" \n",
" if vectorstore is None:\n",
" # Create the initial vectorstore\n",
" vectorstore = Chroma.from_documents(\n",
" documents=batch,\n",
" embedding=embeddings,\n",
" persist_directory=db_name\n",
" )\n",
" print(f\"Created initial vectorstore with {len(batch)} documents\")\n",
" else:\n",
" # Add to existing vectorstore\n",
" vectorstore.add_documents(batch)\n",
" print(f\"Added {len(batch)} documents to vectorstore\")\n",
" \n",
" successful_batches += 1\n",
" elapsed = time.time() - start_time\n",
" print(f\"Processed in {elapsed:.1f} seconds\")\n",
" print(f\"Total documents in vectorstore: {vectorstore._collection.count()}\")\n",
" \n",
" # Rate limiting delay\n",
" time.sleep(2)\n",
" \n",
" except Exception as e:\n",
" failed_batches += 1\n",
" print(f\"Error processing batch {i+1}: {e}\")\n",
" print(f\"Continuing with next batch...\")\n",
" continue\n",
" \n",
" print(f\"\\n{'='*50}\")\n",
" print(f\"SUMMARY\")\n",
" print(f\"{'='*50}\")\n",
" print(f\"Successful batches: {successful_batches}/{len(batches)}\")\n",
" print(f\"Failed batches: {failed_batches}/{len(batches)}\")\n",
" \n",
" if vectorstore:\n",
" final_count = vectorstore._collection.count()\n",
" print(f\"Final vectorstore contains: {final_count} documents\")\n",
" return vectorstore\n",
" else:\n",
" print(\"Failed to create vectorstore\")\n",
" return None\n",
"\n",
"# include metadata\n",
"def add_metadata_to_content(doc: Document) -> Document:\n",
" metadata_lines = []\n",
" if \"doc_type\" in doc.metadata:\n",
" metadata_lines.append(f\"Document Type: {doc.metadata['doc_type']}\")\n",
" if \"title\" in doc.metadata:\n",
" metadata_lines.append(f\"Title: {doc.metadata['title']}\")\n",
" if \"author\" in doc.metadata:\n",
" metadata_lines.append(f\"Author: {doc.metadata['author']}\")\n",
" metadata_text = \"\\n\".join(metadata_lines)\n",
"\n",
" new_content = f\"{metadata_text}\\n\\n{doc.page_content}\"\n",
" return Document(page_content=new_content, metadata=doc.metadata)\n",
"\n",
"# Apply to all documents before chunking\n",
"documents_with_metadata = [add_metadata_to_content(doc) for doc in documents]\n",
"\n",
"# Chunking\n",
"text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=200)\n",
"chunks = text_splitter.split_documents(documents_with_metadata)\n",
"\n",
"# Embedding\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"# Store in vector DB\n",
"print(\"Creating vectorstore in batches...\")\n",
"vectorstore = create_vectorstore_with_progress(\n",
" chunks=chunks,\n",
" embeddings=embeddings, \n",
" db_name=db_name,\n",
" batch_size_tokens=250000\n",
")\n",
"\n",
"if vectorstore:\n",
" print(f\"Successfully created vectorstore with {vectorstore._collection.count()} documents\")\n",
"else:\n",
" print(\"Failed to create vectorstore\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46c29b11-2ae3-4f6b-901d-5de67a09fd49",
"metadata": {},
"outputs": [],
"source": [
"# create a new Chat with OpenAI\n",
"llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n",
"\n",
"# set up the conversation memory for the chat\n",
"memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
"\n",
"# the retriever is an abstraction over the VectorStore that will be used during RAG\n",
"retriever = vectorstore.as_retriever(search_kwargs={\"k\": 200})\n",
"\n",
"# putting it together: set up the conversation chain with the GPT 3.5 LLM, the vector store and memory\n",
"conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be163251-0dfa-4f50-ab05-43c6c0833405",
"metadata": {},
"outputs": [],
"source": [
"# Wrapping that in a function\n",
"\n",
"def chat(question, history):\n",
" result = conversation_chain.invoke({\"question\": question})\n",
" return result[\"answer\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6320402-8213-47ec-8b05-dda234052274",
"metadata": {},
"outputs": [],
"source": [
"# And in Gradio:\n",
"\n",
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "717e010b-8d7e-4a43-8cb1-9688ffdd76b6",
"metadata": {},
"outputs": [],
"source": [
"# Let's investigate what gets sent behind the scenes\n",
"\n",
"# from langchain_core.callbacks import StdOutCallbackHandler\n",
"\n",
"# llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n",
"\n",
"# memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
"\n",
"# retriever = vectorstore.as_retriever(search_kwargs={\"k\": 200})\n",
"\n",
"# conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory, callbacks=[StdOutCallbackHandler()])\n",
"\n",
"# query = \"Can you name some authors?\"\n",
"# result = conversation_chain.invoke({\"question\": query})\n",
"# answer = result[\"answer\"]\n",
"# print(\"\\nAnswer:\", answer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2333a77e-8d32-4cc2-8ae9-f8e7a979b3ae",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,472 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "dfe37963-1af6-44fc-a841-8e462443f5e6",
"metadata": {},
"source": [
"## gmail RAG assistant"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba2779af-84ef-4227-9e9e-6eaf0df87e77",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import glob\n",
"from dotenv import load_dotenv\n",
"import gradio as gr\n",
"# NEW IMPORTS FOR GMAIL\n",
"from google.auth.transport.requests import Request\n",
"from google.oauth2.credentials import Credentials\n",
"from google_auth_oauthlib.flow import InstalledAppFlow\n",
"from googleapiclient.discovery import build\n",
"from datetime import datetime\n",
"import base64\n",
"from email.mime.text import MIMEText\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "802137aa-8a74-45e0-a487-d1974927d7ca",
"metadata": {},
"outputs": [],
"source": [
"# imports for langchain, plotly and Chroma\n",
"\n",
"from langchain.document_loaders import DirectoryLoader, TextLoader\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.schema import Document\n",
"from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n",
"from langchain_chroma import Chroma\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.manifold import TSNE\n",
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"from langchain.memory import ConversationBufferMemory\n",
"from langchain.chains import ConversationalRetrievalChain\n",
"from langchain.embeddings import HuggingFaceEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58c85082-e417-4708-9efe-81a5d55d1424",
"metadata": {},
"outputs": [],
"source": [
"# price is a factor for our company, so we're going to use a low cost model\n",
"\n",
"MODEL = \"gpt-4o-mini\"\n",
"db_name = \"vector_db\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee78efcb-60fe-449e-a944-40bab26261af",
"metadata": {},
"outputs": [],
"source": [
"# Load environment variables in a file called .env\n",
"\n",
"load_dotenv(override=True)\n",
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
"# NEW: Gmail API credentials\n",
"SCOPES = ['https://www.googleapis.com/auth/gmail.readonly']\n",
"CREDENTIALS_FILE = 'credentials.json' # Download from Google Cloud Console\n",
"TOKEN_FILE = 'token.json'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "730711a9-6ffe-4eee-8f48-d6cfb7314905",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Read in emails using LangChain's loaders\n",
"# IMPORTANT: set the email received date range hard-coded below\n",
"\n",
"def authenticate_gmail():\n",
" \"\"\"Authenticate and return Gmail service object\"\"\"\n",
" creds = None\n",
" if os.path.exists(TOKEN_FILE):\n",
" creds = Credentials.from_authorized_user_file(TOKEN_FILE, SCOPES)\n",
" \n",
" if not creds or not creds.valid:\n",
" if creds and creds.expired and creds.refresh_token:\n",
" creds.refresh(Request())\n",
" else:\n",
" flow = InstalledAppFlow.from_client_secrets_file(CREDENTIALS_FILE, SCOPES)\n",
" creds = flow.run_local_server(port=0)\n",
" \n",
" with open(TOKEN_FILE, 'w') as token:\n",
" token.write(creds.to_json())\n",
" \n",
" return build('gmail', 'v1', credentials=creds)\n",
"\n",
"def get_email_content(service, message_id):\n",
" \"\"\"Extract email content from message\"\"\"\n",
" try:\n",
" message = service.users().messages().get(userId='me', id=message_id, format='full').execute()\n",
" \n",
" # Extract basic info\n",
" headers = message['payload'].get('headers', [])\n",
" subject = next((h['value'] for h in headers if h['name'] == 'Subject'), 'No Subject')\n",
" sender = next((h['value'] for h in headers if h['name'] == 'From'), 'Unknown Sender')\n",
" date = next((h['value'] for h in headers if h['name'] == 'Date'), 'Unknown Date')\n",
" \n",
" # Extract body\n",
" body = \"\"\n",
" if 'parts' in message['payload']:\n",
" for part in message['payload']['parts']:\n",
" if part['mimeType'] == 'text/plain':\n",
" data = part['body']['data']\n",
" body = base64.urlsafe_b64decode(data).decode('utf-8')\n",
" break\n",
" else:\n",
" if message['payload']['body'].get('data'):\n",
" body = base64.urlsafe_b64decode(message['payload']['body']['data']).decode('utf-8')\n",
" \n",
" # Clean up body text\n",
" body = re.sub(r'\\s+', ' ', body).strip()\n",
" \n",
" return {\n",
" 'subject': subject,\n",
" 'sender': sender,\n",
" 'date': date,\n",
" 'body': body,\n",
" 'id': message_id\n",
" }\n",
" except Exception as e:\n",
" print(f\"Error processing message {message_id}: {str(e)}\")\n",
" return None\n",
"\n",
"def load_gmail_documents(start_date, end_date, max_emails=100):\n",
" \"\"\"Load emails from Gmail between specified dates\"\"\"\n",
" service = authenticate_gmail()\n",
" \n",
" # Format dates for Gmail API (YYYY/MM/DD)\n",
" start_date_str = start_date.strftime('%Y/%m/%d')\n",
" end_date_str = end_date.strftime('%Y/%m/%d')\n",
" \n",
" # Build query\n",
" query = f'after:{start_date_str} before:{end_date_str}'\n",
" \n",
" # Get message list\n",
" result = service.users().messages().list(userId='me', q=query, maxResults=max_emails).execute()\n",
" messages = result.get('messages', [])\n",
" \n",
" print(f\"Found {len(messages)} emails between {start_date_str} and {end_date_str}\")\n",
" \n",
" # Convert to LangChain documents\n",
" documents = []\n",
" for i, message in enumerate(messages):\n",
" print(f\"Processing email {i+1}/{len(messages)}\")\n",
" email_data = get_email_content(service, message['id'])\n",
" \n",
" if email_data and email_data['body']:\n",
" # Create document content\n",
" content = f\"\"\"Subject: {email_data['subject']}\n",
"From: {email_data['sender']}\n",
"Date: {email_data['date']}\n",
"\n",
"{email_data['body']}\"\"\"\n",
" \n",
" # Create LangChain document\n",
" doc = Document(\n",
" page_content=content,\n",
" metadata={\n",
" \"doc_type\": \"email\",\n",
" \"subject\": email_data['subject'],\n",
" \"sender\": email_data['sender'],\n",
" \"date\": email_data['date'],\n",
" \"message_id\": email_data['id']\n",
" }\n",
" )\n",
" documents.append(doc)\n",
" \n",
" return documents\n",
"\n",
"# SET YOUR DATE RANGE HERE\n",
"start_date = datetime(2025, 6, 20) # YYYY, MM, DD\n",
"end_date = datetime(2025, 6, 26) # YYYY, MM, DD\n",
"\n",
"# Load Gmail documents \n",
"documents = load_gmail_documents(start_date, end_date, max_emails=200)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c59de72d-f965-44b3-8487-283e4c623b1d",
"metadata": {},
"outputs": [],
"source": [
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
"chunks = text_splitter.split_documents(documents)\n",
"\n",
"print(f\"Total number of chunks: {len(chunks)}\")\n",
"print(f\"Document types found: {set(doc.metadata['doc_type'] for doc in documents)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78998399-ac17-4e28-b15f-0b5f51e6ee23",
"metadata": {},
"outputs": [],
"source": [
"# Put the chunks of data into a Vector Store that associates a Vector Embedding with each chunk\n",
"# Chroma is a popular open source Vector Database based on SQLLite\n",
"\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"# If you would rather use the free Vector Embeddings from HuggingFace sentence-transformers\n",
"# Then replace embeddings = OpenAIEmbeddings()\n",
"# with:\n",
"# from langchain.embeddings import HuggingFaceEmbeddings\n",
"# embeddings = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-MiniLM-L6-v2\")\n",
"\n",
"# Delete if already exists\n",
"\n",
"if os.path.exists(db_name):\n",
" Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()\n",
"\n",
"# Create vectorstore\n",
"\n",
"vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=db_name)\n",
"print(f\"Vectorstore created with {vectorstore._collection.count()} documents\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff2e7687-60d4-4920-a1d7-a34b9f70a250",
"metadata": {},
"outputs": [],
"source": [
"# Let's investigate the vectors\n",
"\n",
"collection = vectorstore._collection\n",
"count = collection.count()\n",
"\n",
"sample_embedding = collection.get(limit=1, include=[\"embeddings\"])[\"embeddings\"][0]\n",
"dimensions = len(sample_embedding)\n",
"print(f\"There are {count:,} vectors with {dimensions:,} dimensions in the vector store\")"
]
},
{
"cell_type": "markdown",
"id": "b0d45462-a818-441c-b010-b85b32bcf618",
"metadata": {},
"source": [
"## Visualizing the Vector Store\n",
"\n",
"Let's take a minute to look at the documents and their embedding vectors to see what's going on."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b98adf5e-d464-4bd2-9bdf-bc5b6770263b",
"metadata": {},
"outputs": [],
"source": [
"# Prework (with thanks to Jon R for identifying and fixing a bug in this!)\n",
"\n",
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
"vectors = np.array(result['embeddings'])\n",
"documents = result['documents']\n",
"metadatas = result['metadatas']\n",
"\n",
"# Alternatively, color by sender:\n",
"senders = [metadata.get('sender', 'unknown') for metadata in metadatas]\n",
"unique_senders = list(set(senders))\n",
"sender_colors = ['blue', 'green', 'red', 'orange', 'purple', 'brown', 'pink', 'gray']\n",
"colors = [sender_colors[unique_senders.index(sender) % len(sender_colors)] for sender in senders]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "427149d5-e5d8-4abd-bb6f-7ef0333cca21",
"metadata": {},
"outputs": [],
"source": [
"# We humans find it easier to visalize things in 2D!\n",
"# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
"# (t-distributed stochastic neighbor embedding)\n",
"\n",
"tsne = TSNE(n_components=2, random_state=42)\n",
"reduced_vectors = tsne.fit_transform(vectors)\n",
"\n",
"# Create the 2D scatter plot\n",
"fig = go.Figure(data=[go.Scatter(\n",
" x=reduced_vectors[:, 0],\n",
" y=reduced_vectors[:, 1],\n",
" mode='markers',\n",
" marker=dict(size=5, color=colors, opacity=0.8),\n",
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(senders, documents)],\n",
" hoverinfo='text'\n",
")])\n",
"\n",
"fig.update_layout(\n",
" title='2D Chroma Vector Store Visualization',\n",
" scene=dict(xaxis_title='x',yaxis_title='y'),\n",
" width=800,\n",
" height=600,\n",
" margin=dict(r=20, b=10, l=10, t=40)\n",
")\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1418e88-acd5-460a-bf2b-4e6efc88e3dd",
"metadata": {},
"outputs": [],
"source": [
"# Let's try 3D!\n",
"\n",
"tsne = TSNE(n_components=3, random_state=42)\n",
"reduced_vectors = tsne.fit_transform(vectors)\n",
"\n",
"# Create the 3D scatter plot\n",
"fig = go.Figure(data=[go.Scatter3d(\n",
" x=reduced_vectors[:, 0],\n",
" y=reduced_vectors[:, 1],\n",
" z=reduced_vectors[:, 2],\n",
" mode='markers',\n",
" marker=dict(size=5, color=colors, opacity=0.8),\n",
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(senders, documents)],\n",
" hoverinfo='text'\n",
")])\n",
"\n",
"fig.update_layout(\n",
" title='3D Chroma Vector Store Visualization',\n",
" scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),\n",
" width=900,\n",
" height=700,\n",
" margin=dict(r=20, b=10, l=10, t=40)\n",
")\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "markdown",
"id": "bbbcb659-13ce-47ab-8a5e-01b930494964",
"metadata": {},
"source": [
"## Langchain and Gradio to prototype a chat with the LLM\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d72567e8-f891-4797-944b-4612dc6613b1",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains.combine_documents import create_stuff_documents_chain\n",
"from langchain.chains import create_retrieval_chain\n",
"\n",
"# create a new Chat with OpenAI\n",
"llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n",
"\n",
"# Alternative - if you'd like to use Ollama locally, uncomment this line instead\n",
"# llm = ChatOpenAI(temperature=0.7, model_name='llama3.2', base_url='http://localhost:11434/v1', api_key='ollama')\n",
"\n",
"# change LLM standard prompt (standard prompt defaults the answer to be 'I don't know' too often, especially when using a small LLM\n",
"\n",
"qa_prompt=PromptTemplate.from_template(\"Use the following pieces of context to answer the user's question. Answer as best you can given the information you have;\\\n",
" if you have a reasonable idea of the answer,/then explain it and mention that you're unsure. \\\n",
" But if you don't know the answer, don't make it up. \\\n",
" {context} \\\n",
" Question: {question} \\\n",
" Helpful Answer:\"\n",
" )\n",
"\n",
"\n",
"# Wrap into a StuffDocumentsChain, matching the variable name 'context'\n",
"combine_docs_chain = create_stuff_documents_chain(\n",
" llm=llm,\n",
" prompt=qa_prompt,\n",
" document_variable_name=\"context\"\n",
")\n",
"\n",
"# set up the conversation memory for the chat\n",
"#memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
"memory = ConversationBufferMemory(\n",
" memory_key='chat_history', \n",
" return_messages=True,\n",
" output_key='answer' \n",
")\n",
"\n",
"# the retriever is an abstraction over the VectorStore that will be used during RAG\n",
"retriever = vectorstore.as_retriever(search_kwargs={\"k\": 10})\n",
"\n",
"# putting it together: set up the conversation chain with the GPT 3.5 LLM, the vector store and memory\n",
"# conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)\n",
"\n",
"conversation_chain = ConversationalRetrievalChain.from_llm(\n",
" llm=llm,\n",
" retriever=retriever,\n",
" memory=memory,\n",
" combine_docs_chain_kwargs={\"prompt\": qa_prompt},\n",
" return_source_documents=True\n",
")\n",
"\n",
"def chat(question, history):\n",
" result = conversation_chain.invoke({\"question\": question})\n",
" return result[\"answer\"]\n",
"\n",
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe4229aa-6afe-4592-93a4-71a47ab69846",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}