Merge branch 'ed-donner:main' into llm-engineering-contributions-omar
This commit is contained in:
@@ -0,0 +1,266 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0",
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from openai import OpenAI\n",
|
||||
"from IPython.display import Markdown, display"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1",
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load environment variables in a file called .env\n",
|
||||
"\n",
|
||||
"load_dotenv()\n",
|
||||
"api_key = os.getenv('OPENAI_API_KEY')\n",
|
||||
"\n",
|
||||
"# Check the key\n",
|
||||
"\n",
|
||||
"if not api_key:\n",
|
||||
" print(\"No API key was found - please head over to the troubleshooting notebook in this folder to identify & fix!\")\n",
|
||||
"elif not api_key.startswith(\"sk-proj-\"):\n",
|
||||
" print(\"An API key was found, but it doesn't start sk-proj-; please check you're using the right key - see troubleshooting notebook\")\n",
|
||||
"elif api_key.strip() != api_key:\n",
|
||||
" print(\"An API key was found, but it looks like it might have space or tab characters at the start or end - please remove them - see troubleshooting notebook\")\n",
|
||||
"else:\n",
|
||||
" print(\"API key found and looks good so far!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2",
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"openai = OpenAI()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's just make sure the model is loaded\n",
|
||||
"!ollama pull llama3.2\n",
|
||||
"import ollama\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4",
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# System prompt - defines the AI's behavior\n",
|
||||
"SYSTEM_PROMPT = \"\"\"You are a helpful cooking assistant that provides ingredient lists for recipes.\n",
|
||||
"Format your response as clean markdown with this structure:\n",
|
||||
"\n",
|
||||
"# [Dish Name]\n",
|
||||
"**Serves:** [number] people \n",
|
||||
"**Cook Time:** [estimated time]\n",
|
||||
"\n",
|
||||
"## Shopping List\n",
|
||||
"- [ ] [amount] [unit] [ingredient]\n",
|
||||
"- [ ] [amount] [unit] [ingredient]\n",
|
||||
"\n",
|
||||
"Guidelines:\n",
|
||||
"- Use common grocery store measurements (cups, lbs, oz, pieces, cans, etc.)\n",
|
||||
"- Round to practical shopping amounts (1.5 lbs instead of 1.47 lbs)\n",
|
||||
"- Group similar items when logical (all spices together)\n",
|
||||
"- Include pantry staples only if they're essential (salt, oil, etc.)\n",
|
||||
"- Assume basic seasonings are available unless recipe-specific\n",
|
||||
"- For produce, specify size when important (large onion, medium tomatoes)\n",
|
||||
"- Keep optional items at the end of similar item groups or end of the list\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5",
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_recipe_openai(dish_name: str, num_people: int):\n",
|
||||
" \"\"\"Get scaled recipe ingredients using system and user prompts\"\"\"\n",
|
||||
"\n",
|
||||
" user_prompt = f\"Give me the ingredients needed to make {dish_name} for {num_people} people.\"\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" response = openai.chat.completions.create(\n",
|
||||
" model=\"gpt-4o-mini\",\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt}\n",
|
||||
" ],\n",
|
||||
" max_tokens=400\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" return response.choices[0].message.content\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"❌ Error: Failed to get recipe - {str(e)}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"OLLAMA_MODEL = \"llama3.2\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_recipe_ollama(dish_name: str, num_people: int):\n",
|
||||
" \"\"\"Get recipe using Ollama API\"\"\"\n",
|
||||
" user_prompt = f\"Give me the ingredients needed to make {dish_name} for {num_people} people.\"\n",
|
||||
" \n",
|
||||
" messages = [\n",
|
||||
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt}\n",
|
||||
" ]\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" response = ollama.chat(model=OLLAMA_MODEL, messages=messages)\n",
|
||||
" return response['message']['content']\n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"❌ Ollama Error: {str(e)}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8",
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def print_shopping_list(recipe_markdown):\n",
|
||||
" \"\"\"Print the markdown response\"\"\"\n",
|
||||
" display(Markdown(recipe_markdown))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9",
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"🍳 Recipe Scaler & Grocery List Maker\")\n",
|
||||
"print(\"=\" * 40)\n",
|
||||
" \n",
|
||||
"ai_service_choice = input(\"\\nChoose AI service (1 for OpenAI, 2 for Ollama): \").strip()\n",
|
||||
"\n",
|
||||
"dish = input(\"What dish do you want to make? \")\n",
|
||||
"num_people = int(input(\"How many people? \"))\n",
|
||||
" \n",
|
||||
"print(f\"\\n🔍 Getting recipe for {dish}...\")\n",
|
||||
" \n",
|
||||
"# Get and display recipe\n",
|
||||
"if ai_service_choice == '1':\n",
|
||||
" print(\"Using OpenAI API...\")\n",
|
||||
" recipe_markdown = get_recipe_openai(dish, num_people)\n",
|
||||
"else:\n",
|
||||
" print(\"Using Ollama (local)...\")\n",
|
||||
" recipe_markdown = get_recipe_ollama(dish, num_people)\n",
|
||||
"\n",
|
||||
"print_shopping_list(recipe_markdown)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "10",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fe12c203-e6a6-452c-a655-afb8a03a4ff5",
|
||||
"id": "0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# End of week 1 exercise\n",
|
||||
@@ -13,22 +13,30 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "c1070317-3ed9-4659-abe3-828943230e03",
|
||||
"execution_count": null,
|
||||
"id": "1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"from IPython.display import Markdown, display, update_display\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"import os\n",
|
||||
"import openai\n",
|
||||
"from openai import OpenAI\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "4a456906-915a-4bfd-bb9d-57e505c5093f",
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"id": "2",
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# constants\n",
|
||||
@@ -37,6 +45,9 @@
|
||||
" 'MODEL_LLAMA': 'llama3.2'\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"api_key = os.getenv(\"OPENAI_API_KEY\")\n",
|
||||
"\n",
|
||||
"# To use ollama using openai API (ensure that ollama is running on localhost)\n",
|
||||
"ollama_via_openai = OpenAI(base_url='http://localhost:11434/v1', api_key='ollama')\n",
|
||||
"\n",
|
||||
@@ -57,9 +68,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "a8d7923c-5f28-4c30-8556-342d7c8497c1",
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"id": "3",
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# set up environment\n",
|
||||
@@ -89,8 +106,8 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "3f0d0137-52b0-47a8-81a8-11a90a010798",
|
||||
"execution_count": null,
|
||||
"id": "4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -105,67 +122,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "60ce7000-a4a5-4cce-a261-e75ef45063b4",
|
||||
"id": "5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/markdown": [
|
||||
"**Understanding the Code Snippet**\n",
|
||||
"\n",
|
||||
"This Python code snippet uses a combination of built-in functions, dictionary iteration, and generator expressions to extract and yield author names from a list of `Book` objects.\n",
|
||||
"\n",
|
||||
"Here's a breakdown:\n",
|
||||
"\n",
|
||||
"1. **Dictionary Iteration**: The expression `for book in books if book.get(\"author\")`\n",
|
||||
" - Iterates over each element (`book`) in the container `books`.\n",
|
||||
" - Filters out elements whose `'author'` key does not have a value (i.e., `None`, `False`, or an empty string). This leaves only dictionaries with author information.\n",
|
||||
"\n",
|
||||
"2. **Dictionary Access**: The expression `{book.get(\"author\") for book in books if book.get(\"author\")}`\n",
|
||||
" - Uses dictionary membership testing to access only the values associated with the `'author'` key.\n",
|
||||
" - If the value is not found or is considered false, it's skipped in this particular case.\n",
|
||||
"\n",
|
||||
"3. **Generator Expression**: This generates an iterator that iterates over the filtered author names.\n",
|
||||
" - Yields each author name (i.e., a single `'name'` from the book dictionary) on demand.\n",
|
||||
" - Since these are generator expressions, they use memory less than equivalent Python lists and also create results on-demand.\n",
|
||||
"\n",
|
||||
"4. **`yield from`**: This statement takes the generator expression as an argument and uses it to generate a nested iterator structure.\n",
|
||||
" - It essentially \"decompresses\" the single level of nested iterator created by `list(iter(x))`, allowing for simpler use cases and potentially significant efficiency improvements for more complex structures where every value must be iterated, while in the latter case just the first item per iterable in the outer expression's sequence needs to actually be yielded into result stream.\n",
|
||||
" - By \"yielding\" a nested iterator (the generator expression), we can simplify code by avoiding repetitive structure like `for book, book_author in zip(iterating over), ...` or list creation.\n",
|
||||
"\n",
|
||||
"**Example Use Case**\n",
|
||||
"\n",
|
||||
"In this hypothetical example:\n",
|
||||
"\n",
|
||||
"# Example Book objects\n",
|
||||
"class Book:\n",
|
||||
" def __init__(self, author, title):\n",
|
||||
" self.author = author # str\n",
|
||||
" self.title = title\n",
|
||||
"\n",
|
||||
"books = [\n",
|
||||
" {\"author\": \"John Doe\", \"title\": f\"Book 1 by John Doe\"},\n",
|
||||
" {\"author\": None, \"title\": f\"Book 2 without Author\"},\n",
|
||||
" {\"author\": \"Jane Smith\", \"title\": f\"Book 3 by Jane Smith\"}\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# The given expression to extract and yield author names\n",
|
||||
"for author in yield from {book.get(\"author\") for book in books if book.get(\"author\")}:\n",
|
||||
"\n",
|
||||
" print(author) \n",
|
||||
"\n",
|
||||
"In this code snippet, printing the extracted authors would output `John Doe`, `Jane Smith` (since only dictionaries with author information pass the filtering test).\n",
|
||||
"\n",
|
||||
"Please modify it like as you wish and use `yield from` along with dictionary iteration, list comprehension or generator expression if needed, and explain what purpose your version has."
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Markdown object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Get the model of your choice (choices appeared below) to answer, with streaming \n",
|
||||
"\n",
|
||||
@@ -174,13 +133,21 @@
|
||||
" 'MODEL_LLAMA': 'llama3.2'\n",
|
||||
"}\"\"\"\n",
|
||||
"\n",
|
||||
"stream_brochure(question,'MODEL_LLAMA')"
|
||||
"stream_brochure(question,'MODEL_GPT')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "llms",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -194,7 +161,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
64
week1/community-contributions/week1-jedi-master.py
Normal file
64
week1/community-contributions/week1-jedi-master.py
Normal file
@@ -0,0 +1,64 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
from IPython.display import Markdown, display, update_display
|
||||
|
||||
def load_openai_key():
|
||||
# Load environment variables in a file called .env
|
||||
load_dotenv(override=True)
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
# Check the key
|
||||
if not api_key:
|
||||
return "Error: No API key was found!"
|
||||
elif not api_key.startswith("sk-proj-"):
|
||||
return "Error: An API key was found, but it doesn't start sk-proj-; please check you're using the right key"
|
||||
elif api_key.strip() != api_key:
|
||||
return "Error: An API key was found, but it looks like it might have space or tab characters at the start or end - please remove them!"
|
||||
else:
|
||||
return "API key found and looks good so far!"
|
||||
|
||||
def ask_llm(client, model, user_prompt):
|
||||
system_prompt = """
|
||||
You are a wise Jedi Master and an excellent teacher.
|
||||
You will answer any question you are given by breaking it down into small steps
|
||||
that even a complete beginner will understand.
|
||||
When answering, speak as if you are Yoda from the Star Wars universe.
|
||||
Also, refer to the user as "My young Padawan"
|
||||
End every answer with "May the force be with you, always."
|
||||
"""
|
||||
response = client.chat.completions.create(
|
||||
model = model,
|
||||
messages = [ {"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}]
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="JedAI Master instructor")
|
||||
parser.add_argument("provider", choices=["openai", "ollama"], help="AI provider to use")
|
||||
parser.add_argument("--model", help="Model to use for Ollama (required if provider is 'ollama')", required="ollama" in parser.parse_known_args()[0].provider)
|
||||
parser.add_argument("question", help="What knowledge do you seek, my young Padawan?")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.provider == "openai":
|
||||
load_openai_key()
|
||||
client = OpenAI()
|
||||
model = "gpt-4o-mini"
|
||||
elif args.provider == "ollama":
|
||||
client = OpenAI(base_url='http://localhost:11434/v1', api_key='ollama')
|
||||
model = args.model
|
||||
else:
|
||||
return "Error: invalid provider!"
|
||||
|
||||
user_prompt = args.question
|
||||
|
||||
result = ask_llm(client, model, user_prompt)
|
||||
print("AI Response:", result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
450
week3/community-contributions/06_meeting_minute_assistant.ipynb
Normal file
450
week3/community-contributions/06_meeting_minute_assistant.ipynb
Normal file
File diff suppressed because one or more lines are too long
381
week3/community-contributions/Week3-Dataset_Generator-DP.ipynb
Normal file
381
week3/community-contributions/Week3-Dataset_Generator-DP.ipynb
Normal file
@@ -0,0 +1,381 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c08309b8-13f0-45bb-a3ea-7b01f05a7346",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"import pandas as pd\n",
|
||||
"import random\n",
|
||||
"import re\n",
|
||||
"import subprocess\n",
|
||||
"import pyarrow as pa\n",
|
||||
"from typing import List\n",
|
||||
"import openai\n",
|
||||
"import anthropic\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"import gradio as gr"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f5efd903-e683-4e7f-8747-2998e23a0751",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load API\n",
|
||||
"load_dotenv(override=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ce49b86a-53f4-4d4f-a721-0d66d9c1b070",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# --- Schema Definition ---\n",
|
||||
"SCHEMA = [\n",
|
||||
" (\"Team\", \"TEXT\", '\"Toronto Raptors\"'),\n",
|
||||
" (\"NAME\", \"TEXT\", '\"Otto Porter Jr.\"'),\n",
|
||||
" (\"Jersey\", \"TEXT\", '\"10\", or \"NA\" if null'),\n",
|
||||
" (\"POS\", \"TEXT\", 'One of [\"PF\",\"SF\",\"G\",\"C\",\"SG\",\"F\",\"PG\"]'),\n",
|
||||
" (\"AGE\", \"INT\", 'integer age in years, e.g., 22'),\n",
|
||||
" (\"HT\", \"TEXT\", '`6\\' 7\"` or `6\\' 10\"`'),\n",
|
||||
" (\"WT\", \"TEXT\", '\"232 lbs\"'),\n",
|
||||
" (\"COLLEGE\", \"TEXT\", '\"Michigan\", or \"--\" if null'),\n",
|
||||
" (\"SALARY\", \"TEXT\", '\"$9,945,830\", or \"--\" if null')\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "93743e57-c2c5-43e5-8fa1-2e242085db07",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Default schema text for the textbox\n",
|
||||
"DEFAULT_SCHEMA_TEXT = \"\\n\".join([f\"{i+1}. {col[0]} ({col[1]}) Example: {col[2]}\" for i, col in enumerate(SCHEMA)])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "87c58595-6fdd-48f5-a253-ccba352cb385",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Available models\n",
|
||||
"MODELS = [\n",
|
||||
" \"gpt-4o\",\n",
|
||||
" \"claude-3-5-haiku-20241022\", \n",
|
||||
" \"ollama:llama3.2:latest\"\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "08cd9ce2-8685-46b5-95d0-811b8025696f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Available file formats\n",
|
||||
"FILE_FORMATS = [\".csv\", \".tsv\", \".jsonl\", \".parquet\", \".arrow\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "13d68c7f-6f49-4efa-b075-f1e7db2ab527",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_prompt(n: int, schema_text: str, system_prompt: str) -> str:\n",
|
||||
" prompt = f\"\"\"\n",
|
||||
"{system_prompt}\n",
|
||||
"\n",
|
||||
"Generate {n} rows of realistic basketball player data in JSONL format, each line a JSON object with the following fields:\n",
|
||||
"\n",
|
||||
"{schema_text}\n",
|
||||
"\n",
|
||||
"Do NOT repeat column values from one row to another.\n",
|
||||
"\n",
|
||||
"Only output valid JSONL.\n",
|
||||
"\"\"\"\n",
|
||||
" return prompt.strip()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cdc68f1e-4fbe-45dc-aa36-ce5f718ef6ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# --- LLM Interface ---\n",
|
||||
"def query_model(prompt: str, model: str = \"gpt-4o\") -> List[dict]:\n",
|
||||
" \"\"\"Call OpenAI, Claude, or Ollama\"\"\"\n",
|
||||
" try:\n",
|
||||
" if model.lower().startswith(\"gpt\"):\n",
|
||||
" client = openai.OpenAI(api_key=os.getenv(\"OPENAI_API_KEY\"))\n",
|
||||
" response = client.chat.completions.create(\n",
|
||||
" model=model,\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" temperature=0.7\n",
|
||||
" )\n",
|
||||
" content = response.choices[0].message.content\n",
|
||||
"\n",
|
||||
" elif model.lower().startswith(\"claude\"):\n",
|
||||
" client = anthropic.Anthropic(api_key=os.getenv(\"ANTHROPIC_API_KEY\"))\n",
|
||||
" response = client.messages.create(\n",
|
||||
" model=model,\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" max_tokens=4000,\n",
|
||||
" temperature=0.7\n",
|
||||
" )\n",
|
||||
" content = response.content[0].text\n",
|
||||
"\n",
|
||||
" elif model.lower().startswith(\"ollama:\"):\n",
|
||||
" ollama_model = model.split(\":\")[1]\n",
|
||||
" result = subprocess.run(\n",
|
||||
" [\"ollama\", \"run\", ollama_model],\n",
|
||||
" input=prompt,\n",
|
||||
" text=True,\n",
|
||||
" capture_output=True\n",
|
||||
" )\n",
|
||||
" if result.returncode != 0:\n",
|
||||
" raise Exception(f\"Ollama error: {result.stderr}\")\n",
|
||||
" content = result.stdout\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"Unsupported model. Use 'gpt-4.1-mini', 'claude-3-5-haiku-20241022', or 'ollama:llama3.2:latest'\")\n",
|
||||
"\n",
|
||||
" # Parse JSONL output\n",
|
||||
" lines = [line.strip() for line in content.strip().splitlines() if line.strip().startswith(\"{\")]\n",
|
||||
" return [json.loads(line) for line in lines]\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" raise Exception(f\"Model query failed: {str(e)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "29e3f5f5-e99c-429c-bea9-69d554c58c9c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# --- Output Formatter ---\n",
|
||||
"def save_dataset(records: List[dict], file_format: str, filename: str):\n",
|
||||
" df = pd.DataFrame(records)\n",
|
||||
" if file_format == \".csv\":\n",
|
||||
" df.to_csv(filename, index=False)\n",
|
||||
" elif file_format == \".tsv\":\n",
|
||||
" df.to_csv(filename, sep=\"\\t\", index=False)\n",
|
||||
" elif file_format == \".jsonl\":\n",
|
||||
" with open(filename, \"w\") as f:\n",
|
||||
" for record in records:\n",
|
||||
" f.write(json.dumps(record) + \"\\n\")\n",
|
||||
" elif file_format == \".parquet\":\n",
|
||||
" df.to_parquet(filename, engine=\"pyarrow\", index=False)\n",
|
||||
" elif file_format == \".arrow\":\n",
|
||||
" table = pa.Table.from_pandas(df)\n",
|
||||
" with pa.OSFile(filename, \"wb\") as sink:\n",
|
||||
" with pa.ipc.new_file(sink, table.schema) as writer:\n",
|
||||
" writer.write(table)\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"Unsupported file format\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fe258e84-66f4-4fe7-99c0-75b24148e147",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# --- Main Generation Function ---\n",
|
||||
"def generate_dataset(schema_text, system_prompt, model, nr_records, file_format, save_as):\n",
|
||||
" try:\n",
|
||||
" # Validation\n",
|
||||
" if nr_records <= 10:\n",
|
||||
" return \"❌ Error: Nr_records must be greater than 10.\", None\n",
|
||||
" \n",
|
||||
" if file_format not in FILE_FORMATS:\n",
|
||||
" return \"❌ Error: Invalid file format specified.\", None\n",
|
||||
" \n",
|
||||
" if not save_as or save_as.strip() == \"\":\n",
|
||||
" save_as = f\"basketball_dataset{file_format}\"\n",
|
||||
" elif not save_as.endswith(file_format):\n",
|
||||
" save_as = save_as + file_format\n",
|
||||
" \n",
|
||||
" # Generate prompt\n",
|
||||
" prompt = get_prompt(nr_records, schema_text, system_prompt)\n",
|
||||
" \n",
|
||||
" # Query model\n",
|
||||
" records = query_model(prompt, model=model)\n",
|
||||
" \n",
|
||||
" if not records:\n",
|
||||
" return \"❌ Error: No valid records generated from the model.\", None\n",
|
||||
" \n",
|
||||
" # Save dataset\n",
|
||||
" save_dataset(records, file_format, save_as)\n",
|
||||
" \n",
|
||||
" # Create preview\n",
|
||||
" df = pd.DataFrame(records)\n",
|
||||
" preview = df.head(10) # Show first 10 rows\n",
|
||||
" \n",
|
||||
" success_message = f\"✅ Dataset generated successfully!\\n📁 Saved to: {save_as}\\n📊 Generated {len(records)} records\"\n",
|
||||
" \n",
|
||||
" return success_message, preview\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"❌ Error: {str(e)}\", None"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c2405a9d-b4cd-43d9-82f6-ff3512b4541f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# --- Gradio Interface ---\n",
|
||||
"def create_interface():\n",
|
||||
" with gr.Blocks(title=\"Dataset Generator\", theme=gr.themes.Soft()) as interface:\n",
|
||||
" gr.Markdown(\"# Dataset Generator\")\n",
|
||||
" gr.Markdown(\"Generate realistic datasets using AI models\")\n",
|
||||
" \n",
|
||||
" with gr.Row():\n",
|
||||
" with gr.Column(scale=2):\n",
|
||||
" schema_input = gr.Textbox(\n",
|
||||
" label=\"Schema\",\n",
|
||||
" value=DEFAULT_SCHEMA_TEXT,\n",
|
||||
" lines=15,\n",
|
||||
" placeholder=\"Define your dataset schema here...\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" system_prompt_input = gr.Textbox(\n",
|
||||
" label=\"Prompt\",\n",
|
||||
" value=\"You are a helpful assistant that generates realistic basketball player data.\",\n",
|
||||
" lines=1,\n",
|
||||
" placeholder=\"Enter system prompt for the model...\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" with gr.Row():\n",
|
||||
" model_dropdown = gr.Dropdown(\n",
|
||||
" label=\"Model\",\n",
|
||||
" choices=MODELS,\n",
|
||||
" value=MODELS[1], # Default to Claude\n",
|
||||
" interactive=True\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" nr_records_input = gr.Number(\n",
|
||||
" label=\"Nr. records\",\n",
|
||||
" value=25,\n",
|
||||
" minimum=11,\n",
|
||||
" maximum=1000,\n",
|
||||
" step=1\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" with gr.Row():\n",
|
||||
" file_format_dropdown = gr.Dropdown(\n",
|
||||
" label=\"File format\",\n",
|
||||
" choices=FILE_FORMATS,\n",
|
||||
" value=\".csv\",\n",
|
||||
" interactive=True\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" save_as_input = gr.Textbox(\n",
|
||||
" label=\"Save as\",\n",
|
||||
" value=\"basketball_dataset\",\n",
|
||||
" placeholder=\"Enter filename (extension will be added automatically)\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" generate_btn = gr.Button(\"🚀 Generate\", variant=\"primary\", size=\"lg\")\n",
|
||||
" \n",
|
||||
" with gr.Column(scale=1):\n",
|
||||
" output_status = gr.Textbox(\n",
|
||||
" label=\"Status\",\n",
|
||||
" lines=4,\n",
|
||||
" interactive=False\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" output_preview = gr.Dataframe(\n",
|
||||
" label=\"Preview (First 10 rows)\",\n",
|
||||
" interactive=False,\n",
|
||||
" wrap=True\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Connect the generate button\n",
|
||||
" generate_btn.click(\n",
|
||||
" fn=generate_dataset,\n",
|
||||
" inputs=[\n",
|
||||
" schema_input,\n",
|
||||
" system_prompt_input, \n",
|
||||
" model_dropdown,\n",
|
||||
" nr_records_input,\n",
|
||||
" file_format_dropdown,\n",
|
||||
" save_as_input\n",
|
||||
" ],\n",
|
||||
" outputs=[output_status, output_preview]\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" gr.Markdown(\"\"\"\n",
|
||||
" ### 📝 Instructions:\n",
|
||||
" 1. **Schema**: Define the structure of your dataset (pre-filled with basketball player schema)\n",
|
||||
" 2. **Prompt**: System prompt to guide the AI model\n",
|
||||
" 3. **Model**: Choose between GPT, Claude, or Ollama models\n",
|
||||
" 4. **Nr. records**: Number of records to generate (minimum 11)\n",
|
||||
" 5. **File format**: Choose output format (.csv, .tsv, .jsonl, .parquet, .arrow)\n",
|
||||
" 6. **Save as**: Filename (extension added automatically)\n",
|
||||
" 7. Click **Generate** to create your dataset\n",
|
||||
" \n",
|
||||
" ### 🔧 Requirements:\n",
|
||||
" - Set up your API keys in `.env` file (`OPENAI_API_KEY`, `ANTHROPIC_API_KEY`)\n",
|
||||
" - For Ollama models, ensure Ollama is installed and running locally\n",
|
||||
" \"\"\")\n",
|
||||
" \n",
|
||||
" return interface"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "50fd2b91-2578-4224-b9dd-e28caf6a0a85",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"interface = create_interface()\n",
|
||||
"interface.launch(inbrowser=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
102
week3/community-contributions/muawiya/README.md
Normal file
102
week3/community-contributions/muawiya/README.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# 🧠 Synthetic Data Generator
|
||||
|
||||
A Python-based tool to generate structured, synthetic job postings using open-source LLMs from Hugging Face.
|
||||
This project supports both **script-based execution** and an **interactive Colab notebook**, making it ideal for rapid prototyping, dataset bootstrapping, or demonstrating prompt engineering techniques.
|
||||
|
||||
> Note: Original Repo can be found at: https://github.com/moawiah/synthetic_data_generator
|
||||
|
||||
|
||||

|
||||
|
||||
|
||||
This tool helps:
|
||||
- Researchers create labeled training data for NLP classification or QA
|
||||
- HR tech startups prototype recommendation models
|
||||
- AI instructors demonstrate few-shot prompting in class
|
||||
|
||||
|
||||
---
|
||||
|
||||
## ✨ Features
|
||||
|
||||
- 🔗 Integrates Hugging Face Transformer models
|
||||
- 📄 Generates realistic job postings in structured JSON format
|
||||
- 🧪 Supports prompt engineering with control over output length and variability
|
||||
- 🧠 Minimal Gradio UI for non-technical users
|
||||
- 📓 Jupyter/Colab support for experimentation and reproducibility
|
||||
|
||||
## 📂 Project Structure
|
||||
<pre> ```
|
||||
. ├── app/
|
||||
│
|
||||
├── app.py # Main script entry point
|
||||
│
|
||||
├── consts.py # Configuration and constants
|
||||
│
|
||||
└── requirements.txt # Python dependencies
|
||||
├── data/
|
||||
│
|
||||
└── software_engineer_jobs.json # Sample input data (JSON format)
|
||||
├── notebooks/
|
||||
│
|
||||
└── synthetic_data_generator.ipynb # Interactive Colab notebook
|
||||
├── .env.example # Sample environment variable config
|
||||
├── .gitignore # Git ignored files list
|
||||
└── README.md
|
||||
``` </pre>
|
||||
|
||||
## 🚀 Getting Started
|
||||
|
||||
### 1. Clone the repository
|
||||
```bash
|
||||
git clone https://github.com/moawiah/synthetic_data_generator.git
|
||||
cd synthetic_data_generator
|
||||
```
|
||||
### Install Dependencies
|
||||
```bah
|
||||
pip install -r app/requirements.txt
|
||||
```
|
||||
### Hugging Face Token
|
||||
You need to create a `.env` file with your HuggingFace token like `HF_TOKEN=your-token-here`
|
||||
|
||||
### Run
|
||||
run the app using
|
||||
`python app/app.py`
|
||||
|
||||
|
||||
## Example Output - 1 Job
|
||||
|
||||
```JSON
|
||||
{
|
||||
"title": "Software Engineer"
|
||||
,
|
||||
"description": "We are seeking a highly skilled software engineer to join our team and contribute to the development of innovative software solutions. The ideal candidate will have experience in designing, coding, and testing software systems, and will be able to work collaboratively with cross-functional teams. Responsibilities include writing clean, maintainable, and efficient code, as well as actively participating in code reviews and continuous integration processes. This is an excellent opportunity for a self-starter with a passion for technology and a desire to grow in their career."
|
||||
,
|
||||
"requirements":[
|
||||
"0":"Bachelor's degree in Computer Science or related field",
|
||||
"1":"Minimum of 2 years experience in software development",
|
||||
"2":"Strong proficiency in Java or C++",
|
||||
"3":"Experience with agile development methodologies",
|
||||
"4":"Good understanding of data structures and algorithms",
|
||||
"5":"Excellent problem-solving and analytical skills"
|
||||
],
|
||||
"location":"New York, NY",
|
||||
"company_name":"ABC Technologies"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Future Improvements
|
||||
🔁 Add support for more job roles and industries
|
||||
|
||||
🧠 Model selector from UI
|
||||
|
||||
💾 Export dataset as CSV
|
||||
|
||||
☁️ Optional integration with LangChain or RAG workflows
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
156
week3/community-contributions/muawiya/app/app.py
Normal file
156
week3/community-contributions/muawiya/app/app.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
import requests
|
||||
from IPython.display import Markdown, display, update_display
|
||||
from openai import OpenAI
|
||||
from google.colab import drive
|
||||
from huggingface_hub import login
|
||||
from google.colab import userdata
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig, pipeline, TextGenerationPipeline
|
||||
import torch
|
||||
from consts import FALCON, MISTRAL, Databricks
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import ast
|
||||
import gradio as gr
|
||||
import re
|
||||
|
||||
# Sign in to HuggingFace Hub
|
||||
load_dotenv()
|
||||
hf_token = os.getenv("HF_TOKEN")
|
||||
|
||||
|
||||
# Main Prompt
|
||||
prompt = """
|
||||
Generate one fake job posting for a {{role}}.
|
||||
|
||||
Return only a single JSON object with:
|
||||
- title
|
||||
- description (5-10 sentences)
|
||||
- requirements (array of 4-6 strings)
|
||||
- location
|
||||
- company_name
|
||||
|
||||
No explanations, no extra text.
|
||||
Only the JSON object.
|
||||
"""
|
||||
|
||||
# Main Conf
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
|
||||
def load_model_and_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained(MISTRAL, trust_remote_code=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MISTRAL,
|
||||
device_map={"": "cuda"},
|
||||
trust_remote_code=True,
|
||||
offload_folder="/tmp/dolly_offload",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def generate_job(role="Software Engineer", model=None, tokenizer=None):
|
||||
# prompt = prompt.format(role=role, n=n)
|
||||
# outputs = generator(prompt, max_new_tokens=500, do_sample=True, temperature=0.9)
|
||||
# return outputs[0]['generated_text']
|
||||
|
||||
# Apply chat template formatting
|
||||
# inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||||
inputs = tokenizer(prompt.format(role=role), return_tensors="pt")
|
||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||
|
||||
|
||||
# Generate output
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=600,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
top_p=0.9,
|
||||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Decode and return
|
||||
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
return result
|
||||
|
||||
def generate_jobs(role="Software Engineer", n=5):
|
||||
model, tokenizer = load_model_and_tokenizer()
|
||||
role = "Software Engineer"
|
||||
fake_jobs = []
|
||||
for i in range(n):
|
||||
fake_jobs.append(generate_job(role=role, model=model, tokenizer=tokenizer))
|
||||
return fake_jobs
|
||||
|
||||
def extract_json_objects_from_text_block(texts):
|
||||
"""
|
||||
Accepts either a single string or a list of strings.
|
||||
Extracts all valid JSON objects from messy text blocks.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts] # wrap in list if single string
|
||||
|
||||
pattern = r"\{[\s\S]*?\}"
|
||||
results = []
|
||||
|
||||
for raw_text in texts:
|
||||
matches = re.findall(pattern, raw_text)
|
||||
for match in matches:
|
||||
try:
|
||||
obj = json.loads(match)
|
||||
results.append(obj)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def generate_ui(role, n):
|
||||
try:
|
||||
raw_jobs = generate_jobs(role, n)
|
||||
parsed_jobs = extract_json_objects_from_text_block(raw_jobs)
|
||||
|
||||
if not isinstance(parsed_jobs, list) or not all(isinstance(item, dict) for item in parsed_jobs):
|
||||
print("[ERROR] Parsed result is not a list of dicts")
|
||||
return gr.update(value=[], visible=True), None
|
||||
|
||||
filename = f"data/{role.replace(' ', '_').lower()}_jobs.json"
|
||||
with open(filename, "w") as f:
|
||||
json.dump(parsed_jobs, f, indent=2)
|
||||
|
||||
print(f"[INFO] Returning {len(parsed_jobs)} jobs -> {filename}")
|
||||
return parsed_jobs, filename
|
||||
|
||||
except Exception as e:
|
||||
print(f"[FATAL ERROR] {e}")
|
||||
return gr.update(value=[], visible=True), None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# 🧠 Synthetic Job Dataset Generator")
|
||||
gr.Markdown("Generate a structured dataset of job postings for a specific role.")
|
||||
|
||||
with gr.Row():
|
||||
role_input = gr.Textbox(label="Job Role", placeholder="e.g. Software Engineer", value="Software Engineer")
|
||||
n_input = gr.Number(label="Number of Samples", value=5, precision=0)
|
||||
|
||||
generate_button = gr.Button("🚀 Generate")
|
||||
output_table = gr.JSON(label="Generated Dataset")
|
||||
download_button = gr.File(label="Download JSON")
|
||||
|
||||
generate_button.click(
|
||||
generate_ui,
|
||||
inputs=[role_input, n_input],
|
||||
outputs=[output_table, download_button]
|
||||
)
|
||||
|
||||
demo.launch(debug=True, share=True)
|
||||
|
||||
|
||||
5
week3/community-contributions/muawiya/app/consts.py
Normal file
5
week3/community-contributions/muawiya/app/consts.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Models
|
||||
GPT = 'gpt2'
|
||||
FALCON = "tiiuae/falcon-rw-1b"
|
||||
MISTRAL = "mistralai/Mistral-7B-Instruct-v0.1"
|
||||
Databricks = "databricks/dolly-v2-3b"
|
||||
@@ -0,0 +1,7 @@
|
||||
huggingface_hub==0.30.2
|
||||
ipython==8.12.3
|
||||
openai==1.76.2
|
||||
protobuf==6.30.2
|
||||
Requests==2.32.3
|
||||
torch==2.6.0+cu124
|
||||
transformers==4.51.3
|
||||
@@ -0,0 +1,71 @@
|
||||
[
|
||||
{
|
||||
"title": "Software Engineer",
|
||||
"description": "We are seeking a highly skilled software engineer to join our team in developing and maintaining complex software systems. The ideal candidate will have a strong background in computer science and experience with multiple programming languages. Responsibilities include writing clean and efficient code, collaborating with cross-functional teams, and actively participating in code reviews. This is an excellent opportunity for a self-starter with a passion for technology and a desire to grow in their career.",
|
||||
"requirements": [
|
||||
"Bachelor's degree in Computer Science or related field",
|
||||
"3+ years of experience in software development",
|
||||
"Strong proficiency in Java or C++",
|
||||
"Experience with agile development methodologies",
|
||||
"Excellent problem-solving and analytical skills"
|
||||
],
|
||||
"location": "New York, NY",
|
||||
"company_name": "ABC Technologies"
|
||||
},
|
||||
{
|
||||
"title": "Software Engineer",
|
||||
"description": "We are looking for a highly skilled software engineer to join our team and contribute to the development of innovative software solutions. The ideal candidate will have experience in designing, developing, and testing software systems, and be able to work independently or as part of a team. Responsibilities include writing clean and efficient code, collaborating with cross-functional teams, and actively participating in code reviews. Must have a strong understanding of computer science principles and be able to learn quickly. This is a full-time position located in San Francisco, CA.",
|
||||
"requirements": [
|
||||
"Bachelor's degree in Computer Science or related field",
|
||||
"3+ years of experience in software development",
|
||||
"Strong proficiency in Java or C++",
|
||||
"Experience with agile development methodologies",
|
||||
"Excellent problem-solving skills",
|
||||
"Ability to work in a fast-paced environment"
|
||||
],
|
||||
"location": "San Francisco, CA",
|
||||
"company_name": "Acme Inc."
|
||||
},
|
||||
{
|
||||
"title": "Software Engineer",
|
||||
"description": "We are seeking a highly skilled software engineer to join our team in developing and maintaining our cutting-edge software applications. The ideal candidate will have a strong background in computer science and software engineering, with experience in designing, coding, and testing software systems. Responsibilities include collaborating with cross-functional teams, writing clean and efficient code, and ensuring the timely delivery of high-quality software products. This is an excellent opportunity for a self-starter with a passion for technology and a desire to work in a dynamic and fast-paced environment.",
|
||||
"requirements": [
|
||||
"Bachelor's degree in Computer Science or related field",
|
||||
"3+ years of experience in software engineering",
|
||||
"Strong proficiency in Java, Python, or C++",
|
||||
"Experience with agile development methodologies",
|
||||
"Excellent problem-solving and analytical skills",
|
||||
"Strong communication and interpersonal skills"
|
||||
],
|
||||
"location": "New York, NY",
|
||||
"company_name": "ABC Tech"
|
||||
},
|
||||
{
|
||||
"title": "Software Engineer",
|
||||
"description": "We are seeking a highly skilled software engineer to join our team and contribute to the development of innovative software solutions. The ideal candidate will have a strong background in computer science and experience with various programming languages and technologies. Responsibilities include designing, coding, testing, and maintaining software systems, as well as collaborating with cross-functional teams. This is an excellent opportunity for a creative and motivated individual to make a significant impact in the tech industry.",
|
||||
"requirements": [
|
||||
"Bachelor's degree in Computer Science or related field",
|
||||
"Minimum of 2 years experience in software development",
|
||||
"Strong proficiency in Java, Python, or C++",
|
||||
"Experience with agile development methodologies",
|
||||
"Excellent problem-solving and analytical skills",
|
||||
"Ability to work independently and as part of a team",
|
||||
"Strong communication and interpersonal skills"
|
||||
],
|
||||
"location": "New York, NY",
|
||||
"company_name": "ABC Tech Inc."
|
||||
},
|
||||
{
|
||||
"title": "Software Engineer",
|
||||
"description": "We are looking for a skilled software engineer to join our team and contribute to the development of innovative software solutions. Responsibilities include designing, coding, testing and maintaining software systems, as well as collaborating with cross-functional teams. The ideal candidate will have a strong background in computer science or a related field, and at least 3 years of experience in software development. Must be proficient in multiple programming languages, including Java, Python, and C++. Strong problem-solving skills and the ability to work independently or as part of a team are required. This is a full-time position located in San Francisco, CA.",
|
||||
"requirements": [
|
||||
"Bachelor's degree in Computer Science or related field",
|
||||
"At least 3 years of experience in software development",
|
||||
"Proficiency in Java, Python, and C++",
|
||||
"Strong problem-solving skills",
|
||||
"Ability to work independently or as part of a team"
|
||||
],
|
||||
"location": "San Francisco, CA",
|
||||
"company_name": "Innovative Solutions Inc."
|
||||
}
|
||||
]
|
||||
File diff suppressed because one or more lines are too long
400
week4/community-contributions/Week4-Comments-Generator-DP.ipynb
Normal file
400
week4/community-contributions/Week4-Comments-Generator-DP.ipynb
Normal file
@@ -0,0 +1,400 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3e473bbd-a0c2-43bd-bf99-c749784d00c3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gradio as gr\n",
|
||||
"import openai\n",
|
||||
"import anthropic\n",
|
||||
"import google.generativeai as genai\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"from typing import Dict, Any, Optional\n",
|
||||
"import asyncio\n",
|
||||
"from dotenv import load_dotenv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "16210512-41f1-4de3-8348-2cd7129e023f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load API\n",
|
||||
"load_dotenv(override=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6747e275-91eb-4d2b-90b6-805f2bd9b6b7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class CodeCommenter:\n",
|
||||
" def __init__(self):\n",
|
||||
" # Initialize API clients\n",
|
||||
" self.openai_client = None\n",
|
||||
" self.anthropic_client = None\n",
|
||||
" self.gemini_client = None\n",
|
||||
" \n",
|
||||
" # Load API keys from environment variables\n",
|
||||
" self.setup_clients()\n",
|
||||
" \n",
|
||||
" def setup_clients(self):\n",
|
||||
" \"\"\"Initialize API clients with keys from environment variables\"\"\"\n",
|
||||
" try:\n",
|
||||
" # OpenAI\n",
|
||||
" openai_key = os.getenv('OPENAI_API_KEY')\n",
|
||||
" if openai_key:\n",
|
||||
" self.openai_client = openai.OpenAI(api_key=openai_key)\n",
|
||||
" \n",
|
||||
" # Anthropic\n",
|
||||
" anthropic_key = os.getenv('ANTHROPIC_API_KEY')\n",
|
||||
" if anthropic_key:\n",
|
||||
" self.anthropic_client = anthropic.Anthropic(api_key=anthropic_key)\n",
|
||||
" \n",
|
||||
" # Google Gemini\n",
|
||||
" gemini_key = os.getenv('GOOGLE_API_KEY')\n",
|
||||
" if gemini_key:\n",
|
||||
" genai.configure(api_key=gemini_key)\n",
|
||||
" self.gemini_client = genai.GenerativeModel('gemini-2.0-flash-exp')\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Warning: Error setting up API clients: {e}\")\n",
|
||||
" \n",
|
||||
" def create_prompt(self, code: str, language: str) -> str:\n",
|
||||
" \"\"\"Create a prompt for the LLM to add comments and docstrings\"\"\"\n",
|
||||
" return f\"\"\"Please add detailed and helpful comments and docstrings to the following {language} code. \n",
|
||||
" \n",
|
||||
"Guidelines:\n",
|
||||
"1. Add comprehensive docstrings for functions, classes, and modules\n",
|
||||
"2. Add inline comments explaining complex logic\n",
|
||||
"3. Follow the commenting conventions for {language}\n",
|
||||
"4. Maintain the original code structure and functionality\n",
|
||||
"5. Make comments clear and professional\n",
|
||||
"6. Don't change the actual code logic, only add comments\n",
|
||||
"7. Do not add code markdown delimiters like ```python\n",
|
||||
"\n",
|
||||
"Here's the code to comment:\n",
|
||||
"\n",
|
||||
"{code}\n",
|
||||
"\n",
|
||||
"Please return only the commented code without any additional explanation or markdown formatting.\"\"\"\n",
|
||||
"\n",
|
||||
" def call_openai(self, prompt: str, model: str = \"gpt-4o-mini\") -> str:\n",
|
||||
" \"\"\"Make API call to OpenAI\"\"\"\n",
|
||||
" if not self.openai_client:\n",
|
||||
" return \"Error: OpenAI API key not configured. Please set OPENAI_API_KEY environment variable.\"\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" response = self.openai_client.chat.completions.create(\n",
|
||||
" model=model,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"system\", \"content\": \"You are a helpful coding assistant that adds detailed comments and docstrings to code.\"},\n",
|
||||
" {\"role\": \"user\", \"content\": prompt}\n",
|
||||
" ],\n",
|
||||
" max_tokens=4000,\n",
|
||||
" temperature=0.1\n",
|
||||
" )\n",
|
||||
" return response.choices[0].message.content.strip()\n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"Error calling OpenAI API: {str(e)}\"\n",
|
||||
" \n",
|
||||
" def call_anthropic(self, prompt: str, model: str = \"claude-3-5-haiku-20241022\") -> str:\n",
|
||||
" \"\"\"Make API call to Anthropic Claude\"\"\"\n",
|
||||
" if not self.anthropic_client:\n",
|
||||
" return \"Error: Anthropic API key not configured. Please set ANTHROPIC_API_KEY environment variable.\"\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" response = self.anthropic_client.messages.create(\n",
|
||||
" model=model,\n",
|
||||
" max_tokens=4000,\n",
|
||||
" temperature=0.1,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"user\", \"content\": prompt}\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" return response.content[0].text.strip()\n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"Error calling Anthropic API: {str(e)}\"\n",
|
||||
" \n",
|
||||
" def call_gemini(self, prompt: str) -> str:\n",
|
||||
" \"\"\"Make API call to Google Gemini\"\"\"\n",
|
||||
" if not self.gemini_client:\n",
|
||||
" return \"Error: Google API key not configured. Please set GOOGLE_API_KEY environment variable.\"\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" response = self.gemini_client.generate_content(\n",
|
||||
" prompt,\n",
|
||||
" generation_config=genai.types.GenerationConfig(\n",
|
||||
" max_output_tokens=4000,\n",
|
||||
" temperature=0.1,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" return response.text.strip()\n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"Error calling Gemini API: {str(e)}\"\n",
|
||||
" \n",
|
||||
" def call_ollama(self, prompt: str, model: str = \"llama3.2:latest\") -> str:\n",
|
||||
" \"\"\"Make API call to Ollama (local)\"\"\"\n",
|
||||
" try:\n",
|
||||
" url = \"http://localhost:11434/api/generate\"\n",
|
||||
" data = {\n",
|
||||
" \"model\": model,\n",
|
||||
" \"prompt\": prompt,\n",
|
||||
" \"stream\": False,\n",
|
||||
" \"options\": {\n",
|
||||
" \"temperature\": 0.1,\n",
|
||||
" \"num_predict\": 4000\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" response = requests.post(url, json=data, timeout=60)\n",
|
||||
" if response.status_code == 200:\n",
|
||||
" result = response.json()\n",
|
||||
" return result.get('response', '').strip()\n",
|
||||
" else:\n",
|
||||
" return f\"Error calling Ollama API: HTTP {response.status_code}\"\n",
|
||||
" except requests.exceptions.ConnectionError:\n",
|
||||
" return \"Error: Could not connect to Ollama. Make sure Ollama is running locally on port 11434.\"\n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"Error calling Ollama API: {str(e)}\"\n",
|
||||
"\n",
|
||||
" def generate_comments(self, language: str, code: str, llm: str) -> str:\n",
|
||||
" \"\"\"Generate comments for the given code using the specified LLM\"\"\"\n",
|
||||
" if not code.strip():\n",
|
||||
" return \"Error: Please provide code to comment.\"\n",
|
||||
" \n",
|
||||
" prompt = self.create_prompt(code, language)\n",
|
||||
" \n",
|
||||
" # Route to appropriate LLM\n",
|
||||
" if llm == \"gpt-4o-mini\":\n",
|
||||
" return self.call_openai(prompt, \"gpt-4o-mini\")\n",
|
||||
" elif llm == \"claude-3-5-haiku-20241022\":\n",
|
||||
" return self.call_anthropic(prompt, \"claude-3-5-haiku-20241022\")\n",
|
||||
" elif llm == \"gemini-2.0-flash\":\n",
|
||||
" return self.call_gemini(prompt)\n",
|
||||
" elif llm == \"ollama:llama3.2:latest\":\n",
|
||||
" return self.call_ollama(prompt, \"llama3.2:latest\")\n",
|
||||
" else:\n",
|
||||
" return f\"Error: Unsupported LLM: {llm}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "813f0911-d53f-4887-9341-656712e32d8f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def create_gradio_interface():\n",
|
||||
" \"\"\"Create and configure the Gradio interface\"\"\"\n",
|
||||
" commenter = CodeCommenter()\n",
|
||||
" \n",
|
||||
" # Define the main function for the interface\n",
|
||||
" def process_code(language, code, llm):\n",
|
||||
" \"\"\"Process the code and return commented version\"\"\"\n",
|
||||
" if not code.strip():\n",
|
||||
" return \"Please enter some code to comment.\"\n",
|
||||
" \n",
|
||||
" # Show processing message\n",
|
||||
" processing_msg = f\"Processing {language} code with {llm}...\"\n",
|
||||
" print(processing_msg)\n",
|
||||
" \n",
|
||||
" # Generate comments\n",
|
||||
" result = commenter.generate_comments(language, code, llm)\n",
|
||||
" return result\n",
|
||||
" \n",
|
||||
" # Define default code\n",
|
||||
" default_code = \"\"\"import pyodbc\n",
|
||||
"from tabulate import tabulate\n",
|
||||
"def connect_to_sql_server(server_name, database, username=None, password=None):\n",
|
||||
" try:\n",
|
||||
" if username and password:\n",
|
||||
" connection_string = f\"DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={server_name};DATABASE={database};UID={username};PWD={password}\"\n",
|
||||
" else:\n",
|
||||
" connection_string = f\"DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={server_name};DATABASE={database};Trusted_Connection=yes\"\n",
|
||||
" connection = pyodbc.connect(connection_string)\n",
|
||||
" print(f\"Successfully connected to {server_name}/{database}\")\n",
|
||||
" return connection\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Failed to connect to {server_name}/{database}: {str(e)}\")\n",
|
||||
" return None\n",
|
||||
"def get_record_count(connection, table_name):\n",
|
||||
" try:\n",
|
||||
" cursor = connection.cursor()\n",
|
||||
" query = f\"SELECT COUNT(*) FROM {table_name}\"\n",
|
||||
" cursor.execute(query)\n",
|
||||
" count = cursor.fetchone()[0]\n",
|
||||
" cursor.close()\n",
|
||||
" print(f\"Record count for {table_name}: {count}\")\n",
|
||||
" return count\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Failed to get record count for {table_name}: {str(e)}\")\n",
|
||||
" return None\n",
|
||||
"def select_top_records(connection, table_name, n):\n",
|
||||
" try:\n",
|
||||
" cursor = connection.cursor()\n",
|
||||
" query = f\"SELECT TOP {n} * FROM {table_name}\"\n",
|
||||
" cursor.execute(query)\n",
|
||||
" records = cursor.fetchall()\n",
|
||||
" columns = [column[0] for column in cursor.description]\n",
|
||||
" cursor.close()\n",
|
||||
" print(f\"Top {n} records from {table_name}\")\n",
|
||||
" if records:\n",
|
||||
" print(tabulate(records, headers=columns, tablefmt=\"grid\"))\n",
|
||||
" return records\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Failed to retrieve top {n} records from {table_name}: {str(e)}\")\n",
|
||||
" return None\n",
|
||||
"conn = connect_to_sql_server(\"localhost\", \"AdventureWorks_lite\")\n",
|
||||
"if conn:\n",
|
||||
" total_records = get_record_count(conn, \"Sales.SalesOrderDetail\")\n",
|
||||
" top_records = select_top_records(conn, \"Production.Product\", 10)\n",
|
||||
" conn.close()\n",
|
||||
" print(\"Connection closed successfully\")\"\"\"\n",
|
||||
"\n",
|
||||
" css = \"\"\"\n",
|
||||
"textarea[rows]:not([rows=\"1\"]) {\n",
|
||||
" overflow-y: auto !important;\n",
|
||||
" scrollbar-width: thin !important;\n",
|
||||
"}\n",
|
||||
"textarea[rows]:not([rows=\"1\"])::-webkit-scrollbar {\n",
|
||||
" all: initial !important;\n",
|
||||
" background: #f1f1f1 !important;\n",
|
||||
"}\n",
|
||||
"textarea[rows]:not([rows=\"1\"])::-webkit-scrollbar-thumb {\n",
|
||||
" all: initial !important;\n",
|
||||
" background: #a8a8a8 !important;\n",
|
||||
"}\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
" # Create the interface\n",
|
||||
" with gr.Blocks(title=\"Code Commenter\", theme=gr.themes.Base(), css=css) as interface:\n",
|
||||
" gr.Markdown(\"# 🔧 Code Commenter\")\n",
|
||||
" gr.Markdown(\"Add detailed comments and docstrings to your code using various LLM models.\")\n",
|
||||
" \n",
|
||||
" with gr.Row():\n",
|
||||
" with gr.Column():\n",
|
||||
" code_input = gr.Textbox(\n",
|
||||
" label=\"Input Code\",\n",
|
||||
" value=default_code,\n",
|
||||
" lines=15,\n",
|
||||
" max_lines=20,\n",
|
||||
" info=\"Enter the code you want to add comments to\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" with gr.Column():\n",
|
||||
" code_output = gr.Textbox(\n",
|
||||
" label=\"Commented Code\",\n",
|
||||
" lines=20,\n",
|
||||
" max_lines=20,\n",
|
||||
" info=\"Your code with added comments and docstrings\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" with gr.Row():\n",
|
||||
" with gr.Column(scale=1):\n",
|
||||
" language_dropdown = gr.Dropdown(\n",
|
||||
" choices=[\"Python\", \"Ruby\", \"Rust\", \"C++\", \"Java\"],\n",
|
||||
" value=\"Python\",\n",
|
||||
" label=\"Programming Language\",\n",
|
||||
" info=\"Select the programming language of your code\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" llm_dropdown = gr.Dropdown(\n",
|
||||
" choices=[\n",
|
||||
" \"gpt-4o-mini\",\n",
|
||||
" \"claude-3-5-haiku-20241022\", \n",
|
||||
" \"gemini-2.0-flash\",\n",
|
||||
" \"ollama:llama3.2:latest\"\n",
|
||||
" ],\n",
|
||||
" value=\"gpt-4o-mini\",\n",
|
||||
" label=\"LLM Model\",\n",
|
||||
" info=\"Choose the language model to use\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" generate_btn = gr.Button(\n",
|
||||
" \"🚀 Generate Comments\", \n",
|
||||
" variant=\"primary\",\n",
|
||||
" size=\"lg\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Add some API setup information\n",
|
||||
" gr.Markdown(\"## 📝 API Setup Instructions\")\n",
|
||||
" gr.Markdown(\"\"\"\n",
|
||||
" To use this tool, you need to set up API keys as environment variables:\n",
|
||||
" \n",
|
||||
" - **OpenAI**: Set `OPENAI_API_KEY`\n",
|
||||
" - **Anthropic**: Set `ANTHROPIC_API_KEY` \n",
|
||||
" - **Google Gemini**: Set `GOOGLE_API_KEY`\n",
|
||||
" - **Ollama**: Make sure Ollama is running locally on port 11434\n",
|
||||
" \"\"\")\n",
|
||||
" \n",
|
||||
" # Connect the button to the processing function\n",
|
||||
" generate_btn.click(\n",
|
||||
" fn=process_code,\n",
|
||||
" inputs=[language_dropdown, code_input, llm_dropdown],\n",
|
||||
" outputs=code_output,\n",
|
||||
" show_progress=True\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" return interface"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ef461e08-c1d5-406d-b7d2-a4329f16486e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"🚀 Starting Code Commenter...\")\n",
|
||||
"print(\"📋 Setting up Gradio interface...\")\n",
|
||||
"\n",
|
||||
"# Create and launch the interface\n",
|
||||
"interface = create_gradio_interface()\n",
|
||||
"\n",
|
||||
"print(\"🌐 Launching interface...\")\n",
|
||||
"print(\"💡 The interface will open in your default browser\")\n",
|
||||
"print(\"🔧 Make sure to set up your API keys as environment variables\")\n",
|
||||
"\n",
|
||||
"# Launch with auto-opening in browser\n",
|
||||
"interface.launch(\n",
|
||||
" server_name=\"127.0.0.1\",\n",
|
||||
" server_port=7860,\n",
|
||||
" share=False,\n",
|
||||
" inbrowser=True,\n",
|
||||
" show_error=True\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,538 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "3e473bbd-a0c2-43bd-bf99-c749784d00c3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gradio as gr\n",
|
||||
"import openai\n",
|
||||
"import anthropic\n",
|
||||
"import google.generativeai as genai\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"from typing import Dict, Any, Optional\n",
|
||||
"import asyncio\n",
|
||||
"from dotenv import load_dotenv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "16210512-41f1-4de3-8348-2cd7129e023f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# load API\n",
|
||||
"load_dotenv(override=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "6747e275-91eb-4d2b-90b6-805f2bd9b6b7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class CodeCommenter:\n",
|
||||
" def __init__(self):\n",
|
||||
" # Initialize API clients\n",
|
||||
" self.openai_client = None\n",
|
||||
" self.anthropic_client = None\n",
|
||||
" self.gemini_client = None\n",
|
||||
" \n",
|
||||
" # Load API keys from environment variables\n",
|
||||
" self.setup_clients()\n",
|
||||
" \n",
|
||||
" def setup_clients(self):\n",
|
||||
" \"\"\"Initialize API clients with keys from environment variables\"\"\"\n",
|
||||
" try:\n",
|
||||
" # OpenAI\n",
|
||||
" openai_key = os.getenv('OPENAI_API_KEY')\n",
|
||||
" if openai_key:\n",
|
||||
" self.openai_client = openai.OpenAI(api_key=openai_key)\n",
|
||||
" \n",
|
||||
" # Anthropic\n",
|
||||
" anthropic_key = os.getenv('ANTHROPIC_API_KEY')\n",
|
||||
" if anthropic_key:\n",
|
||||
" self.anthropic_client = anthropic.Anthropic(api_key=anthropic_key)\n",
|
||||
" \n",
|
||||
" # Google Gemini\n",
|
||||
" gemini_key = os.getenv('GOOGLE_API_KEY')\n",
|
||||
" if gemini_key:\n",
|
||||
" genai.configure(api_key=gemini_key)\n",
|
||||
" self.gemini_client = genai.GenerativeModel('gemini-2.0-flash-exp')\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Warning: Error setting up API clients: {e}\")\n",
|
||||
" \n",
|
||||
" def create_comments_prompt(self, code: str, language: str) -> str:\n",
|
||||
" \"\"\"Create a prompt for the LLM to add comments and docstrings\"\"\"\n",
|
||||
" return f\"\"\"Please add detailed and helpful comments and docstrings to the following {language} code. \n",
|
||||
" \n",
|
||||
"Guidelines:\n",
|
||||
"1. Add comprehensive docstrings for functions, classes, and modules\n",
|
||||
"2. Add inline comments explaining complex logic\n",
|
||||
"3. Follow the commenting conventions for {language}\n",
|
||||
"4. Maintain the original code structure and functionality\n",
|
||||
"5. Make comments clear and professional\n",
|
||||
"6. Don't change the actual code logic, only add comments\n",
|
||||
"7. Do not add code markdown delimiters like ```python\n",
|
||||
"\n",
|
||||
"Here's the code to comment:\n",
|
||||
"\n",
|
||||
"{code}\n",
|
||||
"\n",
|
||||
"Please return only the commented code without any additional explanation or markdown formatting.\"\"\"\n",
|
||||
"\n",
|
||||
" def create_tests_prompt(self, code: str, language: str) -> str:\n",
|
||||
" \"\"\"Create a prompt for the LLM to generate unit tests\"\"\"\n",
|
||||
" return f\"\"\"Please generate comprehensive unit tests for the following {language} code.\n",
|
||||
" \n",
|
||||
"Guidelines:\n",
|
||||
"1. Use appropriate testing framework for {language} (pytest for Python, JUnit for Java, etc.)\n",
|
||||
"2. Create tests for all functions and methods\n",
|
||||
"3. Include both positive and negative test cases\n",
|
||||
"4. Test edge cases and error conditions\n",
|
||||
"5. Use meaningful test names that describe what is being tested\n",
|
||||
"6. Include setup and teardown methods if needed\n",
|
||||
"7. Add mock objects for external dependencies (like database connections)\n",
|
||||
"8. Do not add code markdown delimiters like ```python\n",
|
||||
"9. Follow testing best practices for {language}\n",
|
||||
"\n",
|
||||
"Here's the code to test:\n",
|
||||
"\n",
|
||||
"{code}\n",
|
||||
"\n",
|
||||
"Please return only the unit test code without any additional explanation or markdown formatting.\"\"\"\n",
|
||||
"\n",
|
||||
" def create_combined_prompt(self, code: str, language: str) -> str:\n",
|
||||
" \"\"\"Create a prompt for the LLM to add both comments and unit tests\"\"\"\n",
|
||||
" return f\"\"\"Please add detailed comments and docstrings to the following {language} code AND generate comprehensive unit tests for it.\n",
|
||||
" \n",
|
||||
"For Comments:\n",
|
||||
"1. Add comprehensive docstrings for functions, classes, and modules\n",
|
||||
"2. Add inline comments explaining complex logic\n",
|
||||
"3. Follow the commenting conventions for {language}\n",
|
||||
"4. Don't change the actual code logic, only add comments\n",
|
||||
"\n",
|
||||
"For Unit Tests:\n",
|
||||
"1. Use appropriate testing framework for {language} (pytest for Python, JUnit for Java, etc.)\n",
|
||||
"2. Create tests for all functions and methods\n",
|
||||
"3. Include both positive and negative test cases\n",
|
||||
"4. Test edge cases and error conditions\n",
|
||||
"5. Add mock objects for external dependencies (like database connections)\n",
|
||||
"6. Follow testing best practices for {language}\n",
|
||||
"\n",
|
||||
"Structure your response as:\n",
|
||||
"1. First, provide the original code with added comments and docstrings \n",
|
||||
"2. Then, provide the unit tests as a separate section\n",
|
||||
"3. Do not add code markdown delimiters like ```python\n",
|
||||
"4. The 2 separated portions of code, comments and unit test should be clearly demarcated by comments specifying the following section purpose\n",
|
||||
"\n",
|
||||
"Here's the code:\n",
|
||||
"\n",
|
||||
"{code}\n",
|
||||
"\n",
|
||||
"Please return the commented code followed by the unit tests, clearly separated.\"\"\"\n",
|
||||
"\n",
|
||||
" def call_openai(self, prompt: str, model: str = \"gpt-4o-mini\") -> str:\n",
|
||||
" \"\"\"Make API call to OpenAI\"\"\"\n",
|
||||
" if not self.openai_client:\n",
|
||||
" return \"Error: OpenAI API key not configured. Please set OPENAI_API_KEY environment variable.\"\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" response = self.openai_client.chat.completions.create(\n",
|
||||
" model=model,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"system\", \"content\": \"You are a helpful coding assistant that adds detailed comments, docstrings, and generates unit tests for code.\"},\n",
|
||||
" {\"role\": \"user\", \"content\": prompt}\n",
|
||||
" ],\n",
|
||||
" max_tokens=4000,\n",
|
||||
" temperature=0.1\n",
|
||||
" )\n",
|
||||
" return response.choices[0].message.content.strip()\n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"Error calling OpenAI API: {str(e)}\"\n",
|
||||
" \n",
|
||||
" def call_anthropic(self, prompt: str, model: str = \"claude-3-5-haiku-20241022\") -> str:\n",
|
||||
" \"\"\"Make API call to Anthropic Claude\"\"\"\n",
|
||||
" if not self.anthropic_client:\n",
|
||||
" return \"Error: Anthropic API key not configured. Please set ANTHROPIC_API_KEY environment variable.\"\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" response = self.anthropic_client.messages.create(\n",
|
||||
" model=model,\n",
|
||||
" max_tokens=4000,\n",
|
||||
" temperature=0.1,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"user\", \"content\": prompt}\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" return response.content[0].text.strip()\n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"Error calling Anthropic API: {str(e)}\"\n",
|
||||
" \n",
|
||||
" def call_gemini(self, prompt: str) -> str:\n",
|
||||
" \"\"\"Make API call to Google Gemini\"\"\"\n",
|
||||
" if not self.gemini_client:\n",
|
||||
" return \"Error: Google API key not configured. Please set GOOGLE_API_KEY environment variable.\"\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" response = self.gemini_client.generate_content(\n",
|
||||
" prompt,\n",
|
||||
" generation_config=genai.types.GenerationConfig(\n",
|
||||
" max_output_tokens=4000,\n",
|
||||
" temperature=0.1,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" return response.text.strip()\n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"Error calling Gemini API: {str(e)}\"\n",
|
||||
" \n",
|
||||
" def call_ollama(self, prompt: str, model: str = \"llama3.2:latest\") -> str:\n",
|
||||
" \"\"\"Make API call to Ollama (local)\"\"\"\n",
|
||||
" try:\n",
|
||||
" url = \"http://localhost:11434/api/generate\"\n",
|
||||
" data = {\n",
|
||||
" \"model\": model,\n",
|
||||
" \"prompt\": prompt,\n",
|
||||
" \"stream\": False,\n",
|
||||
" \"options\": {\n",
|
||||
" \"temperature\": 0.1,\n",
|
||||
" \"num_predict\": 4000\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" response = requests.post(url, json=data, timeout=60)\n",
|
||||
" if response.status_code == 200:\n",
|
||||
" result = response.json()\n",
|
||||
" return result.get('response', '').strip()\n",
|
||||
" else:\n",
|
||||
" return f\"Error calling Ollama API: HTTP {response.status_code}\"\n",
|
||||
" except requests.exceptions.ConnectionError:\n",
|
||||
" return \"Error: Could not connect to Ollama. Make sure Ollama is running locally on port 11434.\"\n",
|
||||
" except Exception as e:\n",
|
||||
" return f\"Error calling Ollama API: {str(e)}\"\n",
|
||||
"\n",
|
||||
" def process_code(self, language: str, code: str, llm: str, generate_comments: bool, generate_tests: bool) -> str:\n",
|
||||
" \"\"\"Process the given code based on selected options\"\"\"\n",
|
||||
" if not code.strip():\n",
|
||||
" return \"Error: Please provide code to process.\"\n",
|
||||
" \n",
|
||||
" if not generate_comments and not generate_tests:\n",
|
||||
" return \"Error: Please select at least one option (Generate comments or Generate test units).\"\n",
|
||||
" \n",
|
||||
" # Determine which prompt to use\n",
|
||||
" if generate_comments and generate_tests:\n",
|
||||
" prompt = self.create_combined_prompt(code, language)\n",
|
||||
" elif generate_comments:\n",
|
||||
" prompt = self.create_comments_prompt(code, language)\n",
|
||||
" else: # generate_tests only\n",
|
||||
" prompt = self.create_tests_prompt(code, language)\n",
|
||||
" \n",
|
||||
" # Route to appropriate LLM\n",
|
||||
" if llm == \"gpt-4o-mini\":\n",
|
||||
" return self.call_openai(prompt, \"gpt-4o-mini\")\n",
|
||||
" elif llm == \"claude-3-5-haiku-20241022\":\n",
|
||||
" return self.call_anthropic(prompt, \"claude-3-5-haiku-20241022\")\n",
|
||||
" elif llm == \"gemini-2.0-flash\":\n",
|
||||
" return self.call_gemini(prompt)\n",
|
||||
" elif llm == \"ollama:llama3.2:latest\":\n",
|
||||
" return self.call_ollama(prompt, \"llama3.2:latest\")\n",
|
||||
" else:\n",
|
||||
" return f\"Error: Unsupported LLM: {llm}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "813f0911-d53f-4887-9341-656712e32d8f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def create_gradio_interface():\n",
|
||||
" \"\"\"Create and configure the Gradio interface\"\"\"\n",
|
||||
" commenter = CodeCommenter()\n",
|
||||
" \n",
|
||||
" # Define the main function for the interface\n",
|
||||
" def process_code_interface(language, code, llm, generate_comments, generate_tests):\n",
|
||||
" \"\"\"Process the code and return processed version based on selected options\"\"\"\n",
|
||||
" if not code.strip():\n",
|
||||
" return \"Please enter some code to process.\"\n",
|
||||
" \n",
|
||||
" if not generate_comments and not generate_tests:\n",
|
||||
" return \"Please select at least one option: Generate comments or Generate test units.\"\n",
|
||||
" \n",
|
||||
" # Show processing message\n",
|
||||
" options = []\n",
|
||||
" if generate_comments:\n",
|
||||
" options.append(\"comments\")\n",
|
||||
" if generate_tests:\n",
|
||||
" options.append(\"unit tests\")\n",
|
||||
" \n",
|
||||
" processing_msg = f\"Processing {language} code with {llm} to generate {' and '.join(options)}...\"\n",
|
||||
" print(processing_msg)\n",
|
||||
" \n",
|
||||
" # Process the code\n",
|
||||
" result = commenter.process_code(language, code, llm, generate_comments, generate_tests)\n",
|
||||
" return result\n",
|
||||
" \n",
|
||||
" # Define default code\n",
|
||||
" default_code = \"\"\"import pyodbc\n",
|
||||
"from tabulate import tabulate\n",
|
||||
"def connect_to_sql_server(server_name, database, username=None, password=None):\n",
|
||||
" try:\n",
|
||||
" if username and password:\n",
|
||||
" connection_string = f\"DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={server_name};DATABASE={database};UID={username};PWD={password}\"\n",
|
||||
" else:\n",
|
||||
" connection_string = f\"DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={server_name};DATABASE={database};Trusted_Connection=yes\"\n",
|
||||
" connection = pyodbc.connect(connection_string)\n",
|
||||
" print(f\"Successfully connected to {server_name}/{database}\")\n",
|
||||
" return connection\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Failed to connect to {server_name}/{database}: {str(e)}\")\n",
|
||||
" return None\n",
|
||||
"def get_record_count(connection, table_name):\n",
|
||||
" try:\n",
|
||||
" cursor = connection.cursor()\n",
|
||||
" query = f\"SELECT COUNT(*) FROM {table_name}\"\n",
|
||||
" cursor.execute(query)\n",
|
||||
" count = cursor.fetchone()[0]\n",
|
||||
" cursor.close()\n",
|
||||
" print(f\"Record count for {table_name}: {count}\")\n",
|
||||
" return count\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Failed to get record count for {table_name}: {str(e)}\")\n",
|
||||
" return None\n",
|
||||
"def select_top_records(connection, table_name, n):\n",
|
||||
" try:\n",
|
||||
" cursor = connection.cursor()\n",
|
||||
" query = f\"SELECT TOP {n} * FROM {table_name}\"\n",
|
||||
" cursor.execute(query)\n",
|
||||
" records = cursor.fetchall()\n",
|
||||
" columns = [column[0] for column in cursor.description]\n",
|
||||
" cursor.close()\n",
|
||||
" print(f\"Top {n} records from {table_name}\")\n",
|
||||
" if records:\n",
|
||||
" print(tabulate(records, headers=columns, tablefmt=\"grid\"))\n",
|
||||
" return records\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Failed to retrieve top {n} records from {table_name}: {str(e)}\")\n",
|
||||
" return None\n",
|
||||
"conn = connect_to_sql_server(\"localhost\", \"AdventureWorks_lite\")\n",
|
||||
"if conn:\n",
|
||||
" total_records = get_record_count(conn, \"Sales.SalesOrderDetail\")\n",
|
||||
" top_records = select_top_records(conn, \"Production.Product\", 10)\n",
|
||||
" conn.close()\n",
|
||||
" print(\"Connection closed successfully\")\"\"\"\n",
|
||||
"\n",
|
||||
" css = \"\"\"\n",
|
||||
"textarea[rows]:not([rows=\"1\"]) {\n",
|
||||
" overflow-y: auto !important;\n",
|
||||
" scrollbar-width: thin !important;\n",
|
||||
"}\n",
|
||||
"textarea[rows]:not([rows=\"1\"])::-webkit-scrollbar {\n",
|
||||
" all: initial !important;\n",
|
||||
" background: #f1f1f1 !important;\n",
|
||||
"}\n",
|
||||
"textarea[rows]:not([rows=\"1\"])::-webkit-scrollbar-thumb {\n",
|
||||
" all: initial !important;\n",
|
||||
" background: #a8a8a8 !important;\n",
|
||||
"}\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
" # Create the interface\n",
|
||||
" with gr.Blocks(title=\"Code Commenter & Test Generator\", theme=gr.themes.Base(), css=css) as interface:\n",
|
||||
" gr.Markdown(\"# 🔧 Code Commenter & Test Generator\")\n",
|
||||
" gr.Markdown(\"Add detailed comments, docstrings, and/or generate unit tests for your code using various LLM models.\")\n",
|
||||
" \n",
|
||||
" with gr.Row():\n",
|
||||
" with gr.Column():\n",
|
||||
" code_input = gr.Textbox(\n",
|
||||
" label=\"Input Code\",\n",
|
||||
" value=default_code,\n",
|
||||
" lines=15,\n",
|
||||
" max_lines=20,\n",
|
||||
" info=\"Enter the code you want to process\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" with gr.Column():\n",
|
||||
" code_output = gr.Textbox(\n",
|
||||
" label=\"Processed Code\",\n",
|
||||
" lines=20,\n",
|
||||
" max_lines=20,\n",
|
||||
" info=\"Your code with added comments, docstrings, and/or unit tests\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Add checkboxes below the textboxes\n",
|
||||
" with gr.Row():\n",
|
||||
" with gr.Column():\n",
|
||||
" generate_comments_checkbox = gr.Checkbox(\n",
|
||||
" label=\"Generate comments\",\n",
|
||||
" value=True,\n",
|
||||
" info=\"Add detailed comments and docstrings to the code\"\n",
|
||||
" )\n",
|
||||
" generate_tests_checkbox = gr.Checkbox(\n",
|
||||
" label=\"Generate test units\",\n",
|
||||
" value=False,\n",
|
||||
" info=\"Generate comprehensive unit tests for the code\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" with gr.Row():\n",
|
||||
" with gr.Column(scale=1):\n",
|
||||
" language_dropdown = gr.Dropdown(\n",
|
||||
" choices=[\"Python\", \"Ruby\", \"Rust\", \"C++\", \"Java\"],\n",
|
||||
" value=\"Python\",\n",
|
||||
" label=\"Programming Language\",\n",
|
||||
" info=\"Select the programming language of your code\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" llm_dropdown = gr.Dropdown(\n",
|
||||
" choices=[\n",
|
||||
" \"gpt-4o-mini\",\n",
|
||||
" \"claude-3-5-haiku-20241022\", \n",
|
||||
" \"gemini-2.0-flash\",\n",
|
||||
" \"ollama:llama3.2:latest\"\n",
|
||||
" ],\n",
|
||||
" value=\"gpt-4o-mini\",\n",
|
||||
" label=\"LLM Model\",\n",
|
||||
" info=\"Choose the language model to use\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" generate_btn = gr.Button(\n",
|
||||
" \"🚀 Process Code\", \n",
|
||||
" variant=\"primary\",\n",
|
||||
" size=\"lg\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Add some API setup information\n",
|
||||
" gr.Markdown(\"## 📝 API Setup Instructions\")\n",
|
||||
" gr.Markdown(\"\"\"\n",
|
||||
" To use this tool, you need to set up API keys as environment variables:\n",
|
||||
" \n",
|
||||
" - **OpenAI**: Set `OPENAI_API_KEY`\n",
|
||||
" - **Anthropic**: Set `ANTHROPIC_API_KEY` \n",
|
||||
" - **Google Gemini**: Set `GOOGLE_API_KEY`\n",
|
||||
" - **Ollama**: Make sure Ollama is running locally on port 11434\n",
|
||||
" \"\"\")\n",
|
||||
" \n",
|
||||
" gr.Markdown(\"## ✨ Features\")\n",
|
||||
" gr.Markdown(\"\"\"\n",
|
||||
" - **Generate Comments**: Add detailed docstrings and inline comments\n",
|
||||
" - **Generate Unit Tests**: Create comprehensive test suites with mocking for external dependencies\n",
|
||||
" - **Combined Mode**: Generate both comments and unit tests in one go\n",
|
||||
" - **Multiple LLMs**: Choose from OpenAI, Anthropic, Google Gemini, or local Ollama models\n",
|
||||
" - **Multiple Languages**: Support for Python, Ruby, Rust, C++, and Java\n",
|
||||
" \"\"\")\n",
|
||||
" \n",
|
||||
" # Connect the button to the processing function\n",
|
||||
" generate_btn.click(\n",
|
||||
" fn=process_code_interface,\n",
|
||||
" inputs=[language_dropdown, code_input, llm_dropdown, generate_comments_checkbox, generate_tests_checkbox],\n",
|
||||
" outputs=code_output,\n",
|
||||
" show_progress=True\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" return interface"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "ef461e08-c1d5-406d-b7d2-a4329f16486e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"🚀 Starting Code Commenter & Test Generator...\n",
|
||||
"📋 Setting up Gradio interface...\n",
|
||||
"🌐 Launching interface...\n",
|
||||
"💡 The interface will open in your default browser\n",
|
||||
"🔧 Make sure to set up your API keys as environment variables\n",
|
||||
"* Running on local URL: http://127.0.0.1:7860\n",
|
||||
"\n",
|
||||
"To create a public link, set `share=True` in `launch()`.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": []
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"🚀 Starting Code Commenter & Test Generator...\")\n",
|
||||
"print(\"📋 Setting up Gradio interface...\")\n",
|
||||
"\n",
|
||||
"# Create and launch the interface\n",
|
||||
"interface = create_gradio_interface()\n",
|
||||
"\n",
|
||||
"print(\"🌐 Launching interface...\")\n",
|
||||
"print(\"💡 The interface will open in your default browser\")\n",
|
||||
"print(\"🔧 Make sure to set up your API keys as environment variables\")\n",
|
||||
"\n",
|
||||
"# Launch with auto-opening in browser\n",
|
||||
"interface.launch(\n",
|
||||
" server_name=\"127.0.0.1\",\n",
|
||||
" server_port=7860,\n",
|
||||
" share=False,\n",
|
||||
" inbrowser=True,\n",
|
||||
" show_error=True\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
300
week4/community-contributions/day5_java_code_commenter.ipynb
Normal file
300
week4/community-contributions/day5_java_code_commenter.ipynb
Normal file
@@ -0,0 +1,300 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "45ca91c2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# AI tool to add comments to the provided Java code\n",
|
||||
"\n",
|
||||
"Here we build a Gradio App that uses the frontier models to add comments to a java code. For testing purposes I have used the *cheaper* versions of the models, not the ones the leaderboards indicate as the best ones."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f44901f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from openai import OpenAI\n",
|
||||
"import google.generativeai as genai\n",
|
||||
"import anthropic\n",
|
||||
"import gradio as gr"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c47706b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# environment\n",
|
||||
"\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
||||
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
|
||||
"google_api_key = os.getenv('GOOGLE_API_KEY')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "35446b9a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"openai = OpenAI()\n",
|
||||
"claude = anthropic.Anthropic()\n",
|
||||
"genai.configure()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0e899efd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"OPENAI_MODEL = \"gpt-4o-mini\"\n",
|
||||
"CLAUDE_MODEL = \"claude-3-haiku-20240307\"\n",
|
||||
"GEMINI_MODEL = 'gemini-2.0-flash-lite'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47640f53",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"system_message = \"You are an assistant that adds comments to java code. \"\n",
|
||||
"system_message += \"Do not make any changes to the code itself.\"\n",
|
||||
"system_message += \"Use comments sparingly. Only add them in places where they help to undestand how the code works. Do not comment every single line of the code.\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f41ccbf0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def user_prompt_for(code):\n",
|
||||
" user_prompt = \"Add helpful comments to this java code. \"\n",
|
||||
" user_prompt += \"Do not change the code itself.\\n\\n\"\n",
|
||||
" user_prompt += code\n",
|
||||
" return user_prompt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c57c0000",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_code = \"\"\"\n",
|
||||
"package com.hma.kafkaproducertest.producer;\n",
|
||||
"\n",
|
||||
"import com.hma.kafkaproducertest.model.TestDTO;\n",
|
||||
"import org.springframework.cloud.stream.function.StreamBridge;\n",
|
||||
"import org.springframework.messaging.Message;\n",
|
||||
"import org.springframework.messaging.support.MessageBuilder;\n",
|
||||
"import org.springframework.stereotype.Component;\n",
|
||||
"\n",
|
||||
"import java.util.Arrays;\n",
|
||||
"import java.util.Comparator;\n",
|
||||
"import java.util.StringJoiner;\n",
|
||||
"import java.util.stream.Collectors;\n",
|
||||
"import java.util.stream.IntStream;\n",
|
||||
"\n",
|
||||
"@Component\n",
|
||||
"public class TestProducer {\n",
|
||||
"\n",
|
||||
" public static final String EVENT_TYPE_HEADER = \"event-type\";\n",
|
||||
" private static final String BINDING_NAME = \"testProducer-out-0\";\n",
|
||||
"\n",
|
||||
" private final StreamBridge streamBridge;\n",
|
||||
"\n",
|
||||
" public TestProducer(StreamBridge streamBridge) {\n",
|
||||
" this.streamBridge = streamBridge;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" public void sendMessage(TestDTO payload, String eventType){\n",
|
||||
" Message<TestDTO> message = MessageBuilder\n",
|
||||
" .withPayload(payload)\n",
|
||||
" .setHeader(EVENT_TYPE_HEADER, eventType)\n",
|
||||
" .build();\n",
|
||||
"\n",
|
||||
" streamBridge.send(BINDING_NAME, message);\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" public void test(String t1, String t2) {\n",
|
||||
" var s = t1.length() > t2.length() ? t2 : t1;\n",
|
||||
" var l = t1.length() > t2.length() ? t1 : t2;\n",
|
||||
" var res = true;\n",
|
||||
" for (int i = 0; i < s.length(); i++) {\n",
|
||||
" if (s.charAt(i) == l.charAt(i)) {\n",
|
||||
" res = false;\n",
|
||||
" break;\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" System.out.println(res);\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "00c71128",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_gpt(code):\n",
|
||||
" messages = [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt_for(code)}\n",
|
||||
" ]\n",
|
||||
" stream = openai.chat.completions.create(\n",
|
||||
" model=OPENAI_MODEL,\n",
|
||||
" messages=messages,\n",
|
||||
" stream=True\n",
|
||||
" )\n",
|
||||
" result = \"\"\n",
|
||||
" for chunk in stream:\n",
|
||||
" result += chunk.choices[0].delta.content or \"\"\n",
|
||||
" yield result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ca92f8a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_claude(code):\n",
|
||||
" result = claude.messages.stream(\n",
|
||||
" model=CLAUDE_MODEL,\n",
|
||||
" max_tokens=2000,\n",
|
||||
" system=system_message,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt_for(code)},\n",
|
||||
" ],\n",
|
||||
" )\n",
|
||||
" response = \"\"\n",
|
||||
" with result as stream:\n",
|
||||
" for text in stream.text_stream:\n",
|
||||
" response += text or \"\"\n",
|
||||
" yield response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9dffed4b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_gemini(code):\n",
|
||||
" gemini = genai.GenerativeModel(\n",
|
||||
" model_name=GEMINI_MODEL,\n",
|
||||
" system_instruction=system_message\n",
|
||||
" )\n",
|
||||
" stream = gemini.generate_content(user_prompt_for(code), stream=True)\n",
|
||||
" result = \"\"\n",
|
||||
" for chunk in stream:\n",
|
||||
" result += chunk.text or \"\"\n",
|
||||
" yield result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "31f9c267",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def comment_code(code, model):\n",
|
||||
" if model==\"GPT\":\n",
|
||||
" result = stream_gpt(code)\n",
|
||||
" elif model==\"Claude\":\n",
|
||||
" result = stream_claude(code)\n",
|
||||
" elif model==\"Gemini\":\n",
|
||||
" result = stream_gemini(code)\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"Unknown model\")\n",
|
||||
" yield from result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c04c0a1b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with gr.Blocks() as ui:\n",
|
||||
" with gr.Row():\n",
|
||||
" original_code = gr.Textbox(label=\"Java code:\", lines=10, value=test_code)\n",
|
||||
" commented_code = gr.Markdown(label=\"Commented code:\")\n",
|
||||
" with gr.Row():\n",
|
||||
" model = gr.Dropdown([\"GPT\", \"Claude\", \"Gemini\"], label=\"Select model\", value=\"GPT\")\n",
|
||||
" comment = gr.Button(\"Comment code\")\n",
|
||||
"\n",
|
||||
" comment.click(comment_code, inputs=[original_code, model], outputs=[commented_code])\n",
|
||||
"\n",
|
||||
"ui.launch(inbrowser=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "84d33a5f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ui.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bbd50bf7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Conclusion\n",
|
||||
"\n",
|
||||
"In my personal opinion, at least when using these *cheaper* versions of the models, the result provided by Claude is the best. ChatGPT adds way too many comments even if the system message discourages that. Gemini provides a good result also, but maybe adds a tad too few comments -- although that certainly depends on your personal preferences."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "llms",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,281 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "45ca91c2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# AI tool to generate unit tests for the provided Java code\n",
|
||||
"\n",
|
||||
"Here we build a Gradio App that uses the frontier models to generate unit tests for a java code. For testing purposes I have used the *cheaper* versions of the models, not the ones the leaderboards indicate as the best ones."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f44901f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from openai import OpenAI\n",
|
||||
"import google.generativeai as genai\n",
|
||||
"import anthropic\n",
|
||||
"import gradio as gr"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c47706b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# environment\n",
|
||||
"\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
||||
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
|
||||
"google_api_key = os.getenv('GOOGLE_API_KEY')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "35446b9a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"openai = OpenAI()\n",
|
||||
"claude = anthropic.Anthropic()\n",
|
||||
"genai.configure()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0e899efd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"OPENAI_MODEL = \"gpt-4o-mini\"\n",
|
||||
"CLAUDE_MODEL = \"claude-3-haiku-20240307\"\n",
|
||||
"GEMINI_MODEL = 'gemini-2.0-flash-lite'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47640f53",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"system_message = \"You are an assistant that generates unit test for java code. \"\n",
|
||||
"system_message += \"Generate one JUnit5 test class with all the relevant test cases in it.\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f41ccbf0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def user_prompt_for(code):\n",
|
||||
" user_prompt = \"Generate unit tests for this java code.\\n\\n\"\n",
|
||||
" user_prompt += code\n",
|
||||
" return user_prompt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c57c0000",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_code = \"\"\"\n",
|
||||
"package com.hma.kafkaproducertest.rest;\n",
|
||||
"\n",
|
||||
"import com.hma.kafkaproducertest.model.TestDTO;\n",
|
||||
"import com.hma.kafkaproducertest.producer.TestProducer;\n",
|
||||
"import org.springframework.web.bind.annotation.*;\n",
|
||||
"\n",
|
||||
"@RestController\n",
|
||||
"@RequestMapping(\"/api\")\n",
|
||||
"public class TestController {\n",
|
||||
"\n",
|
||||
" private final TestProducer producer;\n",
|
||||
"\n",
|
||||
" public TestController(TestProducer producer) {\n",
|
||||
" this.producer = producer;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" @PostMapping(\"/event\")\n",
|
||||
" public TestDTO triggerKafkaEvent(@RequestBody TestDTO payload) {\n",
|
||||
" producer.sendMessage(payload, \"test\");\n",
|
||||
" return payload;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "00c71128",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_gpt(code):\n",
|
||||
" messages = [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt_for(code)}\n",
|
||||
" ]\n",
|
||||
" stream = openai.chat.completions.create(\n",
|
||||
" model=OPENAI_MODEL,\n",
|
||||
" messages=messages,\n",
|
||||
" stream=True\n",
|
||||
" )\n",
|
||||
" result = \"\"\n",
|
||||
" for chunk in stream:\n",
|
||||
" result += chunk.choices[0].delta.content or \"\"\n",
|
||||
" yield result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ca92f8a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_claude(code):\n",
|
||||
" result = claude.messages.stream(\n",
|
||||
" model=CLAUDE_MODEL,\n",
|
||||
" max_tokens=2000,\n",
|
||||
" system=system_message,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt_for(code)},\n",
|
||||
" ],\n",
|
||||
" )\n",
|
||||
" response = \"\"\n",
|
||||
" with result as stream:\n",
|
||||
" for text in stream.text_stream:\n",
|
||||
" response += text or \"\"\n",
|
||||
" yield response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9dffed4b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def stream_gemini(code):\n",
|
||||
" gemini = genai.GenerativeModel(\n",
|
||||
" model_name=GEMINI_MODEL,\n",
|
||||
" system_instruction=system_message\n",
|
||||
" )\n",
|
||||
" stream = gemini.generate_content(user_prompt_for(code), stream=True)\n",
|
||||
" result = \"\"\n",
|
||||
" for chunk in stream:\n",
|
||||
" result += chunk.text or \"\"\n",
|
||||
" yield result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "31f9c267",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def generate_tests(code, model):\n",
|
||||
" if model==\"GPT\":\n",
|
||||
" result = stream_gpt(code)\n",
|
||||
" elif model==\"Claude\":\n",
|
||||
" result = stream_claude(code)\n",
|
||||
" elif model==\"Gemini\":\n",
|
||||
" result = stream_gemini(code)\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"Unknown model\")\n",
|
||||
" yield from result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c04c0a1b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with gr.Blocks() as ui:\n",
|
||||
" with gr.Row():\n",
|
||||
" original_code = gr.Textbox(label=\"Java code:\", lines=10, value=test_code)\n",
|
||||
" generated_code = gr.Markdown(label=\"Unit tests:\")\n",
|
||||
" with gr.Row():\n",
|
||||
" model = gr.Dropdown([\"GPT\", \"Claude\", \"Gemini\"], label=\"Select model\", value=\"GPT\")\n",
|
||||
" generate = gr.Button(\"Generate tests\")\n",
|
||||
"\n",
|
||||
" generate.click(generate_tests, inputs=[original_code, model], outputs=[generated_code])\n",
|
||||
"\n",
|
||||
"ui.launch(inbrowser=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "84d33a5f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ui.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bbd50bf7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Conclusion\n",
|
||||
"\n",
|
||||
"The models are missing some information as the `TestDTO` is not defined in the code provided as an input.\n",
|
||||
"\n",
|
||||
"Results:\n",
|
||||
"- Gemini: Generates a well constructed test class with multiple test cases covering scenarios with valid and invalid inputs. It makes assumptions about the content of `TestDTO` and adds a note about those as a comment.\n",
|
||||
"- Claude: Similar approach to unknown format of `TestDTO`, although no comment added about the assumptions made. The test cases are strutured differently, and they don't cover any case of invalid input, which in my opinion is an important test for a REST endpoint.\n",
|
||||
"- GPT: While the other two generated *real* unit tests using the mockito extension, GPT generated a *webMVC* test. The other two relied on the equality impelemntation of `TestDTO`, while GPT checks separately each field in the response. As this type of test spins up the application context, the test won't run without additional configuration. In addition, some imports are missing from the test file.\n",
|
||||
"\n",
|
||||
"It comes down to personal preferences, but I would give the point to Gemini for this one."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "llms",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
710
week5/community-contributions/08_rag_qa_assistant.ipynb
Normal file
710
week5/community-contributions/08_rag_qa_assistant.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,388 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dfe37963-1af6-44fc-a841-8e462443f5e6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Personal Knowledge Worker for Sameer Khadatkar\n",
|
||||
"\n",
|
||||
"This project will use RAG (Retrieval Augmented Generation) to ensure our question/answering assistant has high accuracy.\n",
|
||||
"\n",
|
||||
"This first implementation will use a simple, brute-force type of RAG.."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ba2779af-84ef-4227-9e9e-6eaf0df87e77",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import glob\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"import gradio as gr"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "802137aa-8a74-45e0-a487-d1974927d7ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports for langchain, plotly and Chroma\n",
|
||||
"\n",
|
||||
"from langchain.document_loaders import DirectoryLoader, TextLoader\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.schema import Document\n",
|
||||
"from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n",
|
||||
"from langchain_chroma import Chroma\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from sklearn.manifold import TSNE\n",
|
||||
"import numpy as np\n",
|
||||
"import plotly.graph_objects as go\n",
|
||||
"from langchain.memory import ConversationBufferMemory\n",
|
||||
"from langchain.chains import ConversationalRetrievalChain\n",
|
||||
"from langchain.embeddings import HuggingFaceEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "58c85082-e417-4708-9efe-81a5d55d1424",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# price is a factor, so we're going to use a low cost model\n",
|
||||
"\n",
|
||||
"MODEL = \"gpt-4o-mini\"\n",
|
||||
"db_name = \"vector_db\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ee78efcb-60fe-449e-a944-40bab26261af",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load environment variables in a file called .env\n",
|
||||
"\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "730711a9-6ffe-4eee-8f48-d6cfb7314905",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Read in documents using LangChain's loaders\n",
|
||||
"# Take everything in all the sub-folders of our knowledgebase\n",
|
||||
"\n",
|
||||
"folders = glob.glob(\"sameer-db/*\")\n",
|
||||
"\n",
|
||||
"def add_metadata(doc, doc_type):\n",
|
||||
" doc.metadata[\"doc_type\"] = doc_type\n",
|
||||
" return doc\n",
|
||||
"\n",
|
||||
"text_loader_kwargs = {'encoding': 'utf-8'}\n",
|
||||
"\n",
|
||||
"documents = []\n",
|
||||
"for folder in folders:\n",
|
||||
" doc_type = os.path.basename(folder)\n",
|
||||
" loader = DirectoryLoader(folder, glob=\"**/*.md\", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)\n",
|
||||
" folder_docs = loader.load()\n",
|
||||
" documents.extend([add_metadata(doc, doc_type) for doc in folder_docs])\n",
|
||||
"\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
|
||||
"chunks = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"print(f\"Total number of chunks: {len(chunks)}\")\n",
|
||||
"print(f\"Document types found: {set(doc.metadata['doc_type'] for doc in documents)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "78998399-ac17-4e28-b15f-0b5f51e6ee23",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Put the chunks of data into a Vector Store that associates a Vector Embedding with each chunk\n",
|
||||
"# Chroma is a popular open source Vector Database based on SQLLite\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"\n",
|
||||
"if os.path.exists(db_name):\n",
|
||||
" Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()\n",
|
||||
"\n",
|
||||
"# Create vectorstore\n",
|
||||
"vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=db_name)\n",
|
||||
"print(f\"Vectorstore created with {vectorstore._collection.count()} documents\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ff2e7687-60d4-4920-a1d7-a34b9f70a250",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's investigate the vectors\n",
|
||||
"\n",
|
||||
"collection = vectorstore._collection\n",
|
||||
"count = collection.count()\n",
|
||||
"\n",
|
||||
"sample_embedding = collection.get(limit=1, include=[\"embeddings\"])[\"embeddings\"][0]\n",
|
||||
"dimensions = len(sample_embedding)\n",
|
||||
"print(f\"There are {count:,} vectors with {dimensions:,} dimensions in the vector store\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0d45462-a818-441c-b010-b85b32bcf618",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Visualizing the Vector Store\n",
|
||||
"\n",
|
||||
"Let's take a minute to look at the documents and their embedding vectors to see what's going on."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b98adf5e-d464-4bd2-9bdf-bc5b6770263b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
|
||||
"vectors = np.array(result['embeddings'])\n",
|
||||
"documents = result['documents']\n",
|
||||
"metadatas = result['metadatas']\n",
|
||||
"doc_types = [metadata['doc_type'] for metadata in metadatas]\n",
|
||||
"colors = [['green', 'red'][['personal', 'profile'].index(t)] for t in doc_types]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "427149d5-e5d8-4abd-bb6f-7ef0333cca21",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We humans find it easier to visalize things in 2D!\n",
|
||||
"# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
|
||||
"# (t-distributed stochastic neighbor embedding)\n",
|
||||
"\n",
|
||||
"tsne = TSNE(n_components=2, random_state=42,perplexity=5)\n",
|
||||
"reduced_vectors = tsne.fit_transform(vectors)\n",
|
||||
"\n",
|
||||
"# Create the 2D scatter plot\n",
|
||||
"fig = go.Figure(data=[go.Scatter(\n",
|
||||
" x=reduced_vectors[:, 0],\n",
|
||||
" y=reduced_vectors[:, 1],\n",
|
||||
" mode='markers',\n",
|
||||
" marker=dict(size=5, color=colors, opacity=0.8),\n",
|
||||
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(doc_types, documents)],\n",
|
||||
" hoverinfo='text'\n",
|
||||
")])\n",
|
||||
"\n",
|
||||
"fig.update_layout(\n",
|
||||
" title='2D Chroma Vector Store Visualization',\n",
|
||||
" scene=dict(xaxis_title='x',yaxis_title='y'),\n",
|
||||
" width=800,\n",
|
||||
" height=600,\n",
|
||||
" margin=dict(r=20, b=10, l=10, t=40)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e1418e88-acd5-460a-bf2b-4e6efc88e3dd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's try 3D!\n",
|
||||
"\n",
|
||||
"tsne = TSNE(n_components=3, random_state=42,perplexity=5)\n",
|
||||
"reduced_vectors = tsne.fit_transform(vectors)\n",
|
||||
"\n",
|
||||
"# Create the 3D scatter plot\n",
|
||||
"fig = go.Figure(data=[go.Scatter3d(\n",
|
||||
" x=reduced_vectors[:, 0],\n",
|
||||
" y=reduced_vectors[:, 1],\n",
|
||||
" z=reduced_vectors[:, 2],\n",
|
||||
" mode='markers',\n",
|
||||
" marker=dict(size=5, color=colors, opacity=0.8),\n",
|
||||
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(doc_types, documents)],\n",
|
||||
" hoverinfo='text'\n",
|
||||
")])\n",
|
||||
"\n",
|
||||
"fig.update_layout(\n",
|
||||
" title='3D Chroma Vector Store Visualization',\n",
|
||||
" scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),\n",
|
||||
" width=900,\n",
|
||||
" height=700,\n",
|
||||
" margin=dict(r=20, b=10, l=10, t=40)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9468860b-86a2-41df-af01-b2400cc985be",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Time to use LangChain to bring it all together"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b3942a10-9977-4ae7-9acf-968c43ad0d4a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema import SystemMessage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "45c0fb93-0a16-4e55-857b-1f9fd61ec24c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# create a new Chat with OpenAI\n",
|
||||
"llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n",
|
||||
"\n",
|
||||
"# set up the conversation memory for the chat\n",
|
||||
"memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
|
||||
"memory.chat_memory.messages.insert(0, SystemMessage(\n",
|
||||
" content=\"\"\"You are an AI Assistant specialized in providing accurate information about Sameer Khadatkar. Only respond when the question explicitly asks for information. \n",
|
||||
" Keep your answers brief, factual, and based solely on the information provided. Do not speculate or fabricate details. \n",
|
||||
" For example, if the user simply says \"hi,\" respond with: \"How can I help you?\"\n",
|
||||
" \"\"\"\n",
|
||||
"))\n",
|
||||
"\n",
|
||||
"# the retriever is an abstraction over the VectorStore that will be used during RAG\n",
|
||||
"retriever = vectorstore.as_retriever(k=4)\n",
|
||||
"\n",
|
||||
"# putting it together: set up the conversation chain with the GPT 3.5 LLM, the vector store and memory\n",
|
||||
"conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "968e7bf2-e862-4679-a11f-6c1efb6ec8ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's try a simple question\n",
|
||||
"\n",
|
||||
"query = \"Who are you?\"\n",
|
||||
"result = conversation_chain.invoke({\"question\": query})\n",
|
||||
"print(result[\"answer\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5b5a9013-d5d4-4e25-9e7c-cdbb4f33e319",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# set up a new conversation memory for the chat\n",
|
||||
"memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
|
||||
"\n",
|
||||
"# putting it together: set up the conversation chain with the GPT 4o-mini LLM, the vector store and memory\n",
|
||||
"conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bbbcb659-13ce-47ab-8a5e-01b930494964",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Now we will bring this up in Gradio using the Chat interface -\n",
|
||||
"\n",
|
||||
"A quick and easy way to prototype a chat with an LLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c3536590-85c7-4155-bd87-ae78a1467670",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Wrapping that in a function\n",
|
||||
"\n",
|
||||
"def chat(question, history):\n",
|
||||
" result = conversation_chain.invoke({\"question\": question})\n",
|
||||
" return result[\"answer\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b252d8c1-61a8-406d-b57a-8f708a62b014",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# And in Gradio:\n",
|
||||
"\n",
|
||||
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e23270cf-2d46-4f9e-aeb3-de1673900d2f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3476931e-7d94-4b4d-8cc6-67a1bd5fa79c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,927 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "fOxyiqtzKqLg",
|
||||
"outputId": "714d12c5-775e-42c8-b51c-979a9112b808"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -q datasets requests torch peft bitsandbytes transformers trl accelerate sentencepiece tiktoken matplotlib gradio modal ollama langchain langchain-core langchain-text-splitters langchain-openai langchain-chroma langchain-community faiss-cpu feedparser"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "zyxwwUw6LWXK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import glob\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"import gradio as gr"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Zzqc9nk1L_5w",
|
||||
"outputId": "0af5e1bb-2ccb-4838-b7a5-76c19285d094"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredPDFLoader\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.schema import Document\n",
|
||||
"from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n",
|
||||
"from langchain_chroma import Chroma\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from sklearn.manifold import TSNE\n",
|
||||
"import numpy as np\n",
|
||||
"import plotly.graph_objects as go\n",
|
||||
"from langchain.memory import ConversationBufferMemory\n",
|
||||
"from langchain.chains import ConversationalRetrievalChain\n",
|
||||
"from langchain.embeddings import HuggingFaceEmbeddings\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"import torch\n",
|
||||
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, set_seed\n",
|
||||
"from google.colab import userdata\n",
|
||||
"from google.colab import drive\n",
|
||||
"drive.mount('/content/drive')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "u_vbe1itNZ2n"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"base_path = \"/content/drive/MyDrive/sameer-db\"\n",
|
||||
"folders = glob.glob(os.path.join(base_path, \"*\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "f0lJBMjhMrLO",
|
||||
"outputId": "5cdc6327-3a3a-4d5b-ca05-4c1383c020e2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def add_metadata(doc, doc_type):\n",
|
||||
" doc.metadata[\"doc_type\"] = doc_type\n",
|
||||
" return doc\n",
|
||||
"\n",
|
||||
"# With thanks to CG and Jon R, students on the course, for this fix needed for some users\n",
|
||||
"text_loader_kwargs = {'encoding': 'utf-8'}\n",
|
||||
"# If that doesn't work, some Windows users might need to uncomment the next line instead\n",
|
||||
"# text_loader_kwargs={'autodetect_encoding': True}\n",
|
||||
"\n",
|
||||
"documents = []\n",
|
||||
"for folder in folders:\n",
|
||||
" doc_type = os.path.basename(folder)\n",
|
||||
" loader = DirectoryLoader(folder, glob=\"**/*.md\", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)\n",
|
||||
" folder_docs = loader.load()\n",
|
||||
" documents.extend([add_metadata(doc, doc_type) for doc in folder_docs])\n",
|
||||
"\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
|
||||
"chunks = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"print(f\"Total number of chunks: {len(chunks)}\")\n",
|
||||
"print(f\"Document types found: {set(doc.metadata['doc_type'] for doc in documents)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "zSjwqZ3YNBLp"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hf_token = userdata.get('HF_TOKEN')\n",
|
||||
"login(hf_token, add_to_git_credential=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "t7rraUyHNkdP"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Phi_4 = \"microsoft/Phi-4-mini-instruct\"\n",
|
||||
"db_name = \"/content/drive/MyDrive/phi_vector_db\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "pDjj2S5ZPzF1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"quant_config = BitsAndBytesConfig(\n",
|
||||
" load_in_4bit=True,\n",
|
||||
" bnb_4bit_use_double_quant=True,\n",
|
||||
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
||||
" bnb_4bit_quant_type=\"nf4\"\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 66,
|
||||
"referenced_widgets": [
|
||||
"2a0377fc1e0c4c08944be1857c4e2409",
|
||||
"7c8335e0c3f8459d89f3b9815a896e39",
|
||||
"0fcb91f0551a4871b747f82e5fa6ff38",
|
||||
"fa5c6cf8395840e08e2743d6e88190be",
|
||||
"8613224ada934e7ba57fd5184ea61044",
|
||||
"1180c8fe49e94873a024d38d33649852",
|
||||
"4395c417cc854fc48da18d0ddd62671e",
|
||||
"d678106a6601478cb5712991604788f0",
|
||||
"5c4a8d25dbc942d5a596c8fa8580a785",
|
||||
"c1b076c063e04536831d68e5e48f1692",
|
||||
"9bcee7f185434cd0b1a998448236548c"
|
||||
]
|
||||
},
|
||||
"id": "qzQzgir5VUBF",
|
||||
"outputId": "1e7198a3-4857-49ab-f368-d430beddbf42"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = AutoTokenizer.from_pretrained(Phi_4, trust_remote_code=True)\n",
|
||||
"tokenizer.pad_token = tokenizer.eos_token\n",
|
||||
"tokenizer.padding_side = \"right\"\n",
|
||||
"\n",
|
||||
"base_model = AutoModelForCausalLM.from_pretrained(\n",
|
||||
" Phi_4,\n",
|
||||
" quantization_config=quant_config,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
")\n",
|
||||
"base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
|
||||
"\n",
|
||||
"print(f\"Memory footprint: {base_model.get_memory_footprint() / 1e9:.1f} GB\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "MjK3mBKHQBra"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.base import Embeddings\n",
|
||||
"from typing import List\n",
|
||||
"import torch.nn.functional as F"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Q1BIMVW4Pf0A"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class PHI4Embeddings(Embeddings):\n",
|
||||
" def __init__(self, tokenizer, model):\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
" self.model = model\n",
|
||||
" self.model.eval()\n",
|
||||
"\n",
|
||||
" def embed_documents(self, texts: List[str]) -> List[List[float]]:\n",
|
||||
" embeddings = []\n",
|
||||
" for text in texts:\n",
|
||||
" with torch.no_grad():\n",
|
||||
" inputs = self.tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=512).to(self.model.device)\n",
|
||||
" outputs = self.model(**inputs, output_hidden_states=True)\n",
|
||||
" hidden_states = outputs.hidden_states[-1] # Last layer\n",
|
||||
" attention_mask = inputs[\"attention_mask\"].unsqueeze(-1)\n",
|
||||
" pooled = (hidden_states * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)\n",
|
||||
" normalized = F.normalize(pooled, p=2, dim=1)\n",
|
||||
" embeddings.append(normalized[0].cpu().tolist())\n",
|
||||
" return embeddings\n",
|
||||
"\n",
|
||||
" def embed_query(self, text: str) -> List[float]:\n",
|
||||
" return self.embed_documents([text])[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "7aUTue_mMxof"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Put the chunks of data into a Vector Store that associates a Vector Embedding with each chunk\n",
|
||||
"\n",
|
||||
"embeddings = PHI4Embeddings(tokenizer, base_model)\n",
|
||||
"\n",
|
||||
"# Delete if already exists\n",
|
||||
"\n",
|
||||
"if os.path.exists(db_name):\n",
|
||||
" Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "uWSe-8mATUag",
|
||||
"outputId": "296804af-2283-435a-908c-48adaa6b4fd9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create vectorstore\n",
|
||||
"vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=db_name)\n",
|
||||
"print(f\"Vectorstore created with {vectorstore._collection.count()} documents\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "1ZQ6agxtSLp5",
|
||||
"outputId": "8e5bf8a7-fbaf-427b-9a67-369945aba80e"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's investigate the vectors\n",
|
||||
"\n",
|
||||
"collection = vectorstore._collection\n",
|
||||
"count = collection.count()\n",
|
||||
"\n",
|
||||
"sample_embedding = collection.get(limit=1, include=[\"embeddings\"])[\"embeddings\"][0]\n",
|
||||
"dimensions = len(sample_embedding)\n",
|
||||
"print(f\"There are {count:,} vectors with {dimensions:,} dimensions in the vector store\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "qBIOPr2YT5FM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Prework\n",
|
||||
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
|
||||
"vectors = np.array(result['embeddings'])\n",
|
||||
"documents = result['documents']\n",
|
||||
"metadatas = result['metadatas']\n",
|
||||
"doc_types = [metadata['doc_type'] for metadata in metadatas]\n",
|
||||
"colors = [['blue', 'red'][['personal', 'profile'].index(t)] for t in doc_types]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 617
|
||||
},
|
||||
"id": "fnuul36bUB3h",
|
||||
"outputId": "f6cf1650-910a-4a03-f92d-9c200fb37de7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We humans find it easier to visalize things in 2D!\n",
|
||||
"# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
|
||||
"# (t-distributed stochastic neighbor embedding)\n",
|
||||
"\n",
|
||||
"tsne = TSNE(n_components=2, random_state=42, perplexity=4)\n",
|
||||
"reduced_vectors = tsne.fit_transform(vectors)\n",
|
||||
"\n",
|
||||
"# Create the 2D scatter plot\n",
|
||||
"fig = go.Figure(data=[go.Scatter(\n",
|
||||
" x=reduced_vectors[:, 0],\n",
|
||||
" y=reduced_vectors[:, 1],\n",
|
||||
" mode='markers',\n",
|
||||
" marker=dict(size=5, color=colors, opacity=0.8),\n",
|
||||
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(doc_types, documents)],\n",
|
||||
" hoverinfo='text'\n",
|
||||
")])\n",
|
||||
"\n",
|
||||
"fig.update_layout(\n",
|
||||
" title='2D Chroma Vector Store Visualization',\n",
|
||||
" scene=dict(xaxis_title='x',yaxis_title='y'),\n",
|
||||
" width=800,\n",
|
||||
" height=600,\n",
|
||||
" margin=dict(r=20, b=10, l=10, t=40)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 717
|
||||
},
|
||||
"id": "Dgaeb7aRUF5d",
|
||||
"outputId": "47546459-e169-4d2b-d0d7-4ebd135556e0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's try 3D!\n",
|
||||
"\n",
|
||||
"tsne = TSNE(n_components=3, random_state=42, perplexity=4)\n",
|
||||
"reduced_vectors = tsne.fit_transform(vectors)\n",
|
||||
"\n",
|
||||
"# Create the 3D scatter plot\n",
|
||||
"fig = go.Figure(data=[go.Scatter3d(\n",
|
||||
" x=reduced_vectors[:, 0],\n",
|
||||
" y=reduced_vectors[:, 1],\n",
|
||||
" z=reduced_vectors[:, 2],\n",
|
||||
" mode='markers',\n",
|
||||
" marker=dict(size=5, color=colors, opacity=0.8),\n",
|
||||
" text=[f\"Type: {t}<br>Text: {d[:100]}...\" for t, d in zip(doc_types, documents)],\n",
|
||||
" hoverinfo='text'\n",
|
||||
")])\n",
|
||||
"\n",
|
||||
"fig.update_layout(\n",
|
||||
" title='3D Chroma Vector Store Visualization',\n",
|
||||
" scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),\n",
|
||||
" width=900,\n",
|
||||
" height=700,\n",
|
||||
" margin=dict(r=20, b=10, l=10, t=40)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"fig.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "BZcCyGI3YEwJ",
|
||||
"outputId": "fd03e6ee-2ec1-4c6b-c14b-986255ca070c"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import HuggingFacePipeline\n",
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\n",
|
||||
" \"text-generation\",\n",
|
||||
" model=base_model,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" max_new_tokens=4069,\n",
|
||||
" return_full_text=False,\n",
|
||||
" temperature=0.7\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"llm = HuggingFacePipeline(pipeline=pipe)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "WDY8-1gJUM1v"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# set up the conversation memory for the chat\n",
|
||||
"from langchain.schema import SystemMessage\n",
|
||||
"memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
|
||||
"# memory.chat_memory.add_message(SystemMessage(content='''You are a helpful assistant that answers questions about Sameer Khadatkar **in English only**, based only on the retrieved documents.\n",
|
||||
"# Do not respond in any other language.'''))\n",
|
||||
"\n",
|
||||
"# the retriever is an abstraction over the VectorStore that will be used during RAG\n",
|
||||
"retriever = vectorstore.as_retriever(k=2)\n",
|
||||
"\n",
|
||||
"# putting it together: set up the conversation chain with the GPT 3.5 LLM, the vector store and memory\n",
|
||||
"conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "dkuv5wD6jCrX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def extract_first_helpful_answer(output: str) -> str:\n",
|
||||
" if \"Helpful Answer:\" in output:\n",
|
||||
" parts = output.split(\"Helpful Answer:\")\n",
|
||||
" return parts[0].strip().split(\"\\n\")[0].strip() # Take only the first line after it\n",
|
||||
" return output.strip()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ZY5BH4C3UY1E"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"Who is Sameer\"\n",
|
||||
"result = conversation_chain.invoke({\"question\": query})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "7n5PcQw0iRjO",
|
||||
"outputId": "794c4dad-efde-4220-a9bd-50a1ae156229"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(extract_first_helpful_answer(result[\"answer\"]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "vW025q5Tkwc3",
|
||||
"outputId": "e57d34e5-a64c-4e0b-e29b-d887214331c4"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "JIev764VkCht"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# set up a new conversation memory for the chat\n",
|
||||
"memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n",
|
||||
"\n",
|
||||
"# putting it together: set up the conversation chain with the GPT 4o-mini LLM, the vector store and memory\n",
|
||||
"conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "OO9o_VBholCx"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Wrapping that in a function\n",
|
||||
"\n",
|
||||
"def chat(question, history):\n",
|
||||
" result = conversation_chain.invoke({\"question\": question})\n",
|
||||
" return extract_first_helpful_answer(result[\"answer\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 646
|
||||
},
|
||||
"id": "zOqiuWqCo04a",
|
||||
"outputId": "fcb89961-1687-4d54-fcdd-ca5c590d69de"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# And in Gradio:\n",
|
||||
"\n",
|
||||
"view = gr.ChatInterface(chat, type=\"messages\").launch(inbrowser=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "qIYSDiQUo5WX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "T4",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
},
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {
|
||||
"0fcb91f0551a4871b747f82e5fa6ff38": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_module_version": "1.5.0",
|
||||
"model_name": "FloatProgressModel",
|
||||
"state": {
|
||||
"_dom_classes": [],
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "FloatProgressModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/controls",
|
||||
"_view_module_version": "1.5.0",
|
||||
"_view_name": "ProgressView",
|
||||
"bar_style": "success",
|
||||
"description": "",
|
||||
"description_tooltip": null,
|
||||
"layout": "IPY_MODEL_d678106a6601478cb5712991604788f0",
|
||||
"max": 2,
|
||||
"min": 0,
|
||||
"orientation": "horizontal",
|
||||
"style": "IPY_MODEL_5c4a8d25dbc942d5a596c8fa8580a785",
|
||||
"value": 2
|
||||
}
|
||||
},
|
||||
"1180c8fe49e94873a024d38d33649852": {
|
||||
"model_module": "@jupyter-widgets/base",
|
||||
"model_module_version": "1.2.0",
|
||||
"model_name": "LayoutModel",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/base",
|
||||
"_model_module_version": "1.2.0",
|
||||
"_model_name": "LayoutModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "LayoutView",
|
||||
"align_content": null,
|
||||
"align_items": null,
|
||||
"align_self": null,
|
||||
"border": null,
|
||||
"bottom": null,
|
||||
"display": null,
|
||||
"flex": null,
|
||||
"flex_flow": null,
|
||||
"grid_area": null,
|
||||
"grid_auto_columns": null,
|
||||
"grid_auto_flow": null,
|
||||
"grid_auto_rows": null,
|
||||
"grid_column": null,
|
||||
"grid_gap": null,
|
||||
"grid_row": null,
|
||||
"grid_template_areas": null,
|
||||
"grid_template_columns": null,
|
||||
"grid_template_rows": null,
|
||||
"height": null,
|
||||
"justify_content": null,
|
||||
"justify_items": null,
|
||||
"left": null,
|
||||
"margin": null,
|
||||
"max_height": null,
|
||||
"max_width": null,
|
||||
"min_height": null,
|
||||
"min_width": null,
|
||||
"object_fit": null,
|
||||
"object_position": null,
|
||||
"order": null,
|
||||
"overflow": null,
|
||||
"overflow_x": null,
|
||||
"overflow_y": null,
|
||||
"padding": null,
|
||||
"right": null,
|
||||
"top": null,
|
||||
"visibility": null,
|
||||
"width": null
|
||||
}
|
||||
},
|
||||
"2a0377fc1e0c4c08944be1857c4e2409": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_module_version": "1.5.0",
|
||||
"model_name": "HBoxModel",
|
||||
"state": {
|
||||
"_dom_classes": [],
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "HBoxModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/controls",
|
||||
"_view_module_version": "1.5.0",
|
||||
"_view_name": "HBoxView",
|
||||
"box_style": "",
|
||||
"children": [
|
||||
"IPY_MODEL_7c8335e0c3f8459d89f3b9815a896e39",
|
||||
"IPY_MODEL_0fcb91f0551a4871b747f82e5fa6ff38",
|
||||
"IPY_MODEL_fa5c6cf8395840e08e2743d6e88190be"
|
||||
],
|
||||
"layout": "IPY_MODEL_8613224ada934e7ba57fd5184ea61044"
|
||||
}
|
||||
},
|
||||
"4395c417cc854fc48da18d0ddd62671e": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_module_version": "1.5.0",
|
||||
"model_name": "DescriptionStyleModel",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "DescriptionStyleModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "StyleView",
|
||||
"description_width": ""
|
||||
}
|
||||
},
|
||||
"5c4a8d25dbc942d5a596c8fa8580a785": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_module_version": "1.5.0",
|
||||
"model_name": "ProgressStyleModel",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "ProgressStyleModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "StyleView",
|
||||
"bar_color": null,
|
||||
"description_width": ""
|
||||
}
|
||||
},
|
||||
"7c8335e0c3f8459d89f3b9815a896e39": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_module_version": "1.5.0",
|
||||
"model_name": "HTMLModel",
|
||||
"state": {
|
||||
"_dom_classes": [],
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "HTMLModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/controls",
|
||||
"_view_module_version": "1.5.0",
|
||||
"_view_name": "HTMLView",
|
||||
"description": "",
|
||||
"description_tooltip": null,
|
||||
"layout": "IPY_MODEL_1180c8fe49e94873a024d38d33649852",
|
||||
"placeholder": "",
|
||||
"style": "IPY_MODEL_4395c417cc854fc48da18d0ddd62671e",
|
||||
"value": "Loading checkpoint shards: 100%"
|
||||
}
|
||||
},
|
||||
"8613224ada934e7ba57fd5184ea61044": {
|
||||
"model_module": "@jupyter-widgets/base",
|
||||
"model_module_version": "1.2.0",
|
||||
"model_name": "LayoutModel",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/base",
|
||||
"_model_module_version": "1.2.0",
|
||||
"_model_name": "LayoutModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "LayoutView",
|
||||
"align_content": null,
|
||||
"align_items": null,
|
||||
"align_self": null,
|
||||
"border": null,
|
||||
"bottom": null,
|
||||
"display": null,
|
||||
"flex": null,
|
||||
"flex_flow": null,
|
||||
"grid_area": null,
|
||||
"grid_auto_columns": null,
|
||||
"grid_auto_flow": null,
|
||||
"grid_auto_rows": null,
|
||||
"grid_column": null,
|
||||
"grid_gap": null,
|
||||
"grid_row": null,
|
||||
"grid_template_areas": null,
|
||||
"grid_template_columns": null,
|
||||
"grid_template_rows": null,
|
||||
"height": null,
|
||||
"justify_content": null,
|
||||
"justify_items": null,
|
||||
"left": null,
|
||||
"margin": null,
|
||||
"max_height": null,
|
||||
"max_width": null,
|
||||
"min_height": null,
|
||||
"min_width": null,
|
||||
"object_fit": null,
|
||||
"object_position": null,
|
||||
"order": null,
|
||||
"overflow": null,
|
||||
"overflow_x": null,
|
||||
"overflow_y": null,
|
||||
"padding": null,
|
||||
"right": null,
|
||||
"top": null,
|
||||
"visibility": null,
|
||||
"width": null
|
||||
}
|
||||
},
|
||||
"9bcee7f185434cd0b1a998448236548c": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_module_version": "1.5.0",
|
||||
"model_name": "DescriptionStyleModel",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "DescriptionStyleModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "StyleView",
|
||||
"description_width": ""
|
||||
}
|
||||
},
|
||||
"c1b076c063e04536831d68e5e48f1692": {
|
||||
"model_module": "@jupyter-widgets/base",
|
||||
"model_module_version": "1.2.0",
|
||||
"model_name": "LayoutModel",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/base",
|
||||
"_model_module_version": "1.2.0",
|
||||
"_model_name": "LayoutModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "LayoutView",
|
||||
"align_content": null,
|
||||
"align_items": null,
|
||||
"align_self": null,
|
||||
"border": null,
|
||||
"bottom": null,
|
||||
"display": null,
|
||||
"flex": null,
|
||||
"flex_flow": null,
|
||||
"grid_area": null,
|
||||
"grid_auto_columns": null,
|
||||
"grid_auto_flow": null,
|
||||
"grid_auto_rows": null,
|
||||
"grid_column": null,
|
||||
"grid_gap": null,
|
||||
"grid_row": null,
|
||||
"grid_template_areas": null,
|
||||
"grid_template_columns": null,
|
||||
"grid_template_rows": null,
|
||||
"height": null,
|
||||
"justify_content": null,
|
||||
"justify_items": null,
|
||||
"left": null,
|
||||
"margin": null,
|
||||
"max_height": null,
|
||||
"max_width": null,
|
||||
"min_height": null,
|
||||
"min_width": null,
|
||||
"object_fit": null,
|
||||
"object_position": null,
|
||||
"order": null,
|
||||
"overflow": null,
|
||||
"overflow_x": null,
|
||||
"overflow_y": null,
|
||||
"padding": null,
|
||||
"right": null,
|
||||
"top": null,
|
||||
"visibility": null,
|
||||
"width": null
|
||||
}
|
||||
},
|
||||
"d678106a6601478cb5712991604788f0": {
|
||||
"model_module": "@jupyter-widgets/base",
|
||||
"model_module_version": "1.2.0",
|
||||
"model_name": "LayoutModel",
|
||||
"state": {
|
||||
"_model_module": "@jupyter-widgets/base",
|
||||
"_model_module_version": "1.2.0",
|
||||
"_model_name": "LayoutModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/base",
|
||||
"_view_module_version": "1.2.0",
|
||||
"_view_name": "LayoutView",
|
||||
"align_content": null,
|
||||
"align_items": null,
|
||||
"align_self": null,
|
||||
"border": null,
|
||||
"bottom": null,
|
||||
"display": null,
|
||||
"flex": null,
|
||||
"flex_flow": null,
|
||||
"grid_area": null,
|
||||
"grid_auto_columns": null,
|
||||
"grid_auto_flow": null,
|
||||
"grid_auto_rows": null,
|
||||
"grid_column": null,
|
||||
"grid_gap": null,
|
||||
"grid_row": null,
|
||||
"grid_template_areas": null,
|
||||
"grid_template_columns": null,
|
||||
"grid_template_rows": null,
|
||||
"height": null,
|
||||
"justify_content": null,
|
||||
"justify_items": null,
|
||||
"left": null,
|
||||
"margin": null,
|
||||
"max_height": null,
|
||||
"max_width": null,
|
||||
"min_height": null,
|
||||
"min_width": null,
|
||||
"object_fit": null,
|
||||
"object_position": null,
|
||||
"order": null,
|
||||
"overflow": null,
|
||||
"overflow_x": null,
|
||||
"overflow_y": null,
|
||||
"padding": null,
|
||||
"right": null,
|
||||
"top": null,
|
||||
"visibility": null,
|
||||
"width": null
|
||||
}
|
||||
},
|
||||
"fa5c6cf8395840e08e2743d6e88190be": {
|
||||
"model_module": "@jupyter-widgets/controls",
|
||||
"model_module_version": "1.5.0",
|
||||
"model_name": "HTMLModel",
|
||||
"state": {
|
||||
"_dom_classes": [],
|
||||
"_model_module": "@jupyter-widgets/controls",
|
||||
"_model_module_version": "1.5.0",
|
||||
"_model_name": "HTMLModel",
|
||||
"_view_count": null,
|
||||
"_view_module": "@jupyter-widgets/controls",
|
||||
"_view_module_version": "1.5.0",
|
||||
"_view_name": "HTMLView",
|
||||
"description": "",
|
||||
"description_tooltip": null,
|
||||
"layout": "IPY_MODEL_c1b076c063e04536831d68e5e48f1692",
|
||||
"placeholder": "",
|
||||
"style": "IPY_MODEL_9bcee7f185434cd0b1a998448236548c",
|
||||
"value": " 2/2 [00:41<00:00, 19.69s/it]"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
# Sameer Khadatkar
|
||||
|
||||
Hi, I am **Sameer Khadatkar**, born and brought up in **Nagpur**.
|
||||
|
||||
I completed my schooling from **Dinanath Junior College and High School, Nagpur** up to 12th standard. After that, I moved to **Amravati** for my Bachelor's degree.
|
||||
|
||||
### Academic Journey
|
||||
I prepared for the **GATE Mechanical Engineering (ME)** exam:
|
||||
- **2020**: Rank **377**
|
||||
|
||||
With this rank, I secured admission to the prestigious **Indian Institute of Science (IISc), Bangalore**.
|
||||
|
||||
### Career
|
||||
I later got placed at **Wells Fargo**, Hyderabad.
|
||||
|
||||
### Personal Life
|
||||
- I got married to my batchmate from Government College of Engineering Amravati.
|
||||
|
||||
### Hobbies & Interests
|
||||
I played **Cycle Polo** up to my 8th standard and even competed at the **national level**.
|
||||
|
||||
### Family
|
||||
- Parents, elder sister and wife.
|
||||
@@ -0,0 +1,145 @@
|
||||
# Sameer Raju Khadatkar
|
||||
|
||||
**Quant AI/ML @ Wells Fargo | M.Tech. (CDS) @ IISc, Bangalore | B.Tech. (Mechanical) @ GCOE, Amravati**
|
||||
📍 Hyderabad, Telangana, India
|
||||
📧 sameer123khadatkar@gmail.com
|
||||
🔗 [LinkedIn](https://www.linkedin.com/in/sameer-khadatkar/)
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
I currently serve as a Quantitative Analytics Specialist within Wells Fargo's Model Risk Management (MRM) team at India and Philippines. My primary responsibility involves validating AI/ML models, with a focus on fraud detection, as well as models used in marketing, credit scoring, and natural language processing (NLP). In this role, I ensure the conceptual soundness of models, conduct performance testing, conduct explainability analysis and rigorously challenge models by developing challenger models to detect weaknesses.
|
||||
|
||||
Additionally, I ensure compliance with regulatory standards set by Wells Fargo, in alignment with guidelines from the Federal Reserve and the OCC. I work closely with model development and risk management teams, providing validation feedback and recommending improvements. I also contribute to documentation and reporting, preparing validation reports, and ensuring the ongoing monitoring of model performance.
|
||||
|
||||
With a strong foundation in Machine Learning, Deep Learning, and High-Performance Computing gained during my graduate studies at the Indian Institute of Science, Bangalore, and a Bachelor's degree in Mechanical Engineering, I bring a unique blend of skills at the intersection of advanced technology and engineering. My expertise allows me to tackle complex challenges, drive innovation, and contribute to cutting-edge solutions in diverse industries.
|
||||
|
||||
---
|
||||
|
||||
## Professional Experience
|
||||
|
||||
### Wells Fargo International Solutions Private Ltd
|
||||
**Quantitative Analytics Specialist – AVP**
|
||||
📍 Hyderabad, Telangana, India
|
||||
📅 August 2022 – September 2023
|
||||
|
||||
- Collaborating with a team overseeing an inventory of ∼300 models focused on Fraud Detection, primarily utilizing Logistic Regression, Extreme Gradient Boosting (XGBoost), and Neural Network models.
|
||||
- Conduct validation of AI/ML models by ensuring conceptual soundness, performing performance testing, carrying out explainability analysis, and developing surrogate, challenger, and offset models to uncover potential weaknesses.
|
||||
- Joined the team during its expansion in India, playing a key role in building trust with US stakeholders. Recognized with the **Manager’s Spotlight Award** for outstanding dedication and contributions.
|
||||
- Developing a module to assist Validators in benchmarking anomaly detection models (Isolation Forest, Extended Isolation Forest, Autoencoders, Histogram-Based Outlier Score (HBOS), etc.) and assessing them using clustering performance metrics.
|
||||
- Created a validation playbook for fraud detection vendor models and developed an Excel-based policy library to facilitate quick reference for team members.
|
||||
|
||||
---
|
||||
|
||||
## Highlighted Projects at Wells Fargo
|
||||
|
||||
### ✅ Check Authorization Model | Validation
|
||||
|
||||
- Validated a high-impact machine learning model for check authorization, ensuring compliance with regulatory and bank's MRM standards.
|
||||
- Reviewed model objectives, assumptions, architecture, and data pipeline.
|
||||
- Assessed performance using AUC, recall, KS statistic, and PSI across time.
|
||||
- Performed explainability analysis using multicollinearity checks, surrogate models (overall and segment level), SHAP, PDP, H-Statistic, 2D-PDPs, and sensitivity analysis.
|
||||
- Identified local weaknesses through segmentation and built offset models to detect missed signals.
|
||||
- Developed challenger models using YOLOv5, SigNet, TrOCR (Transformer-based OCR), XGBoost model, and pixel-based feature engineering.
|
||||
|
||||
### 🧠 Word Embedding Explainability Research
|
||||
|
||||
- Collaborated with the Bank’s Chief Model Risk Officer on a research project focused on the explainability of word embeddings using clustering techniques such as Spectral Clustering, HDBSCAN, and analysis of ReLU neural network activation patterns.
|
||||
- Utilized Sentence Transformer embeddings (SBERT) and applied dimensionality reduction methods including PCA, UMAP, and t-SNE for cluster interpretation and visualization.
|
||||
- Extended the research by developing a Mixture of Experts model leveraging XGBoost.
|
||||
|
||||
---
|
||||
|
||||
## Education
|
||||
|
||||
**Indian Institute of Science (IISc), Bangalore**
|
||||
📅 2020 – 2022
|
||||
🎓 Master of Technology (M.Tech.), Computational and Data Sciences
|
||||
📍 Bengaluru, Karnataka
|
||||
**CGPA:** 9.1 / 10.0
|
||||
|
||||
**Government College of Engineering, Amravati (GCoEA)**
|
||||
📅 2015 – 2019
|
||||
🎓 Bachelor of Technology (B.Tech.), Mechanical Engineering
|
||||
📍 Amravati, Maharashtra
|
||||
**CGPA:** 8.29 / 10.0
|
||||
|
||||
---
|
||||
|
||||
## Certifications
|
||||
|
||||
- Advanced Data Science with IBM (Coursera)
|
||||
- HYPERMESH (SHELL MESH AND SOLID MESH)
|
||||
- Introduction to Big Data (Coursera)
|
||||
- MASTERCAM (Design, Turning and Milling)
|
||||
- CREO PARAMETRIC
|
||||
|
||||
---
|
||||
|
||||
## Research Publication
|
||||
|
||||
**Subspace Recursive Fermi-Operator Expansion Strategies for Large-Scale DFT Eigenvalue Problems on HPC Architectures**
|
||||
📝 Sameer Khadatkar, Phani Motamarri (MATRIX Lab)
|
||||
📅 July 20, 2023
|
||||
📚 *Journal of Chemical Physics, 159, 031102 (2023)*
|
||||
🔗 [Publication Link](https://pubs.aip.org/aip/jcp/article/159/3/031102/2903241/Subspace-recursive-Fermi-operator-expansion)
|
||||
|
||||
- Implemented recursive Fermi-operator expansion methods on multi-node CPU (PARAM Pravega) and GPU (ORNL Summit) systems for large-scale DFT problems.
|
||||
- Applied mixed-precision strategies achieving 2× to 4× speedup over diagonalization.
|
||||
- Benchmarked using MPI and SLATE for distributed dense linear algebra.
|
||||
|
||||
---
|
||||
|
||||
## Academic, Independent and Other Projects
|
||||
|
||||
- **LLM-Powered Multimodal Airline Chatbot**: Built a chatbot with GPT-4o-mini, supporting both text and voice, generating pop-art city images. Stack: Python, Gradio, custom tools.
|
||||
- **Future Stock Price Prediction for MAANG**: Used yfinance, Stateful LSTM vs XGBoost. LSTM outperformed with ~0.02 MAE.
|
||||
- **Duplicate Question Detection**: LSTM Siamese Network with Word2Vec and GloVe. GloVe performed better.
|
||||
- **Music Genre Classification**: Used MFCCs and spectral features. Best result: 76% ± 3% accuracy with SVM.
|
||||
- **Algorithm Implementation from Scratch**: PCA, LDA, GMM, TF-IDF, and backpropagation for DNNs.
|
||||
|
||||
---
|
||||
|
||||
## Skills
|
||||
|
||||
**Knowledge Areas:**
|
||||
Model Risk Management, Machine Learning, Deep Learning, High-Performance Computing
|
||||
|
||||
**Programming Languages:**
|
||||
Python, C, C++ (OpenMP, MPI, CUDA), SQL
|
||||
|
||||
**Python Libraries & Tools:**
|
||||
Numpy, Pandas, Scikit-Learn, PyTorch, TensorFlow (Keras), PySpark, Matplotlib
|
||||
|
||||
---
|
||||
|
||||
## Relevant Courses
|
||||
|
||||
- Machine Learning for Signal Processing (IISc)
|
||||
- Advanced Data Science with IBM (Coursera)
|
||||
- Deep Learning (NPTEL)
|
||||
- Pattern Recognition and Neural Networks (NPTEL)
|
||||
- Numerical Linear Algebra (IISc)
|
||||
- Data Analysis and Visualization (IISc)
|
||||
- Numerical Solution of Differential Equations (IISc)
|
||||
- Parallel Programming (IISc)
|
||||
- Introduction to Big Data (Coursera)
|
||||
- LLM Engineering: Master AI, Large Language Models & Agents (Udemy)
|
||||
|
||||
---
|
||||
|
||||
## Extracurricular Activities
|
||||
|
||||
- **Project Associate** at MATRIX Lab, CDS Department, IISc.
|
||||
- **Teaching Assistant** for “DS284: Numerical Linear Algebra” at IISc.
|
||||
- Led suspension operations for SAE BAJA Team at GCoE Amravati.
|
||||
- Organized Annual Social Gathering as Joint Secretary at GCoE Amravati.
|
||||
|
||||
---
|
||||
|
||||
## Top Skills
|
||||
|
||||
- Data Reporting
|
||||
- SQL
|
||||
- Microsoft Excel
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,510 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "12934dbc-ff4f-4dfc-8cc1-d92cc8826cf2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 🔍 Predicting Item Prices from Descriptions (Part 4)\n",
|
||||
"---\n",
|
||||
"- Data Curation & Preprocessing\n",
|
||||
"- Model Benchmarking – Traditional ML vs LLMs\n",
|
||||
"- E5 Embeddings & RAG\n",
|
||||
"- ➡️ Fine-Tuning GPT-4o Mini\n",
|
||||
"- Evaluating LLaMA 3.1 8B Quantized\n",
|
||||
"- Fine-Tuning LLaMA 3.1 with QLoRA\n",
|
||||
"- Evaluating Fine-Tuned LLaMA\n",
|
||||
"- Summary & Leaderboard\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"# 🔧 Part 4: Fine-Tuning GPT-4o Mini\n",
|
||||
"\n",
|
||||
"- 🧑💻 Skill Level: Advanced\n",
|
||||
"- ⚙️ Hardware: ✅ CPU is sufficient — no GPU required\n",
|
||||
"- 🛠️ Requirements: 🔑 HF Token, Open API Key, wandb API Key\n",
|
||||
"- Tasks:\n",
|
||||
" - Convert chat data to .jsonl format for OpenAI\n",
|
||||
" - Fine-tune the model and monitor with Weights & Biases\n",
|
||||
" - Test the fine-tuned GPT-4o Mini \n",
|
||||
"\n",
|
||||
"Can fine-tuning GPT-4o Mini outperform both its zero-shot baseline and RAG-enhanced version? \n",
|
||||
"Time to find out.\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5809630f-d3ea-41df-86ec-9cbf59a46f5c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import importlib\n",
|
||||
"import json\n",
|
||||
"import re\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"from openai import OpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4120c84d-c310-4d31-9e1f-1549ea4a4186",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"load_dotenv(override=True)\n",
|
||||
"\n",
|
||||
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
||||
"if not openai_api_key:\n",
|
||||
" print(\"❌ OPENAI_API_KEY is missing\")\n",
|
||||
"\n",
|
||||
"openai = OpenAI(api_key=openai_api_key)\n",
|
||||
"\n",
|
||||
"hf_token = os.getenv('HF_TOKEN')\n",
|
||||
"if not hf_token:\n",
|
||||
" print(\"❌ HF_TOKEN is missing\")\n",
|
||||
"\n",
|
||||
"login(hf_token, add_to_git_credential=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "31d3aa97-68a8-4f71-a43f-107f7c8553c5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 📥 Load Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f2bae96a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# #If you face NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported run:\n",
|
||||
"# %pip install -U datasets"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c45e23d6-1304-4859-81f0-35a9ddf1c755",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"HF_USER = \"lisekarimi\"\n",
|
||||
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
|
||||
"\n",
|
||||
"dataset = load_dataset(DATASET_NAME)\n",
|
||||
"train = dataset['train']\n",
|
||||
"test = dataset['test']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "667adda8-add8-41b6-9e60-7870bad20c02",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b85d86d0-b6b1-49cd-9ef0-9214c1267199",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 🛠️ Step 1 : Data Preparation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d3ba760d-467a-4cd9-8d3f-e6ce84273610",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To fine-tune GPT-4o-mini, OpenAI requires training data in **.jsonl format**. \n",
|
||||
"\n",
|
||||
"`make_jsonl` converts our chat data :\n",
|
||||
"\n",
|
||||
"from \n",
|
||||
"\n",
|
||||
"[\n",
|
||||
" {\"role\": \"system\", \"content\": \"You estimate prices of items. Reply only with the price, no explanation\"},\n",
|
||||
" {\"role\": \"user\", \"content\": \"How much is this laptop worth?\"},\n",
|
||||
" {\"role\": \"assistant\", \"content\": \"Price is $999.00\"}\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"into the .jsonl format \n",
|
||||
"\n",
|
||||
"{\"messages\": [{\"role\": \"system\", \"content\": \"You estimate prices of items. Reply only with the price, no explanation\"}, {\"role\": \"user\", \"content\": \"How much is this laptop worth?\"}, {\"role\": \"assistant\", \"content\": \"Price is $999.00\"}]}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec254755-67f6-4676-b67f-c1376ea00124",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Mask the price in the test item\n",
|
||||
"def mask_price_value(text):\n",
|
||||
" return re.sub(r\"(\\n\\nPrice is \\$).*\", r\"\\1\", text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e5e51957-b0ec-49f9-ae70-74771a101756",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def messages_for(datapoint):\n",
|
||||
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
|
||||
" user_prompt = mask_price_value(datapoint[\"text\"]).replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\",\"\")\n",
|
||||
" assistant_response = f\"Price is ${datapoint['price']:.2f}\"\n",
|
||||
" return [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
||||
" {\"role\": \"assistant\", \"content\": assistant_response}\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"messages_for(train[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "03583d32-b0f2-44c0-820e-62c8e7e48247",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def make_jsonl(datapoints):\n",
|
||||
" result = \"\"\n",
|
||||
" for datapoint in datapoints:\n",
|
||||
" messages = messages_for(datapoint)\n",
|
||||
" messages_str = json.dumps(messages, ensure_ascii=False)\n",
|
||||
" result += '{\"messages\": ' + messages_str + '}\\n'\n",
|
||||
" return result.strip()\n",
|
||||
"\n",
|
||||
"make_jsonl(train.select([0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "36c9cf60-0bcb-44cb-8df6-ff2ed4110cd2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ft_train = train.select(range(100))\n",
|
||||
"ft_validation = train.select(range(100, 150))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "494eaecd-ae5d-4396-b694-6faf88fb7fd6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Convert the items into jsonl and write them to a file\n",
|
||||
"\n",
|
||||
"def write_jsonl(datapoints, filename):\n",
|
||||
" with open(filename, \"w\", encoding=\"utf-8\") as f:\n",
|
||||
" jsonl = make_jsonl(datapoints)\n",
|
||||
" f.write(jsonl)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ae42986d-ab02-4a11-aa0c-ede9c63ec7a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"write_jsonl(ft_train, \"data/ft_train.jsonl\")\n",
|
||||
"write_jsonl(ft_validation, \"data/ft_val.jsonl\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b9bed22d-73ad-4820-a983-cbdccd8dbbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"data/ft_train.jsonl\", \"rb\") as f:\n",
|
||||
" train_file = openai.files.create(file=f, purpose=\"fine-tune\")\n",
|
||||
"with open(\"data/ft_val.jsonl\", \"rb\") as f:\n",
|
||||
" validation_file = openai.files.create(file=f, purpose=\"fine-tune\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1e6c6ce8-6600-4068-9ec5-32c6428ce9ea",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_file"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "26943fad-4301-4bb4-97e8-be52a9743322",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"validation_file"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "edb0a3ec-1607-4c5b-ab06-852f951cae8b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 🚀 Step 2: Run Fine-Tuning & Monitor with wandb\n",
|
||||
"We will use https://wandb.ai to monitor the training runs\n",
|
||||
"\n",
|
||||
"1- Create an API key in wandb\n",
|
||||
"\n",
|
||||
"2- Add this key in OpenAI dashboard https://platform.openai.com/account/organization"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "59f552fe-5e80-4742-94a8-5492556a6543",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"wandb_integration = {\"type\": \"wandb\", \"wandb\": {\"project\": \"gpt-pricer\"}}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "144088d7-7c30-439a-9282-1e6096c181ea",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Run the fine tuning\n",
|
||||
"\n",
|
||||
"openai.fine_tuning.jobs.create(\n",
|
||||
" training_file=train_file.id,\n",
|
||||
" validation_file=validation_file.id,\n",
|
||||
" model=\"gpt-4o-mini-2024-07-18\",\n",
|
||||
" seed=42,\n",
|
||||
" hyperparameters={\"n_epochs\": 1},\n",
|
||||
" integrations = [wandb_integration],\n",
|
||||
" suffix=\"pricer\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "330e75f5-0208-4c74-8dd3-07bc06047b2e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_id = openai.fine_tuning.jobs.list(limit=1).data[0].id\n",
|
||||
"job_id\n",
|
||||
"\n",
|
||||
"# Then check your wandb dashboard to view the run of this job ID"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4a92dac5-e6d8-439c-b55e-507becb37a6c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Use this command to track the fine-tuning progress here\n",
|
||||
"\n",
|
||||
"openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job_id, limit=2).data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b6b65677-06b2-47d3-b0e6-51210a3d832b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 📧 You’ll get an email once fine-tuning is complete. ☕ You can take a break until then. ▶️ Once you receive it, run the cells below to continue."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0a7af4be-0b55-4654-af7a-f47485babc52",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 3 : Test the fine tuned model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c8497eb8-49ee-4a05-9e51-fc1b4b2b41d4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ft_model_name = openai.fine_tuning.jobs.retrieve(job_id).fine_tuned_model\n",
|
||||
"ft_model_name"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "12bed33f-be31-4d7c-8651-3f267c529304",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can find the entire fine-tuning process in the **Fine-tuning** dashboard on OpenAI.\n",
|
||||
"\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ac6a89ef-f982-457a-bad7-bd84b6132a07",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Build LLM messages\n",
|
||||
"def build_messages(datapoint):\n",
|
||||
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
|
||||
" user_prompt = mask_price_value(datapoint[\"text\"]).replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\",\"\")\n",
|
||||
" return [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
||||
" {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"def get_price(s):\n",
|
||||
" s = s.replace('$','').replace(',','')\n",
|
||||
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
|
||||
" return float(match.group()) if match else 0\n",
|
||||
"\n",
|
||||
"def gpt_ft(datapoint):\n",
|
||||
" response = openai.chat.completions.create(\n",
|
||||
" model=ft_model_name,\n",
|
||||
" messages=build_messages(datapoint),\n",
|
||||
" seed=42,\n",
|
||||
" max_tokens=7\n",
|
||||
" )\n",
|
||||
" reply = response.choices[0].message.content\n",
|
||||
" return get_price(reply)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "93a93017-458c-4769-b81c-b2dad2af7552",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(test[0][\"price\"])\n",
|
||||
"print(gpt_ft(test[0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "87a5ad10-ed60-4533-ad61-225ceb847e6c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"🔔 **Reminder:** \n",
|
||||
"- In **Part 2**, GPT-4o Mini (zero-shot) scored: \n",
|
||||
" Avg. Error: ~$99 | RMSLE: 0.75 | Accuracy: 44.8% \n",
|
||||
"\n",
|
||||
"- In **Part 3**, with **RAG**, performance improved to: \n",
|
||||
" Avg. Error: ~$59.54 | RMSLE: 0.42 | Accuracy: 69.2%\n",
|
||||
"\n",
|
||||
"🧪 **Now it’s time to see** if fine-tuning can push GPT-4o Mini even further and outperform both baselines."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0adf1500-9cc7-491a-9ea6-88932af85dca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import helpers.testing\n",
|
||||
"importlib.reload(helpers.testing)\n",
|
||||
"\n",
|
||||
"from helpers.testing import Tester # noqa: E402\n",
|
||||
"\n",
|
||||
"tester = Tester(gpt_ft, test)\n",
|
||||
"tester.run()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "37439666",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Gpt Ft Error=$129.16 RMSLE=0.94 Hits=35.2%"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5487da30-e1a8-4db5-bf17-80bc4f109524",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Fine-tuning GPT-4o Mini led to worse performance than both its zero-shot and RAG-enhanced versions.**\n",
|
||||
"\n",
|
||||
"⚠️ When Fine-Tuning Isn’t Needed:\n",
|
||||
"- For tasks like price prediction, GPT-4o performs well with prompting alone — thanks to strong pretraining and generalization.\n",
|
||||
"- 💡 Fine-tuning isn’t always better. Use it when prompting fails — not by default.\n",
|
||||
"\n",
|
||||
"✅ **When Fine-Tuning Is Worth It (based on OpenAI’s own guidelines)**\n",
|
||||
"- Custom tone/style – e.g., mimicking a brand voice or writing like a specific author\n",
|
||||
"- More consistent output – e.g., always following a strict format\n",
|
||||
"- Fix prompt failures – e.g., when multi-step instructions get ignored\n",
|
||||
"- Handle edge cases – e.g., rare product types or weird inputs\n",
|
||||
"- Teach new tasks – e.g., estimating prices in a custom format no model has seen before\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"Now that we’ve explored both frontier closed-source models and traditional ML, it’s time to turn to open-source.\n",
|
||||
"\n",
|
||||
"🚀 **Next up: Fine-tuned LLaMA 3.1 8B (quantized)** — can it beat its base version, outperform GPT-4o Mini, or even challenge the big players?\n",
|
||||
"\n",
|
||||
"🔍 Let’s find out in the [next notebook](https://github.com/lisekarimi/lexo/blob/main/09_part5_llama31_8b_quant.ipynb)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
1500
week6/community-contributions/lisekarimi/data/human_output.csv
Normal file
1500
week6/community-contributions/lisekarimi/data/human_output.csv
Normal file
File diff suppressed because it is too large
Load Diff
120
week6/community-contributions/lisekarimi/helpers/items.py
Normal file
120
week6/community-contributions/lisekarimi/helpers/items.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import Optional # A variable might be a certain type or None
|
||||
from transformers import AutoTokenizer
|
||||
import re
|
||||
|
||||
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B"
|
||||
|
||||
MIN_TOKENS = 150 # Minimum tokens required to accept an item
|
||||
MAX_TOKENS = 160 # We limit to 160 tokens so that after adding prompt text, the total stays around 180 tokens.
|
||||
|
||||
MIN_CHARS = 300 # Reject items with less than 300 characters
|
||||
CEILING_CHARS = MAX_TOKENS * 7 # Truncate long text to about 1120 characters (approx 160 tokens)
|
||||
|
||||
class Item:
|
||||
"""
|
||||
An Item is a cleaned, curated datapoint of a Product with a Price
|
||||
"""
|
||||
|
||||
# Load tokenizer for the model
|
||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
|
||||
|
||||
# Define PRICE_LABEL and question for the training prompt
|
||||
PRICE_LABEL = "Price is $"
|
||||
QUESTION = "How much does this cost to the nearest dollar?"
|
||||
|
||||
# A list of useless phrases to remove to reduce noise for price prediction
|
||||
REMOVALS = ['"Batteries Included?": "No"', '"Batteries Included?": "Yes"', '"Batteries Required?": "No"', '"Batteries Required?": "Yes"', "By Manufacturer", "Item", "Date First", "Package", ":", "Number of", "Best Sellers", "Number", "Product "]
|
||||
|
||||
# Attributes for each item
|
||||
title: str
|
||||
price: float
|
||||
category: str
|
||||
token_count: int = 0 # How many tokens in the final prompt
|
||||
|
||||
# Optional fields
|
||||
details: Optional[str] # The value can be a string or can be None
|
||||
prompt: Optional[str] = None
|
||||
include = False # Whether to keep the item or not
|
||||
|
||||
def __init__(self, data, price):
|
||||
self.title = data['title']
|
||||
self.price = price
|
||||
self.parse(data)
|
||||
|
||||
def scrub_details(self):
|
||||
"""
|
||||
Removes useless phrases from details, which often has repeated specs or boilerplate text.
|
||||
"""
|
||||
details = self.details
|
||||
for remove in self.REMOVALS:
|
||||
details = details.replace(remove, "")
|
||||
return details
|
||||
|
||||
def scrub(self, stuff):
|
||||
"""
|
||||
Clean up the provided text by removing unnecessary characters and whitespace
|
||||
Also remove words that are 7+ chars and contain numbers, as these are likely irrelevant product numbers
|
||||
"""
|
||||
stuff = re.sub(r'[:\[\]"{}【】\s]+', ' ', stuff).strip()
|
||||
stuff = stuff.replace(" ,", ",").replace(",,,",",").replace(",,",",")
|
||||
words = stuff.split(' ')
|
||||
select = [word for word in words if len(word)<7 or not any(char.isdigit() for char in word)]
|
||||
return " ".join(select)
|
||||
|
||||
def parse(self, data):
|
||||
"""
|
||||
Prepares the text, checks length, tokenizes it, and sets include = True if it’s valid.
|
||||
"""
|
||||
# Builds a full contents string by combining description, features, and cleaned details.
|
||||
contents = '\n'.join(data['description'])
|
||||
if contents:
|
||||
contents += '\n'
|
||||
features = '\n'.join(data['features'])
|
||||
if features:
|
||||
contents += features + '\n'
|
||||
self.details = data['details']
|
||||
if self.details:
|
||||
contents += self.scrub_details() + '\n'
|
||||
|
||||
# If content is long enough, trim it to max char limit before processing.
|
||||
if len(contents) > MIN_CHARS:
|
||||
contents = contents[:CEILING_CHARS]
|
||||
|
||||
# Clean and tokenize text, then check token count.
|
||||
text = f"{self.scrub(self.title)}\n{self.scrub(contents)}"
|
||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
if len(tokens) > MIN_TOKENS:
|
||||
# Truncate tokens, decode them back and create the training prompt
|
||||
tokens = tokens[:MAX_TOKENS]
|
||||
text = self.tokenizer.decode(tokens)
|
||||
self.make_prompt(text)
|
||||
|
||||
# Mark the item as valid and ready to be used in training
|
||||
self.include = True # Only items with MIN_TOKENS <= tokens <= MAX_TOKENS are kept
|
||||
|
||||
|
||||
def make_prompt(self, text):
|
||||
"""
|
||||
Builds the training prompt using the question, text, and price. Then counts the tokens.
|
||||
"""
|
||||
self.prompt = f"{self.QUESTION}\n\n{text}\n\n"
|
||||
self.prompt += f"{self.PRICE_LABEL }{str(round(self.price))}.00"
|
||||
self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False))
|
||||
|
||||
def test_prompt(self):
|
||||
"""
|
||||
Returns the prompt without the actual price, useful for testing/inference.
|
||||
"""
|
||||
return self.prompt.split(self.PRICE_LABEL )[0] + self.PRICE_LABEL
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Defines how the Item object looks when printed — it shows the title and price.
|
||||
"""
|
||||
return f"<{self.title} = ${self.price}>"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
106
week6/community-contributions/lisekarimi/helpers/loaders.py
Normal file
106
week6/community-contributions/lisekarimi/helpers/loaders.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from datetime import datetime # Measure how long loading takes
|
||||
from tqdm import tqdm # Shows a progress bar while processing data
|
||||
from datasets import load_dataset # Load a dataset from Hugging Face Hub
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor # For parallel processing (speed)
|
||||
from items import Item
|
||||
|
||||
CHUNK_SIZE = 1000 # Process the dataset in chunks of 1000 datapoints at a time (for efficiency)
|
||||
MIN_PRICE = 0.5
|
||||
MAX_PRICE = 999.49
|
||||
WORKER = 4 # Set the number of workers here
|
||||
|
||||
class ItemLoader:
|
||||
|
||||
def __init__(self, name):
|
||||
"""
|
||||
Initialize the loader with a dataset name.
|
||||
"""
|
||||
self.name = name # Store the category name
|
||||
self.dataset = None #Placeholder for the dataset (we load it later in load())
|
||||
|
||||
def process_chunk(self, chunk):
|
||||
"""
|
||||
Convert a chunk of datapoints into valid Item objects.
|
||||
"""
|
||||
batch = [] # Initialize the list to hold valid items
|
||||
|
||||
# Loop through each datapoint in the chunk
|
||||
for datapoint in chunk:
|
||||
try:
|
||||
# Extract price from datapoint
|
||||
price_str = datapoint['price']
|
||||
if price_str:
|
||||
price = float(price_str)
|
||||
|
||||
# Check if price is within valid range
|
||||
if MIN_PRICE <= price <= MAX_PRICE:
|
||||
item = Item(datapoint, price)
|
||||
|
||||
# Keep only valid items
|
||||
if item.include:
|
||||
batch.append(item)
|
||||
except ValueError:
|
||||
continue # Skip datapoints with invalid price format
|
||||
return batch # Return the list of valid items
|
||||
|
||||
|
||||
def load_in_parallel(self, workers):
|
||||
"""
|
||||
Split the dataset into chunks and process them in parallel.
|
||||
"""
|
||||
results = []
|
||||
size = len(self.dataset)
|
||||
chunk_count = (size // CHUNK_SIZE) + 1
|
||||
|
||||
# Build chunks directly here (no separate function)
|
||||
chunks = [
|
||||
self.dataset.select(range(i, min(i + CHUNK_SIZE, size)))
|
||||
for i in range(0, size, CHUNK_SIZE)
|
||||
]
|
||||
|
||||
# Process chunks in parallel using multiple CPU cores
|
||||
with ProcessPoolExecutor(max_workers=workers) as pool:
|
||||
for batch in tqdm(pool.map(self.process_chunk, chunks), total=chunk_count):
|
||||
results.extend(batch)
|
||||
|
||||
# Add the category name to each result
|
||||
for result in results:
|
||||
result.category = self.name
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load(self, workers=WORKER):
|
||||
"""
|
||||
Load and process the dataset, returning valid items.
|
||||
"""
|
||||
# Record start time
|
||||
start = datetime.now()
|
||||
|
||||
# Print loading message
|
||||
print(f"Loading dataset {self.name}", flush=True)
|
||||
|
||||
# Load dataset from Hugging Face (based on category name)
|
||||
self.dataset = load_dataset(
|
||||
"McAuley-Lab/Amazon-Reviews-2023",
|
||||
f"raw_meta_{self.name}",
|
||||
split="full",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Process the dataset in parallel and collect valid items
|
||||
results = self.load_in_parallel(workers)
|
||||
|
||||
# Record end time and print summary
|
||||
finish = datetime.now()
|
||||
print(
|
||||
f"Completed {self.name} with {len(results):,} datapoints in {(finish-start).total_seconds()/60:.1f} mins",
|
||||
flush=True
|
||||
)
|
||||
|
||||
# Return the list of valid items
|
||||
return results
|
||||
|
||||
|
||||
|
||||
|
||||
84
week6/community-contributions/lisekarimi/helpers/testing.py
Normal file
84
week6/community-contributions/lisekarimi/helpers/testing.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import math
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
COLOR_MAP = {"red":RED, "orange": YELLOW, "green": GREEN}
|
||||
|
||||
class Tester:
|
||||
|
||||
def __init__(self, predictor, data, title=None, size=250):
|
||||
self.predictor = predictor
|
||||
self.data = data
|
||||
self.title = title or predictor.__name__.replace("_", " ").title()
|
||||
self.size = size
|
||||
self.guesses = []
|
||||
self.truths = []
|
||||
self.errors = []
|
||||
self.sles = []
|
||||
self.colors = []
|
||||
|
||||
def color_for(self, error, truth):
|
||||
if error<40 or error/truth < 0.2:
|
||||
return "green"
|
||||
elif error<80 or error/truth < 0.4:
|
||||
return "orange"
|
||||
else:
|
||||
return "red"
|
||||
|
||||
def run_datapoint(self, i):
|
||||
datapoint = self.data[i]
|
||||
guess = self.predictor(datapoint)
|
||||
truth = datapoint["price"]
|
||||
error = abs(guess - truth)
|
||||
log_error = math.log(truth+1) - math.log(guess+1)
|
||||
sle = log_error ** 2
|
||||
color = self.color_for(error, truth)
|
||||
title = datapoint["text"][:40] + "..." if len(datapoint["text"]) > 40 else datapoint["text"]
|
||||
self.guesses.append(guess)
|
||||
self.truths.append(truth)
|
||||
self.errors.append(error)
|
||||
self.sles.append(sle)
|
||||
self.colors.append(color)
|
||||
# print(f"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}")
|
||||
|
||||
def chart(self, title):
|
||||
max_error = max(self.errors)
|
||||
plt.figure(figsize=(15, 6))
|
||||
max_val = max(max(self.truths), max(self.guesses))
|
||||
plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)
|
||||
plt.scatter(self.truths, self.guesses, s=3, c=self.colors)
|
||||
plt.xlabel('Ground Truth')
|
||||
plt.ylabel('Model Estimate')
|
||||
plt.xlim(0, max_val)
|
||||
plt.ylim(0, max_val)
|
||||
plt.title(title)
|
||||
|
||||
# Add color legend
|
||||
from matplotlib.lines import Line2D
|
||||
legend_elements = [
|
||||
Line2D([0], [0], marker='o', color='w', label='Accurate (green)', markerfacecolor='green', markersize=8),
|
||||
Line2D([0], [0], marker='o', color='w', label='Medium error (orange)', markerfacecolor='orange', markersize=8),
|
||||
Line2D([0], [0], marker='o', color='w', label='High error (red)', markerfacecolor='red', markersize=8)
|
||||
]
|
||||
plt.legend(handles=legend_elements, loc='upper left')
|
||||
plt.show()
|
||||
|
||||
def report(self):
|
||||
average_error = sum(self.errors) / self.size
|
||||
rmsle = math.sqrt(sum(self.sles) / self.size)
|
||||
hits = sum(1 for color in self.colors if color=="green")
|
||||
title = f"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%"
|
||||
self.chart(title)
|
||||
|
||||
def run(self):
|
||||
self.error = 0
|
||||
for i in range(self.size):
|
||||
self.run_datapoint(i)
|
||||
self.report()
|
||||
|
||||
@classmethod
|
||||
def test(cls, function, data):
|
||||
cls(function, data).run()
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,907 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 🔍 Predicting Item Prices from Descriptions (Part 6)\n",
|
||||
"---\n",
|
||||
"- Data Curation & Preprocessing\n",
|
||||
"- Model Benchmarking – Traditional ML vs LLMs\n",
|
||||
"- E5 Embeddings & RAG\n",
|
||||
"- Fine-Tuning GPT-4o Mini\n",
|
||||
"- Evaluating LLaMA 3.1 8B Quantized\n",
|
||||
"- ➡️ Fine-Tuning LLaMA 3.1 with QLoRA\n",
|
||||
"- Evaluating Fine-Tuned LLaMA\n",
|
||||
"- Summary & Leaderboard\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"# ⚙️ Part 6: Fine-Tuning LLaMA 3.1 with QLoRA\n",
|
||||
"\n",
|
||||
"- 🧑💻 Skill Level: Advanced\n",
|
||||
"- ⚙️ Hardware: ⚠️ GPU required - use Google Colab (A100)\n",
|
||||
"- 🛠️ Requirements: 🔑 HF Token, wandb API Key ([Weights & Biases](https://wandb.ai))\n",
|
||||
"- Tasks:\n",
|
||||
" - Load and split dataset (Train/validation); set up [Weights & Biases](https://wandb.ai) logging\n",
|
||||
" - Load quantized LLaMA 3.1 8B and tokenizer\n",
|
||||
" - Prepare data with a collator for fine-tuning\n",
|
||||
" - Configure QLoRA (LoRAConfig), training settings (SFTConfig), and tune key hyperparameters\n",
|
||||
" - Fine-tune and push best model to Hugging Face Hub\n",
|
||||
"\n",
|
||||
"⚠️ I attempted to fine-tune the model on the full 400K dataset using an A100 on Google Colab, but it consistently crashed. So for now, I’m training on a 20K subset to understand the process, play with hyperparameters, track progress in Weights & Biases, and push the best checkpoint to the Hub.\n",
|
||||
"\n",
|
||||
"⏱️ Training on 20,000 examples took over 2 hours.\n",
|
||||
"\n",
|
||||
"The full model fine-tuned on the complete 400K dataset is available thanks to our instructor, [Ed](https://www.linkedin.com/in/eddonner) — much appreciated! \n",
|
||||
"We’ll dive into that model in the next notebook — **stay tuned** 😉\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "MDyR63OTNUJ6",
|
||||
"outputId": "525372ce-f614-44f1-b894-80e289958197"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install required packages in Google Colab\n",
|
||||
"%pip install -q datasets transformers torch peft bitsandbytes trl accelerate wandb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "-yikV8pRBer9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"import wandb\n",
|
||||
"from google.colab import userdata\n",
|
||||
"from datetime import datetime\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, EarlyStoppingCallback\n",
|
||||
"from peft import LoraConfig\n",
|
||||
"from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Google Colab User Data\n",
|
||||
"# Ensure you have set the following in your Google Colab environment:\n",
|
||||
"hf_token = userdata.get('HF_TOKEN')\n",
|
||||
"login(hf_token, add_to_git_credential=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "B48QsPsvUs_x"
|
||||
},
|
||||
"source": [
|
||||
"## 🔀 Load Dataset from HF and Split into Train/Validation"
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# #If you face NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported run:\n",
|
||||
"# %pip install -U datasets (for Google Colab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 177,
|
||||
"referenced_widgets": [
|
||||
"6f1f8dca2a334818a36fae380818001e",
|
||||
"6d3be1ece4a949d3b8d3736db02bcb5c",
|
||||
"c8c6bbacfe254c539f4acda8cdd5c04d",
|
||||
"db87c136ff15430892aa75fa47521b0c",
|
||||
"1d56af1140034021b2aecc5df846e499",
|
||||
"6238783102084e0c99626bf948ff5bb6",
|
||||
"f523b67e652049f7b13131d2750325bb",
|
||||
"f03cc2cf18c140c8b4a076ab99ac86e3",
|
||||
"472bb957b0e149df8ef0c26c3a3ffc19",
|
||||
"86dfcc161f2d41a7a33041848766d091",
|
||||
"6a7ed9e79ebb4f9c9962d08c78b424ca",
|
||||
"efc4817d5f734852a844640ebe7eceed",
|
||||
"0b473a8e944c4b028f51f53f62b72deb",
|
||||
"1fd89859568440f58f3ab56f32183dd4",
|
||||
"2e4bd8853acc4faa92e461210df2c689",
|
||||
"3fb588f271db4b7abb9a3631582cc7d6",
|
||||
"8f9c00ca63ca47e9873ec2a743fa1512",
|
||||
"afdae504b36845b9a98874cced112721",
|
||||
"8afd0ddfdeca43b59207a8b35a35e13c",
|
||||
"0be7a6fdb206420d88b2b2e45a37432c",
|
||||
"00f0983c1d204862b589011100297ffe",
|
||||
"8c7de85bcec742ec85f1e8b854351056",
|
||||
"5847c75b6dd74bc1b13116d91431ccf2",
|
||||
"bcb0ad86493f45848895c02c0b9deaf6",
|
||||
"18d70754531248b1ab22e1fd0df061ae",
|
||||
"028d806f909f42e2b6a7ec630f6e3cb5",
|
||||
"ff00d3192c734b398f779c7fffde57c8",
|
||||
"55388dcb89f84c7ebe7f5f7051f2d98b",
|
||||
"d3cab2b162a740fb82f78f030ea32b45",
|
||||
"cea0149336be4c92952bacb8aa820926",
|
||||
"6b560f8a028c4ba39896fd97f48f18ad",
|
||||
"2a3ed922dab44648b6d6ed63e21c549d",
|
||||
"885e1f4b9c3d45d5acd8d0a368ca557d",
|
||||
"73e42dca7c4b455f8be4b34236e6ced2",
|
||||
"c36aec28025e4baab8a3c4a293297f15",
|
||||
"7569e26e1e2b46e4a7018e1bd2bc92d5",
|
||||
"9f5795d223e74f1e8e49709ec1e4ddf1",
|
||||
"5638ccb893164fc79980eb48d06909f9",
|
||||
"70a528a0a08e4931b845ecc0992e07d6",
|
||||
"669bbecd55804849bff5a850438d905d",
|
||||
"245de1eaef2840b69e6c82afee68b4dc",
|
||||
"ad57405b8f474c0aa92833f83dde70e8",
|
||||
"cb3391329a7f4d0b93f5efffb9b0dcfe",
|
||||
"cb0007dffa284be8aff41efacdfc31cb",
|
||||
"c7de048747a24f9a9ce85396b87b8250",
|
||||
"066b3f278ec24b299504cea66b3c3e63",
|
||||
"0e1069c5bf644531902c51283a6d68e1",
|
||||
"06bd7477f9fe45d0ad4138fc21bd29dc",
|
||||
"adb68e7a8bea4b77b960e412c67a6286",
|
||||
"39ec099d38f04f4e8ea334d0c5335e2f",
|
||||
"044bf34d53024427801e24fbca808dc1",
|
||||
"e3d2839112ff4b7f9ab5bc04900ff522",
|
||||
"f620e7774fa04ed0a88d2f78d2243906",
|
||||
"7a12c0d7b32b445f978809c9aee2c62d",
|
||||
"5a230441445746d59ea8a10a4d5bb467"
|
||||
]
|
||||
},
|
||||
"id": "XEE1FrSIh-EF",
|
||||
"outputId": "8cd19745-2f6f-41e0-96dd-5a2f72ac3a63"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"HF_USER = \"lisekarimi\" # your HF name here!\n",
|
||||
"\n",
|
||||
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
|
||||
"dataset = load_dataset(DATASET_NAME)\n",
|
||||
"train = dataset['train']\n",
|
||||
"test = dataset['test']\n",
|
||||
"split_ratio = 0.1 # 10% for validation\n",
|
||||
"\n",
|
||||
"##############################################################################\n",
|
||||
"# Optional: limit training dataset to TRAIN_SIZE for testing/debugging\n",
|
||||
"# Comment the two lines below to use the full dataset\n",
|
||||
"TRAIN_SIZE = 20000\n",
|
||||
"train = train.select(range(TRAIN_SIZE))\n",
|
||||
"##############################################################################\n",
|
||||
"\n",
|
||||
"total_size = len(train)\n",
|
||||
"val_size = int(total_size * split_ratio)\n",
|
||||
"\n",
|
||||
"val_data = train.select(range(val_size))\n",
|
||||
"train_data = train.select(range(val_size, total_size))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "lUPNqb2Bse21",
|
||||
"outputId": "a3d09c8f-ce5a-46b0-e1b0-b4471a659f69"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(f\"Train data size : {len(train_data)}\")\n",
|
||||
"print(f\"Validation data size: {len(val_data)}\")\n",
|
||||
"print(f\"Test data size : {len(test)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "wixbM-VeVfsR"
|
||||
},
|
||||
"source": [
|
||||
"## 🛠️ Hugging Face Configuration"
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"id": "OixVUG06VmZk",
|
||||
"outputId": "3cb523e0-fd03-4a18-913b-c22fa90e3bdd"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PROJECT_NAME = \"llama3-pricer\"\n",
|
||||
"\n",
|
||||
"# Run name for saving the model in the hub\n",
|
||||
"\n",
|
||||
"RUN_NAME = f\"{datetime.now():%Y-%m-%d_%H.%M.%S}-size{total_size}\"\n",
|
||||
"PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
|
||||
"HUB_MODEL_NAME = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
|
||||
"HUB_MODEL_NAME"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "1-t1nGgnVTU4"
|
||||
},
|
||||
"source": [
|
||||
"## 🛠️ wandb Configuration"
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load from Colab's secure storage\n",
|
||||
"wandb_api_key = userdata.get('WANDB_API_KEY')\n",
|
||||
"\n",
|
||||
"# Load from environment variables (.env file) if running Locally (GPU setup)\n",
|
||||
"# wandb_api_key = os.getenv('WANDB_API_KEY')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ[\"WANDB_API_KEY\"] = wandb_api_key\n",
|
||||
"wandb.login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 156
|
||||
},
|
||||
"id": "yJNOv3cVvJ68",
|
||||
"outputId": "0c03623e-6887-49e3-8989-bbe45dfc5d35"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Configure Weights & Biases to record against our project\n",
|
||||
"\n",
|
||||
"LOG_TO_WANDB = True\n",
|
||||
"\n",
|
||||
"os.environ[\"WANDB_PROJECT\"] = PROJECT_NAME\n",
|
||||
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" if LOG_TO_WANDB else \"end\"\n",
|
||||
"os.environ[\"WANDB_WATCH\"] = \"gradients\"\n",
|
||||
"\n",
|
||||
"if LOG_TO_WANDB:\n",
|
||||
" wandb.init(project=PROJECT_NAME, name=RUN_NAME)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "qJWQ0a3wZ0Bw"
|
||||
},
|
||||
"source": [
|
||||
"## 📥 Load the Tokenizer and Model"
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 418,
|
||||
"referenced_widgets": [
|
||||
"1b88f6d4010f4451a58abe2c46b74f62",
|
||||
"139758ba39964f49b65eb67182eef68e",
|
||||
"9c138d12dcb644fe9b72bd9eb5d26637",
|
||||
"3bf8626162904a15932480ddbcea0ebd",
|
||||
"a919a41b53604ccd91331d3f713e1310",
|
||||
"5b8cdfe01f9a4c248e3de30442411ad4",
|
||||
"e14d38a4c3e04d68ac30d475b0db1a73",
|
||||
"dadfd3c2a521420890092be265c0aa50",
|
||||
"761e88b179104dbbb6455ba81bd1f833",
|
||||
"11f5b4df0c7344ba9e188f4eca82886f",
|
||||
"125aa3f0dbd744eb82f8e4de94199736",
|
||||
"6ca21586e6fc4a608adedba7889eadb5",
|
||||
"023eb92e8a2b4323bfd12582e3c23962",
|
||||
"c7c76b9845174e9687107595df27c050",
|
||||
"78d4a28e03db4775b6e8e071c0b02d5d",
|
||||
"8483c625762c49679877a37ab0ddcef9",
|
||||
"1df5f6fe2fc04e60bfcb1f78689824ba",
|
||||
"add10c416e334928af303d51dfd745c6",
|
||||
"5e9e9dac85014292b94d347cc4bad3fe",
|
||||
"d665aa6480624ab697f4e426b51d59de",
|
||||
"03cce0d3f3a443fc808915b101576e4b",
|
||||
"f15714023f234c39863b34d1a3721a8e",
|
||||
"8f7a48d803eb4d2182c9da07af743ac7",
|
||||
"74892e7b343d410bbbef60c64a823a9a",
|
||||
"d6a70560831144e39dc9762d397f4c90",
|
||||
"9b969f7fbcdc491cab71aac42761cd2a",
|
||||
"d31f9443d1c646309c7a5e1ec39ffc0e",
|
||||
"0f5a81846ab143bebf6ec422cda3f145",
|
||||
"f0b05f3f7f37414c9d09470c94e304d7",
|
||||
"d18784692c9c4ca99e277e6ed51e2bf1",
|
||||
"f58addfac7c3438a90ebf10c88348d56",
|
||||
"451deac2eeec45598590579340be0d4b",
|
||||
"848e0651caf34ef288cca451e3d11274",
|
||||
"5adf041222f843429c3a9f1b99becfed",
|
||||
"a4764f36570b4752a1ec4392d2f0146c",
|
||||
"511a4c6a898346acac9d98fd3a7cdf2c",
|
||||
"26da7435a2614201a9e5b8087749f0e0",
|
||||
"6054fa015ae44659beb7473c084c7b5b",
|
||||
"3b9fc447a9ae4506a1edaf0fa449d9d5",
|
||||
"6acef8f1820545ef90b22d90ac80427d",
|
||||
"2a5cbad0b8fd45dc9ee25715b1015aef",
|
||||
"86a9428f39be4d65a1e922bd9afb3800",
|
||||
"96d919a1a7f14e91b8e6c91d855e36d5",
|
||||
"82d7484aa2774015b7ea18d933afa9b6",
|
||||
"b9d2d4f2c44a4d7cad2b3803c7f6e7be",
|
||||
"9f3a176a6ae6426a8c1567a835da8680",
|
||||
"006763d2301f4205a588adf5c19876a0",
|
||||
"b44eb6596c3441bbaab288030f953a04",
|
||||
"bf91666a0c054c79acb03d2e1bb38c37",
|
||||
"f0185f1b4b23445c920a873eb63a9372",
|
||||
"8e1ac15b677d4c21ad42ea1dda68fe05",
|
||||
"87746d8d6d3d413ebb46b4e12fb74cc8",
|
||||
"bb5ea1e92c434a46838f943648de87bd",
|
||||
"1abcfcba332b40eb901d1331ed84f9bd",
|
||||
"52fa5fcc629742619fa3105f73d90767",
|
||||
"1bcc2d5771034c2dbc372031e83a2384",
|
||||
"221cfaa2a5db4cf1ac399363c3589025",
|
||||
"793f9bdc92a545519dd3279023e4ab50",
|
||||
"55e25f5cc12f44f3a39fae501fccd060",
|
||||
"59463b5e6286483394dedb602991ac95",
|
||||
"fc95344ea44d40f28702360542afcff7",
|
||||
"ffb3af537d6c41548ad88027505b04d6",
|
||||
"6afcf0f6131d4dddbeda796e9c0c5bc5",
|
||||
"93f65b3bc071453f86fe8f0f6c17d8fd",
|
||||
"2ac9926ee4644232b43d84cfa95c584d",
|
||||
"0c5a7738132b4f0f8b4810333b37c588",
|
||||
"99d41ffa37134be9a57fe5e50a59b67d",
|
||||
"50e71304ab4f42c29f1994fed9b595b8",
|
||||
"76b4b0d63e524eb783429169a25be74e",
|
||||
"441cfadbe4b446f4b61391b7be4d2865",
|
||||
"6751f0c35b634d7c9b06c4e41f9ff851",
|
||||
"6a5dc276bbf64bf9b5a99751068ee228",
|
||||
"b3ac6055014642a285435f877d5651f5",
|
||||
"e9137600b29c4ecaad4ef8bca5fd5f91",
|
||||
"634afb9c1b8c4e29b3ec7b76a1108ae4",
|
||||
"6be0ac91035548fbbe778e3d7fd58e7e",
|
||||
"e8e9d5c979ac4afba526e38b6d0851be",
|
||||
"a4ae8ca9c0e7478fbad3b9ed67bc21a2",
|
||||
"faf3a64e316a43ddbac8ba14573c4eb4",
|
||||
"a395885e39434f9f98246d0fb1c94c8f",
|
||||
"d13552c90ead4804a4d5a21121f25536",
|
||||
"c25b94002c2246a9aa7f6ed1e4a22cfa",
|
||||
"e3892cf602cb4a49948f26cae1e7644c",
|
||||
"bc290a324a7147c5b6a722acb41ed05a",
|
||||
"2b556f5aa6324958ac6fe36bddf17909",
|
||||
"67c6a0534b3a4345b9c11af1bffdfbf0",
|
||||
"d767921bb23c485396282cb79a4d1836",
|
||||
"d598468ad8f94146976f70d873f0b56d",
|
||||
"b547888cd5494b21911b7d457ab6fbac",
|
||||
"28362e43274848109c2624e5668942b0",
|
||||
"7a27fc65bc0b44ce9bd959f4be13514d",
|
||||
"73bc97e6d9cc4ccd8d134092ce970026",
|
||||
"c042bf08ab23410098e6d16e837d19ce",
|
||||
"d2930ad2c08748d0883bb77c68acf940",
|
||||
"c2a1291730874e8e94232c0d51575f81",
|
||||
"cb92871b11a0410eb295cc323e5872a7",
|
||||
"150a5ce5d8124b0eb9e44d8715b8b1ab",
|
||||
"7a6f05ad1f2e483dbcdca102c66530b0",
|
||||
"626a29aee42e4e6d8c18d8ea5889734a",
|
||||
"c549ca0548d04a7d8749a0842c4aa62b",
|
||||
"958c0ff0f47f4c0fa4e2085f5243d84f",
|
||||
"a8171febcac94a4b902ff737592f3f47",
|
||||
"22630cdb7d6f4975bc31cc189987573d",
|
||||
"2f8a9ccee6ea4cdd8c8c225575cae0ce",
|
||||
"e40f81c5c4334accbca947964146d238",
|
||||
"d6849da8e89546469188dc047c66ea25",
|
||||
"8a67d8a2ac0a4fd7a41aa5c890049525",
|
||||
"5bf18445be0e46e087cbcd377ccfffbe",
|
||||
"72b2020c9479471681ce0f42898cfe1c",
|
||||
"c114fd62eb4b4fdca94654668c8f2374",
|
||||
"401580df26fc40abb2b774c3d9684921",
|
||||
"e756b825b211476994a69fb65f4bbf7c",
|
||||
"b2c26cf10e5a4d4fa8961f5c9cca18ce",
|
||||
"c288256c73dd44d08916db4e9cf989f0",
|
||||
"250a72e9650845d2b274bc3c157439f8",
|
||||
"94281c7e5be049c1a9f3dfa082805133",
|
||||
"f004f9f743ae4229aa90c92abba6ded6",
|
||||
"bd8ca5b8aaed4809a93f553d5cb4a887",
|
||||
"4cec4c2d73de4d52b2143082645536ac",
|
||||
"893b96616a0e47bfaa0434e10eca1341",
|
||||
"74e7d88dd4894894ac2c16fdfd29233b",
|
||||
"9e1f1e4288df407fa03415664dc361d5",
|
||||
"81dc3f390b9a49f8b1be5c43580b070d",
|
||||
"917a225a9bb74f8ab034dcdcee3c7247",
|
||||
"bc6c698857ce4f8eabc1571ba0ff0edf",
|
||||
"e9ae1c247ae5409f9da4db84ce71a6e3",
|
||||
"55071660223e4022a6a7836572077c0c",
|
||||
"8364e661011743af9fd40dabc5a7dfe4",
|
||||
"ac65442e0d5e43e2998d7c700573228a",
|
||||
"666f3434ae8a495f8ada8fedb50b7051",
|
||||
"1977e9f07f104faead7dfcfa8aaed6f2",
|
||||
"ebe2257c07f345fea72f162542a45142"
|
||||
]
|
||||
},
|
||||
"id": "R_O04fKxMMT-",
|
||||
"outputId": "29aa1cf7-2a2e-492e-adc9-cd0a5bfb123e"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
|
||||
"\n",
|
||||
"quant_config = BitsAndBytesConfig(\n",
|
||||
" load_in_4bit=True, # Reduce the precision to 4 bits\n",
|
||||
" bnb_4bit_use_double_quant=True,\n",
|
||||
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
||||
" bnb_4bit_quant_type=\"nf4\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
|
||||
"tokenizer.pad_token = tokenizer.eos_token\n",
|
||||
"tokenizer.padding_side = \"right\"\n",
|
||||
"\n",
|
||||
"base_model = AutoModelForCausalLM.from_pretrained(\n",
|
||||
" BASE_MODEL,\n",
|
||||
" quantization_config=quant_config,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
")\n",
|
||||
"base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
|
||||
"\n",
|
||||
"print(f\"Memory footprint: {base_model.get_memory_footprint() / 1e6:.1f} MB\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "SrCE2Le7RBRj"
|
||||
},
|
||||
"source": [
|
||||
"## ⚙️ Fine-tune our LLaMA 3 8B (4-bit quantized) model with QLoRA\n",
|
||||
"- 1. Prepare the Data with a Data Collator\n",
|
||||
"- 2. Define the QLoRA Configuration (LoraConfig)\n",
|
||||
"- 3. Set the Training Parameters (SFTConfig)\n",
|
||||
"- 4. Initialize the Fine-Tuning Trainer (SFTTrainer)\n",
|
||||
"- 5. Run Fine-Tuning and Push to Hub"
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "9BYO0If4uWys"
|
||||
},
|
||||
"source": [
|
||||
"### 🔄 1. Prepare the Data with a Data Collator\n",
|
||||
"\n",
|
||||
"We only want the model to learn the price, not the product description. Everything before \"Price is $\" is context, not training target. HuggingFace’s DataCollatorForCompletionOnlyLM handles this masking automatically:\n",
|
||||
"\n",
|
||||
"1. Tokenizes the response_template (\"Price is $\")\n",
|
||||
"2. Finds its token position in each input\n",
|
||||
"3. Masks all tokens before it (context)\n",
|
||||
"4. Trains the model only on tokens after it (the price)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Example:\n",
|
||||
"\n",
|
||||
"Input: \"Product: Red T-shirt. Price is $12.99\"\n",
|
||||
"\n",
|
||||
"Masked: \"Product: Red T-shirt. Price is $\" → masked (no loss)\n",
|
||||
"\n",
|
||||
"\"12.99\" → not masked (model is trained to predict this)\n",
|
||||
"\n",
|
||||
"So the model learns to generate 12.99 given the context, but isn’t trained to repeat or memorize the description."
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "2omVEaPIVJZa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response_template = \"Price is $\"\n",
|
||||
"collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "4DaOeBhyy9eS"
|
||||
},
|
||||
"source": [
|
||||
"### 🧠 2. Define the QLoRA Configuration (LoraConfig)"
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "0HKuVS_XR3cw"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"LORA_R = 32\n",
|
||||
"LORA_ALPHA = 64\n",
|
||||
"TARGET_MODULES = [\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"]\n",
|
||||
"LORA_DROPOUT = 0.1\n",
|
||||
"\n",
|
||||
"lora_parameters = LoraConfig(\n",
|
||||
" r=LORA_R,\n",
|
||||
" lora_alpha=LORA_ALPHA,\n",
|
||||
" target_modules=TARGET_MODULES,\n",
|
||||
" lora_dropout=LORA_DROPOUT,\n",
|
||||
" bias=\"none\",\n",
|
||||
" task_type=\"CAUSAL_LM\", # Specifies we're doing causal language modeling\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "uLfFsfNQSBAm"
|
||||
},
|
||||
"source": [
|
||||
"### ⚙️ 3. Set the Training Parameters (SFTConfig)"
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "7PKXdhPXSJot"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 📦 Training Setup:\n",
|
||||
"EPOCHS = 1\n",
|
||||
"BATCH_SIZE = 16 # A100 GPU can go up to 16\n",
|
||||
"GRADIENT_ACCUMULATION_STEPS = 2\n",
|
||||
"MAX_SEQUENCE_LENGTH = 182 # Max token length per input\n",
|
||||
"\n",
|
||||
"# ⚙️ Optimization:\n",
|
||||
"LEARNING_RATE = 1e-4\n",
|
||||
"LR_SCHEDULER_TYPE = 'cosine'\n",
|
||||
"WARMUP_RATIO = 0.03\n",
|
||||
"OPTIMIZER = \"paged_adamw_32bit\"\n",
|
||||
"\n",
|
||||
"# 💾 Checkpointing & Logging:\n",
|
||||
"SAVE_STEPS = 200 # Checkpoint\n",
|
||||
"STEPS = 20 # Log every 20 steps\n",
|
||||
"save_total_limit = 10 # Keep latest 10 only\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"LOG_TO_WANDB = True\n",
|
||||
"\n",
|
||||
"HUB_MODEL_NAME = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
|
||||
"\n",
|
||||
"train_parameters = SFTConfig(\n",
|
||||
" # Output & Run\n",
|
||||
" output_dir=PROJECT_RUN_NAME,\n",
|
||||
" run_name=RUN_NAME,\n",
|
||||
" dataset_text_field=\"text\",\n",
|
||||
" max_seq_length=MAX_SEQUENCE_LENGTH,\n",
|
||||
"\n",
|
||||
" # Training\n",
|
||||
" num_train_epochs=EPOCHS,\n",
|
||||
" per_device_train_batch_size=BATCH_SIZE,\n",
|
||||
" gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n",
|
||||
" max_steps=-1,\n",
|
||||
" group_by_length=True,\n",
|
||||
"\n",
|
||||
" # Evaluation\n",
|
||||
" eval_strategy=\"steps\",\n",
|
||||
" eval_steps=STEPS,\n",
|
||||
" per_device_eval_batch_size=1,\n",
|
||||
"\n",
|
||||
" # Optimization\n",
|
||||
" learning_rate=LEARNING_RATE,\n",
|
||||
" lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
|
||||
" warmup_ratio=WARMUP_RATIO,\n",
|
||||
" optim=OPTIMIZER,\n",
|
||||
" weight_decay=0.001,\n",
|
||||
" max_grad_norm=0.3,\n",
|
||||
"\n",
|
||||
" # Precision\n",
|
||||
" fp16=False,\n",
|
||||
" bf16=True,\n",
|
||||
"\n",
|
||||
" # Logging & Saving\n",
|
||||
" logging_steps=STEPS, # See loss after each {STEP} batches\n",
|
||||
" save_strategy=\"steps\",\n",
|
||||
" save_steps=SAVE_STEPS, # Model Checkpointed locally\n",
|
||||
" save_total_limit=save_total_limit,\n",
|
||||
" report_to=\"wandb\" if LOG_TO_WANDB else None,\n",
|
||||
"\n",
|
||||
" # Hub\n",
|
||||
" push_to_hub=True,\n",
|
||||
" hub_strategy=\"end\", # Only push once, at the end\n",
|
||||
" load_best_model_at_end=True, # Loads the best eval_loss checkpoint\n",
|
||||
" metric_for_best_model=\"eval_loss\", # Monitors eval_loss\n",
|
||||
" greater_is_better=False, # Lower eval_loss = better model\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "1q-a3LHDSoxQ"
|
||||
},
|
||||
"source": [
|
||||
"### 🧩 4. Initialize the Fine-Tuning Trainer (SFTTrainer)\n",
|
||||
"Combining everything"
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 290,
|
||||
"referenced_widgets": [
|
||||
"6753caf741414a4c8fa309978253c8cd",
|
||||
"aeade430d57b4338910ad0c3645fd06a",
|
||||
"eb7081b71cc14aff9b99dba8f9368def",
|
||||
"8eb16171df804d06a02351f74bb28dc4",
|
||||
"9d60a205ebda49ca88220cc4eec716ca",
|
||||
"d8ff973b90374423b4b5e17a1937111c",
|
||||
"4bf3bf107f2c4e28a58387c96916e97f",
|
||||
"d66cb8c1829c439095f4691fa32d7b6e",
|
||||
"567c8321685045c5a873b3b1edecdc96",
|
||||
"96ff596facb94acab611201b4adac13f",
|
||||
"de65507ce09a4ef4ad8f28d46d335acc",
|
||||
"e40fe92fe9094a58b53f0eeb97d3d629",
|
||||
"592615cc81624de5a9934f5671d6c188",
|
||||
"fadf75d91df54f49acef3f178ea53ce3",
|
||||
"5ccca8ab6cb94a88bb27bd482f7948a9",
|
||||
"d74dcc2ef9b8442d9ae99db2a79e0c48",
|
||||
"580ebfa370d34426933e8c7389872e2b",
|
||||
"1187f05dc99641e9a68d9cf49216c370",
|
||||
"7deffbba68ba4f018374bd6bec62dd18",
|
||||
"d24cdc40a6a34d6eb0efbfde17505d6f",
|
||||
"31d44a308b4b4557934ec887e0b6a817",
|
||||
"76112ce6fdc4496dba783451efa28cfd",
|
||||
"15a85e4a77484c9392b2e5cb8767b336",
|
||||
"4524d775b9034a1f890673a9c005d123",
|
||||
"5ab6a6b427f84ec685ac52f6ff0d63b5",
|
||||
"427ee9e90a844313989f623aba124498",
|
||||
"6d2b7c059e6b42afa955fe01bf38011d",
|
||||
"5d821ed8ffe14927be799c4d31043a82",
|
||||
"12f9fab59e9849dcb7b3b17c5674580f",
|
||||
"dd4a2876db37476fa438e8758c855393",
|
||||
"f115f97428764c53ac780131fd75bd17",
|
||||
"1a1e0e562a844ed098e97ce8a62695ee",
|
||||
"0a7ae7cc902243a5996f730f0fe05cdb",
|
||||
"07205ea24c3f4959bf9ebd393f5c921d",
|
||||
"723bb8342ac84eedabd91e3eef178967",
|
||||
"28714d0cf3d84a48975c8ad31e29691d",
|
||||
"dd1d90d76d914839a1dad1cddab2c09f",
|
||||
"e2d55edf98784523bcbeaad0cc2be494",
|
||||
"d00ecfa9dc44428b989ec1a9deb27eae",
|
||||
"ba2717985bc342e9827f8901ef655b00",
|
||||
"6669dc8f20e3461f93c95cef7a90b201",
|
||||
"29cb36c1943c4e1b9898534aaf32bd37",
|
||||
"14a1449c13a14afda16bc7c05b7fd840",
|
||||
"259d315eb4584c699b1c738d411eab7e",
|
||||
"a4bb13eb7cee4f87b0e3e1a3a1be18e7",
|
||||
"14d8a699a92044cda33802d96aaa41a2",
|
||||
"d345350fd5ad4a028fbbc45cfc9f6db3",
|
||||
"6953210353f840d59457fc54f4f8b829",
|
||||
"d6cd9e1196f04ecbba83dc0b446b2c65",
|
||||
"9e380ef863204da5863c9b6e7a2c8340",
|
||||
"1d1bb803831d46309619f6a0c51c2eeb",
|
||||
"6a50aaf7ad304a5aa3f29113121e8fe0",
|
||||
"7a573a39c2b245f5a84626d951584f67",
|
||||
"a57e66367d4245f6bcd4ad0463535583",
|
||||
"d6f3327d39a34ec5a44d976f239a61ce",
|
||||
"8f450df9f161409a8102c1f0b63edad8",
|
||||
"95d932d12cb8442da17adb8e9782c40c",
|
||||
"41c5f295b45f4828a9327b699b85ca01",
|
||||
"9e4f3fd6bf7749f88ccd7ba65dd9446f",
|
||||
"a8f8cb0d9fb14f30a537977f3d51a2c4",
|
||||
"4e9e4ed0f2db4d7ba5a5bb0d00676a0c",
|
||||
"1fe2bab9c9aa4de48e6e2512f9a7d0a1",
|
||||
"d93ac5affccf404fa3916e7f3dd62943",
|
||||
"92346fc65f48493d80198ac6d7adf4d8",
|
||||
"647bfb2a24cc44a0adaf69ced8e99213",
|
||||
"5c96424cff314aa484e4bc905bcbd761",
|
||||
"cec2fcfb30194d5ab8c0a3868bad3598",
|
||||
"35df7031c4964cef9c53bba6eabbe91d",
|
||||
"e15c772e14264c9889e6dae34015e04b",
|
||||
"e85b65cb497c48c2b844ae3e5d9efc60",
|
||||
"52c8495d46ca4a3c8c6694a700d05e95",
|
||||
"3db6d8a5ce2a40daaae6714807a27997",
|
||||
"051d74df7ef1468aa968cac5792e7b00",
|
||||
"75838a7c887545ff9fbbf5887a1336bc",
|
||||
"59f698c1829148ac90edda008d5c6f69",
|
||||
"35921436c69643aab792bd1333c749ef",
|
||||
"2dd51cc6033746e1a8def460e5e51ff5",
|
||||
"a8a3e5973ee5441087d10dfb17bfa1d6",
|
||||
"64c3b3c02e844df6bfd3acf1ee23d765",
|
||||
"83016eccdd7f4dedab9d3ea6e6852977",
|
||||
"9d4c5a62214f4649b77365349ae4ac88",
|
||||
"07cb9756d1814a7ba7fb49cccb2763cb",
|
||||
"492454ad524742bd8bb3f5c3d5b37feb",
|
||||
"e98053f6b7f045da812088d1e76d3a31",
|
||||
"f2aeb3ae99cc4b7ca97fb959df1150ad",
|
||||
"f92e18b6ab0147b1b428724f5155ca61",
|
||||
"14356b2447e349ee8478478eb231fa81",
|
||||
"f244a7e331d941f5a99712dcbc5550ea"
|
||||
]
|
||||
},
|
||||
"id": "fCwmDmkSATvj",
|
||||
"outputId": "2b4adc75-e0db-4e0b-c90b-9f9ff2dfd3c6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The latest version of trl is showing a warning about labels - please ignore this warning\n",
|
||||
"fine_tuning = SFTTrainer(\n",
|
||||
" model=base_model,\n",
|
||||
" train_dataset=train_data,\n",
|
||||
" eval_dataset=val_data,\n",
|
||||
" peft_config=lora_parameters, # QLoRA config\n",
|
||||
" args=train_parameters, # SFTConfig\n",
|
||||
" data_collator=collator,\n",
|
||||
" callbacks=[EarlyStoppingCallback(early_stopping_patience=5)] # Early stop if no val improvement for 5 steps\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "vHz6JA5_XJ07"
|
||||
},
|
||||
"source": [
|
||||
"### 🚀 5. Run Fine-Tuning and Push to Hub"
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 1000
|
||||
},
|
||||
"id": "GfvAxnXPvB7w",
|
||||
"outputId": "d351d89a-b3d7-4e2b-fee2-5ba2e929837e"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fine_tuning.train()\n",
|
||||
"print(f\"✅ Best model pushed to HF Hub: {HUB_MODEL_NAME}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"This chart shows training loss vs evaluation loss over steps during fine-tuning of Llama 31 8B 4-Bit FT (20K Samples).\n",
|
||||
"\n",
|
||||
"- Blue line (train/loss): Decreasing overall, with some noise. Final value: 1.8596.\n",
|
||||
"- Orange line (eval/loss): Smoother and consistently lower than training loss. Final value: 1.8103.\n",
|
||||
"\n",
|
||||
"- No overfitting: Eval loss < train loss throughout — a good sign.\n",
|
||||
"- Stable convergence: Both curves flatten around step 500, suggesting the model is reaching training stability.\n",
|
||||
"- Final eval loss is low, indicating decent generalization to unseen data.\n",
|
||||
"\n",
|
||||
"This fine-tuning run looks healthy. We can likely push further with more data - 400K run."
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 938
|
||||
},
|
||||
"id": "32vvrYRVAUNg",
|
||||
"outputId": "bb4ab0f6-c390-48f3-a71c-2d259bb0ec0b"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if LOG_TO_WANDB:\n",
|
||||
" wandb.finish()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "IyKZ0r38IfT3"
|
||||
},
|
||||
"source": [
|
||||
"Now that our best model is pushed to Hugging Face, let’s put it to the test.\n",
|
||||
"\n",
|
||||
"🔜 See you in the [next notebook](https://github.com/lisekarimi/lexo/blob/main/09_part7_eval_llama_qlora.ipynb)"
|
||||
],
|
||||
"outputs": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "A100",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,75 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "GHsssBgWM_l0"
|
||||
},
|
||||
"source": [
|
||||
"# 🔍 Predicting Item Prices from Descriptions (Part 8)\n",
|
||||
"---\n",
|
||||
"- Data Curation & Preprocessing\n",
|
||||
"- Model Benchmarking – Traditional ML vs LLMs\n",
|
||||
"- E5 Embeddings & RAG\n",
|
||||
"- Fine-Tuning GPT-4o Mini\n",
|
||||
"- Evaluating LLaMA 3.1 8B Quantized\n",
|
||||
"- Fine-Tuning LLaMA 3.1 with QLoRA\n",
|
||||
"- Evaluating Fine-Tuned LLaMA\n",
|
||||
"- ➡️ Summary & Leaderboard\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"# 🧪 Part 8: Summary & Leaderboard\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# 🥇 The winner is the LLaMA 3.1 8B (4-bit) fine-tuned on 400K samples \n",
|
||||
"\n",
|
||||
"LLaMA 3.1 8B (4-bit) fine-tuned on 400K samples is outperforming even the big guy GPT-4o — with the lowest error and highest accuracy (75.6%).\n",
|
||||
"\n",
|
||||
"RAG + GPT-4o Mini also did well, proving that retrieval adds real value.\n",
|
||||
"\n",
|
||||
"On the other hand, traditional ML models and even human guesses, gave weaker results and fell behind the top models.\n",
|
||||
"\n",
|
||||
"💡 As we’ve seen, a **well-tuned open-source small model** can do amazing things on a focused task — sometimes even better than large, closed models.\n",
|
||||
"It’s not about size — it’s about fit, focus, and fine-tuning.\n",
|
||||
"\n",
|
||||
"# ✨ Conclusion\n",
|
||||
"What a journey! From classic ML to state-of-the-art LLMs, from embeddings to retrieval and fine-tuning — we explored it all to answer: who predicts prices best?\n",
|
||||
"\n",
|
||||
"Thanks for following along — see you in the next challenge! 🚀\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
|
||||
],
|
||||
"outputs": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "T4",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
387
week8/community_contributions/lisekarimi/10_part2_modal.ipynb
Normal file
387
week8/community_contributions/lisekarimi/10_part2_modal.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,33 @@
|
||||
import logging
|
||||
|
||||
class Agent:
|
||||
"""
|
||||
An abstract superclass for Agents
|
||||
Used to log messages in a way that can identify each Agent
|
||||
"""
|
||||
|
||||
# Foreground colors
|
||||
RED = '\033[31m'
|
||||
GREEN = '\033[32m'
|
||||
YELLOW = '\033[33m'
|
||||
BLUE = '\033[34m'
|
||||
MAGENTA = '\033[35m'
|
||||
CYAN = '\033[36m'
|
||||
WHITE = '\033[37m'
|
||||
|
||||
# Background color
|
||||
BG_BLACK = '\033[40m'
|
||||
|
||||
# Reset code to return to default color
|
||||
RESET = '\033[0m'
|
||||
|
||||
name: str = ""
|
||||
color: str = '\033[37m'
|
||||
|
||||
def log(self, message):
|
||||
"""
|
||||
Log this as an info message, identifying the agent
|
||||
"""
|
||||
color_code = self.BG_BLACK + self.color
|
||||
message = f"[{self.name}] {message}"
|
||||
logging.info(color_code + message + self.RESET)
|
||||
@@ -0,0 +1,29 @@
|
||||
import modal
|
||||
from agents.base_agent import Agent
|
||||
|
||||
|
||||
class FTPriceAgent(Agent):
|
||||
"""
|
||||
An Agent that runs the fine-tuned LLM that's running remotely on Modal
|
||||
"""
|
||||
|
||||
name = "FTPrice Agent"
|
||||
color = Agent.RED
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Set up this Agent by creating an instance of the modal class
|
||||
"""
|
||||
self.log("FTPrice Agent is initializing - connecting to modal")
|
||||
Pricer = modal.Cls.from_name("llm-ft-pricer", "Pricer") # 1st API call: to fetch Pricer (remote class)
|
||||
self.pricer = Pricer()
|
||||
self.log("FTPrice Agent is ready")
|
||||
|
||||
def price(self, description: str) -> float:
|
||||
"""
|
||||
Make a remote call to return the estimate of the price of this item
|
||||
"""
|
||||
self.log("FTPrice Agent is calling remote fine-tuned model")
|
||||
result = self.pricer.price.remote(description) # 2nd API call: to run the price method in the remote Pricer class
|
||||
self.log(f"FTPrice Agent completed - predicting ${result:.2f}")
|
||||
return result
|
||||
120
week8/community_contributions/lisekarimi/helpers/items.py
Normal file
120
week8/community_contributions/lisekarimi/helpers/items.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import Optional # A variable might be a certain type or None
|
||||
from transformers import AutoTokenizer
|
||||
import re
|
||||
|
||||
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B"
|
||||
|
||||
MIN_TOKENS = 150 # Minimum tokens required to accept an item
|
||||
MAX_TOKENS = 160 # We limit to 160 tokens so that after adding prompt text, the total stays around 180 tokens.
|
||||
|
||||
MIN_CHARS = 300 # Reject items with less than 300 characters
|
||||
CEILING_CHARS = MAX_TOKENS * 7 # Truncate long text to about 1120 characters (approx 160 tokens)
|
||||
|
||||
class Item:
|
||||
"""
|
||||
An Item is a cleaned, curated datapoint of a Product with a Price
|
||||
"""
|
||||
|
||||
# Load tokenizer for the model
|
||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
|
||||
|
||||
# Define PRICE_LABEL and question for the training prompt
|
||||
PRICE_LABEL = "Price is $"
|
||||
QUESTION = "How much does this cost to the nearest dollar?"
|
||||
|
||||
# A list of useless phrases to remove to reduce noise for price prediction
|
||||
REMOVALS = ['"Batteries Included?": "No"', '"Batteries Included?": "Yes"', '"Batteries Required?": "No"', '"Batteries Required?": "Yes"', "By Manufacturer", "Item", "Date First", "Package", ":", "Number of", "Best Sellers", "Number", "Product "]
|
||||
|
||||
# Attributes for each item
|
||||
title: str
|
||||
price: float
|
||||
category: str
|
||||
token_count: int = 0 # How many tokens in the final prompt
|
||||
|
||||
# Optional fields
|
||||
details: Optional[str] # The value can be a string or can be None
|
||||
prompt: Optional[str] = None
|
||||
include = False # Whether to keep the item or not
|
||||
|
||||
def __init__(self, data, price):
|
||||
self.title = data['title']
|
||||
self.price = price
|
||||
self.parse(data)
|
||||
|
||||
def scrub_details(self):
|
||||
"""
|
||||
Removes useless phrases from details, which often has repeated specs or boilerplate text.
|
||||
"""
|
||||
details = self.details
|
||||
for remove in self.REMOVALS:
|
||||
details = details.replace(remove, "")
|
||||
return details
|
||||
|
||||
def scrub(self, stuff):
|
||||
"""
|
||||
Clean up the provided text by removing unnecessary characters and whitespace
|
||||
Also remove words that are 7+ chars and contain numbers, as these are likely irrelevant product numbers
|
||||
"""
|
||||
stuff = re.sub(r'[:\[\]"{}【】\s]+', ' ', stuff).strip()
|
||||
stuff = stuff.replace(" ,", ",").replace(",,,",",").replace(",,",",")
|
||||
words = stuff.split(' ')
|
||||
select = [word for word in words if len(word)<7 or not any(char.isdigit() for char in word)]
|
||||
return " ".join(select)
|
||||
|
||||
def parse(self, data):
|
||||
"""
|
||||
Prepares the text, checks length, tokenizes it, and sets include = True if it’s valid.
|
||||
"""
|
||||
# Builds a full contents string by combining description, features, and cleaned details.
|
||||
contents = '\n'.join(data['description'])
|
||||
if contents:
|
||||
contents += '\n'
|
||||
features = '\n'.join(data['features'])
|
||||
if features:
|
||||
contents += features + '\n'
|
||||
self.details = data['details']
|
||||
if self.details:
|
||||
contents += self.scrub_details() + '\n'
|
||||
|
||||
# If content is long enough, trim it to max char limit before processing.
|
||||
if len(contents) > MIN_CHARS:
|
||||
contents = contents[:CEILING_CHARS]
|
||||
|
||||
# Clean and tokenize text, then check token count.
|
||||
text = f"{self.scrub(self.title)}\n{self.scrub(contents)}"
|
||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
if len(tokens) > MIN_TOKENS:
|
||||
# Truncate tokens, decode them back and create the training prompt
|
||||
tokens = tokens[:MAX_TOKENS]
|
||||
text = self.tokenizer.decode(tokens)
|
||||
self.make_prompt(text)
|
||||
|
||||
# Mark the item as valid and ready to be used in training
|
||||
self.include = True # Only items with MIN_TOKENS <= tokens <= MAX_TOKENS are kept
|
||||
|
||||
|
||||
def make_prompt(self, text):
|
||||
"""
|
||||
Builds the training prompt using the question, text, and price. Then counts the tokens.
|
||||
"""
|
||||
self.prompt = f"{self.QUESTION}\n\n{text}\n\n"
|
||||
self.prompt += f"{self.PRICE_LABEL }{str(round(self.price))}.00"
|
||||
self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False))
|
||||
|
||||
def test_prompt(self):
|
||||
"""
|
||||
Returns the prompt without the actual price, useful for testing/inference.
|
||||
"""
|
||||
return self.prompt.split(self.PRICE_LABEL )[0] + self.PRICE_LABEL
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Defines how the Item object looks when printed — it shows the title and price.
|
||||
"""
|
||||
return f"<{self.title} = ${self.price}>"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
106
week8/community_contributions/lisekarimi/helpers/loaders.py
Normal file
106
week8/community_contributions/lisekarimi/helpers/loaders.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from datetime import datetime # Measure how long loading takes
|
||||
from tqdm import tqdm # Shows a progress bar while processing data
|
||||
from datasets import load_dataset # Load a dataset from Hugging Face Hub
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor # For parallel processing (speed)
|
||||
from items import Item
|
||||
|
||||
CHUNK_SIZE = 1000 # Process the dataset in chunks of 1000 datapoints at a time (for efficiency)
|
||||
MIN_PRICE = 0.5
|
||||
MAX_PRICE = 999.49
|
||||
WORKER = 4 # Set the number of workers here
|
||||
|
||||
class ItemLoader:
|
||||
|
||||
def __init__(self, name):
|
||||
"""
|
||||
Initialize the loader with a dataset name.
|
||||
"""
|
||||
self.name = name # Store the category name
|
||||
self.dataset = None #Placeholder for the dataset (we load it later in load())
|
||||
|
||||
def process_chunk(self, chunk):
|
||||
"""
|
||||
Convert a chunk of datapoints into valid Item objects.
|
||||
"""
|
||||
batch = [] # Initialize the list to hold valid items
|
||||
|
||||
# Loop through each datapoint in the chunk
|
||||
for datapoint in chunk:
|
||||
try:
|
||||
# Extract price from datapoint
|
||||
price_str = datapoint['price']
|
||||
if price_str:
|
||||
price = float(price_str)
|
||||
|
||||
# Check if price is within valid range
|
||||
if MIN_PRICE <= price <= MAX_PRICE:
|
||||
item = Item(datapoint, price)
|
||||
|
||||
# Keep only valid items
|
||||
if item.include:
|
||||
batch.append(item)
|
||||
except ValueError:
|
||||
continue # Skip datapoints with invalid price format
|
||||
return batch # Return the list of valid items
|
||||
|
||||
|
||||
def load_in_parallel(self, workers):
|
||||
"""
|
||||
Split the dataset into chunks and process them in parallel.
|
||||
"""
|
||||
results = []
|
||||
size = len(self.dataset)
|
||||
chunk_count = (size // CHUNK_SIZE) + 1
|
||||
|
||||
# Build chunks directly here (no separate function)
|
||||
chunks = [
|
||||
self.dataset.select(range(i, min(i + CHUNK_SIZE, size)))
|
||||
for i in range(0, size, CHUNK_SIZE)
|
||||
]
|
||||
|
||||
# Process chunks in parallel using multiple CPU cores
|
||||
with ProcessPoolExecutor(max_workers=workers) as pool:
|
||||
for batch in tqdm(pool.map(self.process_chunk, chunks), total=chunk_count):
|
||||
results.extend(batch)
|
||||
|
||||
# Add the category name to each result
|
||||
for result in results:
|
||||
result.category = self.name
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load(self, workers=WORKER):
|
||||
"""
|
||||
Load and process the dataset, returning valid items.
|
||||
"""
|
||||
# Record start time
|
||||
start = datetime.now()
|
||||
|
||||
# Print loading message
|
||||
print(f"Loading dataset {self.name}", flush=True)
|
||||
|
||||
# Load dataset from Hugging Face (based on category name)
|
||||
self.dataset = load_dataset(
|
||||
"McAuley-Lab/Amazon-Reviews-2023",
|
||||
f"raw_meta_{self.name}",
|
||||
split="full",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Process the dataset in parallel and collect valid items
|
||||
results = self.load_in_parallel(workers)
|
||||
|
||||
# Record end time and print summary
|
||||
finish = datetime.now()
|
||||
print(
|
||||
f"Completed {self.name} with {len(results):,} datapoints in {(finish-start).total_seconds()/60:.1f} mins",
|
||||
flush=True
|
||||
)
|
||||
|
||||
# Return the list of valid items
|
||||
return results
|
||||
|
||||
|
||||
|
||||
|
||||
84
week8/community_contributions/lisekarimi/helpers/testing.py
Normal file
84
week8/community_contributions/lisekarimi/helpers/testing.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import math
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
COLOR_MAP = {"red":RED, "orange": YELLOW, "green": GREEN}
|
||||
|
||||
class Tester:
|
||||
|
||||
def __init__(self, predictor, data, title=None, size=250):
|
||||
self.predictor = predictor
|
||||
self.data = data
|
||||
self.title = title or predictor.__name__.replace("_", " ").title()
|
||||
self.size = size
|
||||
self.guesses = []
|
||||
self.truths = []
|
||||
self.errors = []
|
||||
self.sles = []
|
||||
self.colors = []
|
||||
|
||||
def color_for(self, error, truth):
|
||||
if error<40 or error/truth < 0.2:
|
||||
return "green"
|
||||
elif error<80 or error/truth < 0.4:
|
||||
return "orange"
|
||||
else:
|
||||
return "red"
|
||||
|
||||
def run_datapoint(self, i):
|
||||
datapoint = self.data[i]
|
||||
guess = self.predictor(datapoint)
|
||||
truth = datapoint["price"]
|
||||
error = abs(guess - truth)
|
||||
log_error = math.log(truth+1) - math.log(guess+1)
|
||||
sle = log_error ** 2
|
||||
color = self.color_for(error, truth)
|
||||
title = datapoint["text"][:40] + "..." if len(datapoint["text"]) > 40 else datapoint["text"]
|
||||
self.guesses.append(guess)
|
||||
self.truths.append(truth)
|
||||
self.errors.append(error)
|
||||
self.sles.append(sle)
|
||||
self.colors.append(color)
|
||||
# print(f"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}")
|
||||
|
||||
def chart(self, title):
|
||||
max_error = max(self.errors)
|
||||
plt.figure(figsize=(15, 6))
|
||||
max_val = max(max(self.truths), max(self.guesses))
|
||||
plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)
|
||||
plt.scatter(self.truths, self.guesses, s=3, c=self.colors)
|
||||
plt.xlabel('Ground Truth')
|
||||
plt.ylabel('Model Estimate')
|
||||
plt.xlim(0, max_val)
|
||||
plt.ylim(0, max_val)
|
||||
plt.title(title)
|
||||
|
||||
# Add color legend
|
||||
from matplotlib.lines import Line2D
|
||||
legend_elements = [
|
||||
Line2D([0], [0], marker='o', color='w', label='Accurate (green)', markerfacecolor='green', markersize=8),
|
||||
Line2D([0], [0], marker='o', color='w', label='Medium error (orange)', markerfacecolor='orange', markersize=8),
|
||||
Line2D([0], [0], marker='o', color='w', label='High error (red)', markerfacecolor='red', markersize=8)
|
||||
]
|
||||
plt.legend(handles=legend_elements, loc='upper left')
|
||||
plt.show()
|
||||
|
||||
def report(self):
|
||||
average_error = sum(self.errors) / self.size
|
||||
rmsle = math.sqrt(sum(self.sles) / self.size)
|
||||
hits = sum(1 for color in self.colors if color=="green")
|
||||
title = f"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%"
|
||||
self.chart(title)
|
||||
|
||||
def run(self):
|
||||
self.error = 0
|
||||
for i in range(self.size):
|
||||
self.run_datapoint(i)
|
||||
self.report()
|
||||
|
||||
@classmethod
|
||||
def test(cls, function, data):
|
||||
cls(function, data).run()
|
||||
@@ -0,0 +1,140 @@
|
||||
import modal
|
||||
from modal import App, Volume, Image
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Constants
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
GPU = "T4" # Use a T4 GPU for inference
|
||||
CACHE_PATH = "/cache" # Mount point for the Modal volume
|
||||
|
||||
# Hugging Face model references
|
||||
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B"
|
||||
FINETUNED_MODEL = "ed-donner/pricer-2024-09-13_13.04.39"
|
||||
REVISION = "e8d637df551603dc86cd7a1598a8f44af4d7ae36" # Commit of the fine-tuned model
|
||||
|
||||
# Local cache paths (inside the volume)
|
||||
BASE_MODEL_DIR = f"{CACHE_PATH}/llama_base_model"
|
||||
FINETUNED_MODEL_DIR = f"{CACHE_PATH}/llama_finetuned_model"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Structure
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Container (App: llm-ft-pricer)
|
||||
# ├── /app ← Code + installed Python packages (from image)
|
||||
# ├── /cache ← Mounted Modal volume (`hf-hub-cache`)
|
||||
# │ └── meta-llama/Meta-Llama-3.1-8B/... ← HuggingFace model files downloaded via snapshot_download
|
||||
|
||||
|
||||
|
||||
QUESTION = "How much does this cost to the nearest dollar?"
|
||||
PREFIX = "Price is $" # Used to parse generated output
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Modal App, Image, Volume, Secrets
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
app = modal.App("llm-ft-pricer") # Define the Modal app
|
||||
|
||||
image = (
|
||||
Image.debian_slim()
|
||||
.pip_install("huggingface", "torch", "transformers", "bitsandbytes", "accelerate", "peft") # All needed libraries
|
||||
.env({"HF_HUB_CACHE": CACHE_PATH}) # Hugging Face will store model files in /cache
|
||||
)
|
||||
|
||||
cache_vol = modal.Volume.from_name("hf-hub-cache", create_if_missing=True) # Persisted volume for caching models
|
||||
secrets = [modal.Secret.from_name("HF_TOKEN")] # Hugging Face auth token
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Modal Class: Pricer
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# All methods in this class run inside the container with the image, volume, secrets, and GPU you configured.
|
||||
@app.cls(
|
||||
image=image,
|
||||
secrets=secrets,
|
||||
volumes={CACHE_PATH: cache_vol}, # Mount volume into /cache
|
||||
gpu=GPU,
|
||||
timeout=1800, # 30-minute max runtime
|
||||
min_containers=0, # = 1 : Keeping one container warm uses credits continuously if you forget to stop it.
|
||||
scaledown_window=300, # Shuts down the container
|
||||
)
|
||||
class Pricer:
|
||||
@modal.enter()
|
||||
def setup(self):
|
||||
import os, torch
|
||||
import logging
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
from peft import PeftModel
|
||||
|
||||
# Create cache path if it doesn't exist
|
||||
os.makedirs(CACHE_PATH, exist_ok=True)
|
||||
|
||||
# Download base and fine-tuned models into volume
|
||||
logging.info("Downloading base model...")
|
||||
snapshot_download(BASE_MODEL, local_dir=BASE_MODEL_DIR)
|
||||
|
||||
logging.info("Downloading fine-tuned model...")
|
||||
snapshot_download(FINETUNED_MODEL, revision=REVISION, local_dir=FINETUNED_MODEL_DIR)
|
||||
|
||||
# Quantization config (4-bit)
|
||||
quant_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
|
||||
# Load tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_DIR)
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
# Load base model (quantized)
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
BASE_MODEL_DIR,
|
||||
quantization_config=quant_config,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
# Apply fine-tuned weights
|
||||
self.fine_tuned_model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
FINETUNED_MODEL_DIR,
|
||||
revision=REVISION
|
||||
)
|
||||
self.fine_tuned_model.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
||||
|
||||
@modal.method()
|
||||
def price(self, description: str) -> float:
|
||||
import re, torch
|
||||
from transformers import set_seed
|
||||
|
||||
set_seed(42) # Deterministic output
|
||||
|
||||
# Construct prompt
|
||||
prompt = f"{QUESTION}\n\n{description}\n\n{PREFIX}"
|
||||
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
||||
attention_mask = torch.ones(inputs.shape, device="cuda")
|
||||
|
||||
# Generate model output (max 5 tokens)
|
||||
outputs = self.fine_tuned_model.generate(
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=5,
|
||||
num_return_sequences=1
|
||||
)
|
||||
result = self.tokenizer.decode(outputs[0])
|
||||
|
||||
# Extract number after "Price is $"
|
||||
contents = result.split("Price is $")[1]
|
||||
contents = contents.replace(',', '')
|
||||
match = re.search(r"[-+]?\d*\.\d+|\d+", contents)
|
||||
return float(match.group()) if match else 0 # Return parsed price or 0 if not found
|
||||
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
import sys, modal
|
||||
|
||||
app = modal.App("example-hello-world")
|
||||
|
||||
@app.function()
|
||||
def f(i: int) -> int:
|
||||
if i % 2 == 0:
|
||||
print("hello", i)
|
||||
else:
|
||||
print("world", i, file=sys.stderr)
|
||||
|
||||
return i * i
|
||||
@@ -44,7 +44,6 @@
|
||||
"from sentence_transformers import SentenceTransformer\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"import chromadb\n",
|
||||
"from items import Item\n",
|
||||
"from sklearn.manifold import TSNE\n",
|
||||
"import plotly.graph_objects as go"
|
||||
]
|
||||
@@ -77,6 +76,18 @@
|
||||
"login(hf_token, add_to_git_credential=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8491f550-df4a-4c8f-a260-a7a419e8efb6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Another import after Logging in to Hugging Face - thank you Trung N.!\n",
|
||||
"\n",
|
||||
"from items import Item"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3d4995a4-f67f-4871-87df-8c6439b06366",
|
||||
|
||||
@@ -44,7 +44,6 @@
|
||||
"from sentence_transformers import SentenceTransformer\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"import chromadb\n",
|
||||
"from items import Item\n",
|
||||
"from sklearn.manifold import TSNE\n",
|
||||
"import plotly.graph_objects as go"
|
||||
]
|
||||
@@ -174,7 +173,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
"version": "3.11.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -44,7 +44,6 @@
|
||||
"from sentence_transformers import SentenceTransformer\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"import chromadb\n",
|
||||
"from items import Item\n",
|
||||
"from sklearn.manifold import TSNE\n",
|
||||
"import plotly.graph_objects as go"
|
||||
]
|
||||
@@ -166,7 +165,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
"version": "3.11.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -48,7 +48,6 @@
|
||||
"from sentence_transformers import SentenceTransformer\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"import chromadb\n",
|
||||
"from items import Item\n",
|
||||
"from testing import Tester"
|
||||
]
|
||||
},
|
||||
@@ -66,6 +65,31 @@
|
||||
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ce73b034-9ec1-4533-ba41-3e57c7878b61",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Log in to HuggingFace\n",
|
||||
"\n",
|
||||
"hf_token = os.environ['HF_TOKEN']\n",
|
||||
"login(hf_token, add_to_git_credential=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4c01daad-86b0-4bc0-91ba-20a64df043ed",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Another import after Logging in to Hugging Face - thank you Trung N.!\n",
|
||||
"\n",
|
||||
"from items import Item"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -495,7 +519,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
"version": "3.11.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -84,6 +84,31 @@
|
||||
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1006966f-96b7-4e1a-93f0-2bb9a09057c8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Log in to HuggingFace\n",
|
||||
"\n",
|
||||
"hf_token = os.environ['HF_TOKEN']\n",
|
||||
"login(hf_token, add_to_git_credential=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "de0e4b22-ee61-4b79-95bc-3cd707d5f83d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Another import after Logging in to Hugging Face - thank you Trung N.!\n",
|
||||
"\n",
|
||||
"from items import Item"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -78,7 +78,7 @@
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <h2 style=\"color:#f71;\">Additional resource: more sophisticated planning agent</h2>\n",
|
||||
" <span style=\"color:#f71;\">The Planning Agent that we use in the next cell is simply a python script that calls the other Agents; frankly that's all we require for this project. But if you're intrigued to see a more Autonomous version in which we give the Planning Agent tools and allow it to decide which Agents to call, see my implementation of <a href=\"https://github.com/ed-donner/agentic/blob/main/workshop/agents/autonomous_planning_agent.py\">AutonomousPlanningAgent</a> in my related repo, <a href=\"https://github.com/ed-donner/agentic\">Agentic</a>. This is an example with multiple tools that dynamically decides which function to call.\n",
|
||||
" <span style=\"color:#f71;\">The Planning Agent that we use in the next cell is simply a python script that calls the other Agents; frankly that's all we require for this project. But if you're intrigued to see a more Autonomous version in which we give the Planning Agent tools and allow it to decide which Agents to call, see my implementation of <a href=\"https://github.com/ed-donner/agentic/blob/main/workshop/price_agents/autonomous_planning_agent.py\">AutonomousPlanningAgent</a> in my related repo, <a href=\"https://github.com/ed-donner/agentic\">Agentic</a>. This is an example with multiple tools that dynamically decides which function to call.\n",
|
||||
" </span>\n",
|
||||
" </td>\n",
|
||||
" </tr>\n",
|
||||
@@ -144,7 +144,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
"version": "3.11.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user