Merge pull request #867 from chimwemwekachaje/main
Added Week 6 Andela GenAI Bootcamp exercise
This commit is contained in:
4
week6/community-contributions/kachaje-andelaGenAi-bootcamp/.gitignore
vendored
Normal file
4
week6/community-contributions/kachaje-andelaGenAi-bootcamp/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
items.py
|
||||||
|
loaders.py
|
||||||
|
llama32_pricer_lora/
|
||||||
|
testing.py
|
||||||
@@ -0,0 +1,347 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Fine-tune Llama 3.2 1B Locally with LoRA\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook fine-tunes Llama 3.2 1B model for product pricing using Low-Rank Adaptation (LoRA), which is memory-efficient and suitable for local training.\n",
|
||||||
|
"\n",
|
||||||
|
"**macOS Compatibility:** This notebook uses Hugging Face transformers and PEFT (instead of Unsloth) for better macOS compatibility. Works on CPU, Apple Silicon (Metal), or NVIDIA GPU.\n",
|
||||||
|
"\n",
|
||||||
|
"**Optimizations:**\n",
|
||||||
|
"- LoRA for memory-efficient fine-tuning (only ~1% of parameters trained)\n",
|
||||||
|
"- bfloat16 mixed precision training when available\n",
|
||||||
|
"- Gradient checkpointing for additional memory savings\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Install PyTorch first (required for other packages on macOS ARM64)\n",
|
||||||
|
"! uv pip -q install torch torchvision torchaudio\n",
|
||||||
|
"\n",
|
||||||
|
"# Install required packages for fine-tuning with LoRA (works on macOS without GPU)\n",
|
||||||
|
"! uv pip -q install trl peft accelerate datasets transformers"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Imports\n",
|
||||||
|
"import os\n",
|
||||||
|
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
||||||
|
"\n",
|
||||||
|
"import re\n",
|
||||||
|
"import json\n",
|
||||||
|
"import pickle\n",
|
||||||
|
"from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments\n",
|
||||||
|
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
||||||
|
"from datasets import Dataset\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from items import Item\n",
|
||||||
|
"from testing import Tester\n",
|
||||||
|
"\n",
|
||||||
|
"# Import SFTTrainer - try SFTConfig if available, otherwise use old API\n",
|
||||||
|
"try:\n",
|
||||||
|
" from trl import SFTTrainer, SFTConfig\n",
|
||||||
|
" USE_SFT_CONFIG = True\n",
|
||||||
|
"except ImportError:\n",
|
||||||
|
" from trl import SFTTrainer\n",
|
||||||
|
" USE_SFT_CONFIG = False\n",
|
||||||
|
" print(\"Note: Using older TRL API without SFTConfig\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Load Training Data\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load the training and test datasets\n",
|
||||||
|
"with open('train_lite.pkl', 'rb') as f:\n",
|
||||||
|
" train_data = pickle.load(f)\n",
|
||||||
|
"\n",
|
||||||
|
"with open('test_lite.pkl', 'rb') as f:\n",
|
||||||
|
" test_data = pickle.load(f)\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Training samples: {len(train_data)}\")\n",
|
||||||
|
"print(f\"Test samples: {len(test_data)}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Convert Data to Chat Format\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def messages_for(item):\n",
|
||||||
|
" \"\"\"Convert item to chat format for fine-tuning\"\"\"\n",
|
||||||
|
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
|
||||||
|
" user_prompt = item.test_prompt().replace(\" to the nearest dollar\",\"\").replace(\"\\n\\nPrice is $\",\"\")\n",
|
||||||
|
" return [\n",
|
||||||
|
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
||||||
|
" {\"role\": \"assistant\", \"content\": f\"Price is ${item.price:.2f}\"}\n",
|
||||||
|
" ]\n",
|
||||||
|
"\n",
|
||||||
|
"# Convert to chat format\n",
|
||||||
|
"def format_for_training(items):\n",
|
||||||
|
" texts = []\n",
|
||||||
|
" for item in items:\n",
|
||||||
|
" messages = messages_for(item)\n",
|
||||||
|
" # Format as instruction following format for unsloth\n",
|
||||||
|
" text = f\"### System:\\n{messages[0]['content']}\\n\\n### User:\\n{messages[1]['content']}\\n\\n### Assistant:\\n{messages[2]['content']}\"\n",
|
||||||
|
" texts.append(text)\n",
|
||||||
|
" return texts\n",
|
||||||
|
"\n",
|
||||||
|
"train_texts = format_for_training(train_data)\n",
|
||||||
|
"print(f\"Example training text:\\n{train_texts[0]}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Create dataset\n",
|
||||||
|
"train_dataset = Dataset.from_dict({\"text\": train_texts})\n",
|
||||||
|
"print(f\"Dataset created with {len(train_dataset)} samples\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Load Model with LoRA Configuration\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load model and tokenizer\n",
|
||||||
|
"model_name = \"unsloth/Llama-3.2-1B-Instruct\"\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||||
|
"tokenizer.pad_token = tokenizer.eos_token\n",
|
||||||
|
"tokenizer.padding_side = \"right\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Check if CUDA is available (won't be on macOS without GPU)\n",
|
||||||
|
"device_map = \"auto\" if torch.cuda.is_available() else None\n",
|
||||||
|
"\n",
|
||||||
|
"# Load model (use dtype=bfloat16 for Apple Silicon)\n",
|
||||||
|
"model = AutoModelForCausalLM.from_pretrained(\n",
|
||||||
|
" model_name,\n",
|
||||||
|
" dtype=torch.bfloat16 if torch.backends.mps.is_available() else torch.float32,\n",
|
||||||
|
" device_map=device_map,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Configure LoRA\n",
|
||||||
|
"lora_config = LoraConfig(\n",
|
||||||
|
" task_type=TaskType.CAUSAL_LM,\n",
|
||||||
|
" r=16,\n",
|
||||||
|
" lora_alpha=16,\n",
|
||||||
|
" lora_dropout=0.1,\n",
|
||||||
|
" bias=\"none\",\n",
|
||||||
|
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
||||||
|
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Add LoRA adapters\n",
|
||||||
|
"model = get_peft_model(model, lora_config)\n",
|
||||||
|
"model.print_trainable_parameters()\n",
|
||||||
|
"\n",
|
||||||
|
"# Attach tokenizer to model for SFTTrainer\n",
|
||||||
|
"model.tokenizer = tokenizer\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Model loaded with LoRA adapters\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Configure Training Arguments\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Configure training arguments\n",
|
||||||
|
"training_args = TrainingArguments(\n",
|
||||||
|
" output_dir=\"./llama32_pricer_lora\",\n",
|
||||||
|
" per_device_train_batch_size=2,\n",
|
||||||
|
" gradient_accumulation_steps=4,\n",
|
||||||
|
" warmup_steps=10,\n",
|
||||||
|
" max_steps=100, # Adjust based on dataset size\n",
|
||||||
|
" learning_rate=2e-4,\n",
|
||||||
|
" bf16=torch.backends.mps.is_available() or torch.cuda.is_available(), # Use bf16 if available\n",
|
||||||
|
" logging_steps=10,\n",
|
||||||
|
" save_strategy=\"steps\",\n",
|
||||||
|
" save_steps=25,\n",
|
||||||
|
" eval_steps=25,\n",
|
||||||
|
" save_total_limit=2,\n",
|
||||||
|
" load_best_model_at_end=False,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Training arguments configured\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Initialize Trainer and Start Fine-tuning\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Initialize trainer\n",
|
||||||
|
"# Model is already wrapped with PEFT (LoRA), so we use basic parameters\n",
|
||||||
|
"trainer = SFTTrainer(\n",
|
||||||
|
" model=model,\n",
|
||||||
|
" train_dataset=train_dataset,\n",
|
||||||
|
" args=training_args,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Trainer initialized\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Train the model\n",
|
||||||
|
"trainer.train()\n",
|
||||||
|
"print(\"Training completed!\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Save the Fine-tuned Model\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Save the model\n",
|
||||||
|
"model.save_pretrained(\"llama32_pricer_lora\")\n",
|
||||||
|
"tokenizer.save_pretrained(\"llama32_pricer_lora\")\n",
|
||||||
|
"print(\"Model saved to llama32_pricer_lora/\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Test the Fine-tuned Model\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Helper function to extract price from response\n",
|
||||||
|
"def get_price(s):\n",
|
||||||
|
" s = s.replace('$','').replace(',','')\n",
|
||||||
|
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
|
||||||
|
" return float(match.group()) if match else 0\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to test the fine-tuned model\n",
|
||||||
|
"def llama32_finetuned_model(item):\n",
|
||||||
|
" messages = messages_for(item)\n",
|
||||||
|
" \n",
|
||||||
|
" # Format the prompt\n",
|
||||||
|
" prompt = f\"### System:\\n{messages[0]['content']}\\n\\n### User:\\n{messages[1]['content']}\\n\\n### Assistant:\\n\"\n",
|
||||||
|
" \n",
|
||||||
|
" # Move to appropriate device\n",
|
||||||
|
" device = next(model.parameters()).device\n",
|
||||||
|
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" outputs = model.generate(\n",
|
||||||
|
" **inputs,\n",
|
||||||
|
" max_new_tokens=50,\n",
|
||||||
|
" temperature=0.1,\n",
|
||||||
|
" do_sample=True,\n",
|
||||||
|
" pad_token_id=tokenizer.eos_token_id\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
|
||||||
|
" return get_price(response)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Test on the test dataset\n",
|
||||||
|
"print(\"Testing fine-tuned model...\")\n",
|
||||||
|
"Tester.test(llama32_finetuned_model, test_data)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
@@ -0,0 +1,512 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a246687d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# The Product Pricer\n",
|
||||||
|
"\n",
|
||||||
|
"A model that can estimate how much something costs, from its description\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "3792ce5b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"! uv -q pip install langchain-ollama"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "390c3ce3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# imports\n",
|
||||||
|
"\n",
|
||||||
|
"import os\n",
|
||||||
|
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
||||||
|
"\n",
|
||||||
|
"from dotenv import load_dotenv\n",
|
||||||
|
"from huggingface_hub import login\n",
|
||||||
|
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import pickle\n",
|
||||||
|
"import re\n",
|
||||||
|
"from langchain_ollama import OllamaLLM\n",
|
||||||
|
"from openai import OpenAI\n",
|
||||||
|
"from testing import Tester\n",
|
||||||
|
"import json\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8a8ff331",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"load_dotenv(override=True)\n",
|
||||||
|
"hf_token = os.getenv(\"HF_TOKEN\")\n",
|
||||||
|
"login(hf_token, add_to_git_credential=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1051e21e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from items import Item\n",
|
||||||
|
"from loaders import ItemLoader\n",
|
||||||
|
"\n",
|
||||||
|
"%matplotlib inline"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "290fa868",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset_names = [\n",
|
||||||
|
" \"Appliances\",\n",
|
||||||
|
"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "12ffad66",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"items = []\n",
|
||||||
|
"for dataset_name in dataset_names:\n",
|
||||||
|
" loader = ItemLoader(dataset_name)\n",
|
||||||
|
" items.extend(loader.load())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0b3890d7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(f\"A grand total of {len(items):,} items\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "246ab22a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Plot the distribution of token counts again\n",
|
||||||
|
"\n",
|
||||||
|
"tokens = [item.token_count for item in items]\n",
|
||||||
|
"plt.figure(figsize=(15, 6))\n",
|
||||||
|
"plt.title(f\"Token counts: Avg {sum(tokens)/len(tokens):,.1f} and highest {max(tokens):,}\\n\")\n",
|
||||||
|
"plt.xlabel('Length (tokens)')\n",
|
||||||
|
"plt.ylabel('Count')\n",
|
||||||
|
"plt.hist(tokens, rwidth=0.7, color=\"skyblue\", bins=range(0, 300, 10))\n",
|
||||||
|
"plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "3a49a4d4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Plot the distribution of prices\n",
|
||||||
|
"\n",
|
||||||
|
"prices = [item.price for item in items]\n",
|
||||||
|
"plt.figure(figsize=(15, 6))\n",
|
||||||
|
"plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.1f} and highest {max(prices):,}\\n\")\n",
|
||||||
|
"plt.xlabel('Price ($)')\n",
|
||||||
|
"plt.ylabel('Count')\n",
|
||||||
|
"plt.hist(prices, rwidth=0.7, color=\"blueviolet\", bins=range(0, 1000, 10))\n",
|
||||||
|
"plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "57e4ea1b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# How does the price vary with the character count of the prompt?\n",
|
||||||
|
"\n",
|
||||||
|
"sample = items\n",
|
||||||
|
"\n",
|
||||||
|
"sizes = [len(item.prompt) for item in sample]\n",
|
||||||
|
"prices = [item.price for item in sample]\n",
|
||||||
|
"\n",
|
||||||
|
"# Create the scatter plot\n",
|
||||||
|
"plt.figure(figsize=(15, 8))\n",
|
||||||
|
"plt.scatter(sizes, prices, s=0.2, color=\"red\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Add labels and title\n",
|
||||||
|
"plt.xlabel('Size')\n",
|
||||||
|
"plt.ylabel('Price')\n",
|
||||||
|
"plt.title('Is there a simple correlation?')\n",
|
||||||
|
"\n",
|
||||||
|
"# Display the plot\n",
|
||||||
|
"plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e6620daa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def report(item):\n",
|
||||||
|
" prompt = item.prompt\n",
|
||||||
|
" tokens = Item.tokenizer.encode(item.prompt)\n",
|
||||||
|
" print(prompt)\n",
|
||||||
|
" print(tokens[-10:])\n",
|
||||||
|
" print(Item.tokenizer.batch_decode(tokens[-10:]))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "af71d177",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"report(sample[50])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "75ab3c21",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import random\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"random.seed(42)\n",
|
||||||
|
"random.shuffle(sample)\n",
|
||||||
|
"train = sample[:25_000]\n",
|
||||||
|
"test = sample[25_000:27_000]\n",
|
||||||
|
"print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6d5cbd3a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(train[0].prompt)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "39de86d6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(test[0].test_prompt())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "65480df9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Plot the distribution of prices in the first 250 test points\n",
|
||||||
|
"\n",
|
||||||
|
"prices = [float(item.price) for item in test[:250]]\n",
|
||||||
|
"plt.figure(figsize=(15, 6))\n",
|
||||||
|
"plt.title(f\"Avg {sum(prices)/len(prices):.2f} and highest {max(prices):,.2f}\\n\")\n",
|
||||||
|
"plt.xlabel('Price ($)')\n",
|
||||||
|
"plt.ylabel('Count')\n",
|
||||||
|
"plt.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 10))\n",
|
||||||
|
"plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7a315b10",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"filtered_prices = [float(item.price) for item in test if item.price > 99.999]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5693c9c6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Confirm that the tokenizer tokenizes all 3 digit prices into 1 token"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "99e8cfc3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"for price in filtered_prices:\n",
|
||||||
|
" tokens = Item.tokenizer.encode(f\"{price}\", add_special_tokens=False)\n",
|
||||||
|
" assert len(tokens) == 3\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "f3159195",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Helpers"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7bdc5dd5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def messages_for(item):\n",
|
||||||
|
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
|
||||||
|
" user_prompt = item.test_prompt().replace(\" to the nearest dollar\",\"\").replace(\"\\n\\nPrice is $\",\"\")\n",
|
||||||
|
" return [\n",
|
||||||
|
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
||||||
|
" {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
|
||||||
|
" ]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "211b0658",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# A utility function to extract the price from a string\n",
|
||||||
|
"\n",
|
||||||
|
"def get_price(s):\n",
|
||||||
|
" s = s.replace('$','').replace(',','')\n",
|
||||||
|
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
|
||||||
|
" return float(match.group()) if match else 0"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ee01da84",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Convert the items into a list of json objects - a \"jsonl\" string\n",
|
||||||
|
"# Each row represents a message in the form:\n",
|
||||||
|
"# {\"messages\" : [{\"role\": \"system\", \"content\": \"You estimate prices...\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def make_jsonl(items):\n",
|
||||||
|
" result = \"\"\n",
|
||||||
|
" for item in items:\n",
|
||||||
|
" messages = messages_for(item)\n",
|
||||||
|
" messages_str = json.dumps(messages)\n",
|
||||||
|
" result += '{\"messages\": ' + messages_str +'}\\n'\n",
|
||||||
|
" return result.strip()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "f23e8959",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Convert the items into jsonl and write them to a file\n",
|
||||||
|
"\n",
|
||||||
|
"def write_jsonl(items, filename):\n",
|
||||||
|
" with open(filename, \"w\") as f:\n",
|
||||||
|
" jsonl = make_jsonl(items)\n",
|
||||||
|
" f.write(jsonl)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b6a83580",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Load data"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "451b974f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with open('train_lite.pkl', 'rb') as f:\n",
|
||||||
|
" train_lite = pickle.load(f)\n",
|
||||||
|
"\n",
|
||||||
|
"with open('test_lite.pkl', 'rb') as f:\n",
|
||||||
|
" test_lite = pickle.load(f)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "f365d65c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"messages_for(test_lite[0])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "57b0b160",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"get_price(\"The price is roughly $99.99 because blah blah\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ff3e4670",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Models"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9f62c94b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"MODEL_LLAMA3_2 = \"llama3.2\"\n",
|
||||||
|
"MODEL_MISTRAL = \"mistral\"\n",
|
||||||
|
"MODEL_TINY_LLAMA = \"tinyllama\"\n",
|
||||||
|
"\n",
|
||||||
|
"llm3_2 = OllamaLLM(model=MODEL_LLAMA3_2)\n",
|
||||||
|
"llmMistral = OllamaLLM(model=MODEL_MISTRAL)\n",
|
||||||
|
"llmTinyLlama = OllamaLLM(model=MODEL_TINY_LLAMA)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d18394fb",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Model Tests"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7dac335f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def llama3_2_model(item):\n",
|
||||||
|
" response = llm3_2.invoke(messages_for(item))\n",
|
||||||
|
" return get_price(response)\n",
|
||||||
|
"\n",
|
||||||
|
"def mistral_model(item):\n",
|
||||||
|
" response = llmMistral.invoke(messages_for(item))\n",
|
||||||
|
" return get_price(response)\n",
|
||||||
|
"\n",
|
||||||
|
"def tinyllama_model(item):\n",
|
||||||
|
" response = llmTinyLlama.invoke(messages_for(item))\n",
|
||||||
|
" return get_price(response)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "062e78c2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"test_lite[0].price"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c58756f2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"Tester.test(llama3_2_model, test_lite)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "899e2401",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"Tester.test(mistral_model, test_lite)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2f5bc9ad",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"Tester.test(tinyllama_model, test_lite)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user