Adding links to W7 Google Colab
This commit is contained in:
@@ -6,553 +6,9 @@
|
||||
"id": "GHsssBgWM_l0"
|
||||
},
|
||||
"source": [
|
||||
"# Predict Product Prices\n",
|
||||
"### Link\n",
|
||||
"\n",
|
||||
"## Day 3: Training!\n",
|
||||
"\n",
|
||||
"# IMPORTANT please read me!!\n",
|
||||
"\n",
|
||||
"When you run the pip installs below, you may get an error from pip complaining about an incompatible version of fsspec.\n",
|
||||
"\n",
|
||||
"You should ignore that error! The version of fsspec is the right version, needed by HuggingFace.\n",
|
||||
"\n",
|
||||
"If you ask ChatGPT, it will encourage you to pip install a more recent version of fsspec. But that would be problematic; HuggingFace will fail to load the dataset later with an obscure error about file systems.\n",
|
||||
"\n",
|
||||
"So please run the pip installs as they appear below, and look the other way if you get an error!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "MDyR63OTNUJ6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# pip installs\n",
|
||||
"\n",
|
||||
"#!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
|
||||
"#!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "-yikV8pRBer9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"# With much thanks to Islam S. for identifying that there was a missing import!\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import math\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from google.colab import userdata\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"import torch\n",
|
||||
"import transformers\n",
|
||||
"from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, set_seed, BitsAndBytesConfig\n",
|
||||
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
||||
"import wandb\n",
|
||||
"from peft import LoraConfig\n",
|
||||
"from trl import SFTTrainer, SFTConfig\n",
|
||||
"from datetime import datetime\n",
|
||||
"import matplotlib.pyplot as plt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "uuTX-xonNeOK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Constants\n",
|
||||
"\n",
|
||||
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
|
||||
"PROJECT_NAME = \"pricer\"\n",
|
||||
"HF_USER = \"nikhilr12\" # your HF name here!\n",
|
||||
"\n",
|
||||
"# Data\n",
|
||||
"\n",
|
||||
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
|
||||
"# Or just use the one I've uploaded\n",
|
||||
"# DATASET_NAME = \"ed-donner/pricer-data\"\n",
|
||||
"MAX_SEQUENCE_LENGTH = 182\n",
|
||||
"\n",
|
||||
"# Run name for saving the model in the hub\n",
|
||||
"\n",
|
||||
"RUN_NAME = f\"{datetime.now():%Y-%m-%d_%H.%M.%S}\"\n",
|
||||
"PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
|
||||
"HUB_MODEL_NAME = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
|
||||
"\n",
|
||||
"# Hyperparameters for QLoRA\n",
|
||||
"\n",
|
||||
"LORA_R = 32\n",
|
||||
"LORA_ALPHA = 64\n",
|
||||
"TARGET_MODULES = [\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"]\n",
|
||||
"LORA_DROPOUT = 0.1\n",
|
||||
"QUANT_4_BIT = True\n",
|
||||
"\n",
|
||||
"# Hyperparameters for Training\n",
|
||||
"\n",
|
||||
"EPOCHS = 1 # you can do more epochs if you wish, but only 1 is needed - more is probably overkill\n",
|
||||
"BATCH_SIZE = 4 # on an A100 box this can go up to 16\n",
|
||||
"GRADIENT_ACCUMULATION_STEPS = 1\n",
|
||||
"LEARNING_RATE = 1e-4\n",
|
||||
"LR_SCHEDULER_TYPE = 'cosine'\n",
|
||||
"WARMUP_RATIO = 0.03\n",
|
||||
"OPTIMIZER = \"paged_adamw_32bit\"\n",
|
||||
"\n",
|
||||
"# Admin config - note that SAVE_STEPS is how often it will upload to the hub\n",
|
||||
"# I've changed this from 5000 to 2000 so that you get more frequent saves\n",
|
||||
"\n",
|
||||
"STEPS = 50\n",
|
||||
"SAVE_STEPS = 2000\n",
|
||||
"LOG_TO_WANDB = True\n",
|
||||
"\n",
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"id": "QyHOj-c4FmkM",
|
||||
"outputId": "a9d22f5c-ad62-48cd-9a29-74dc8390c00c"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"HUB_MODEL_NAME"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "PfvyitbgCMMQ"
|
||||
},
|
||||
"source": [
|
||||
"# More on Optimizers\n",
|
||||
"\n",
|
||||
"https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one#optimizer-choice\n",
|
||||
"\n",
|
||||
"The most common is Adam or AdamW (Adam with Weight Decay). \n",
|
||||
"Adam achieves good convergence by storing the rolling average of the previous gradients; however, it adds an additional memory footprint of the order of the number of model parameters.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8JArT3QAQAjx"
|
||||
},
|
||||
"source": [
|
||||
"### Log in to HuggingFace and Weights & Biases\n",
|
||||
"\n",
|
||||
"If you don't already have a HuggingFace account, visit https://huggingface.co to sign up and create a token.\n",
|
||||
"\n",
|
||||
"Then select the Secrets for this Notebook by clicking on the key icon in the left, and add a new secret called `HF_TOKEN` with the value as your token.\n",
|
||||
"\n",
|
||||
"Repeat this for weightsandbiases at https://wandb.ai and add a secret called `WANDB_API_KEY`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "WyFPZeMcM88v"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Log in to HuggingFace\n",
|
||||
"\n",
|
||||
"hf_token = userdata.get('HF_TOKEN')\n",
|
||||
"login(hf_token, add_to_git_credential=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "yJNOv3cVvJ68",
|
||||
"outputId": "5b1445de-ec41-476d-878e-f747c4f13b87"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Log in to Weights & Biases\n",
|
||||
"wandb_api_key = userdata.get('WANDB_API_KEY')\n",
|
||||
"os.environ[\"WANDB_API_KEY\"] = wandb_api_key\n",
|
||||
"wandb.login()\n",
|
||||
"\n",
|
||||
"# Configure Weights & Biases to record against our project\n",
|
||||
"os.environ[\"WANDB_PROJECT\"] = PROJECT_NAME\n",
|
||||
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" if LOG_TO_WANDB else \"end\"\n",
|
||||
"os.environ[\"WANDB_WATCH\"] = \"gradients\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "cvXVoJH8LS6u"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = load_dataset(DATASET_NAME)\n",
|
||||
"train = dataset['train']\n",
|
||||
"test = dataset['test']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rJb9IDVjOAn9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# if you wish to reduce the training dataset to 20,000 points instead, then uncomment this line:\n",
|
||||
"train = train.select(range(1000))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 139
|
||||
},
|
||||
"id": "8_SUsKqA23Gc",
|
||||
"outputId": "2d52d813-fb5c-4477-a66c-769e803e8709"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if LOG_TO_WANDB:\n",
|
||||
" wandb.init(project=PROJECT_NAME, name=RUN_NAME)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "qJWQ0a3wZ0Bw"
|
||||
},
|
||||
"source": [
|
||||
"## Now load the Tokenizer and Model\n",
|
||||
"\n",
|
||||
"The model is \"quantized\" - we are reducing the precision to 4 bits."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9lb7M9xn46wx"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# pick the right quantization\n",
|
||||
"\n",
|
||||
"if QUANT_4_BIT:\n",
|
||||
" quant_config = BitsAndBytesConfig(\n",
|
||||
" load_in_4bit=True,\n",
|
||||
" bnb_4bit_use_double_quant=True,\n",
|
||||
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
||||
" bnb_4bit_quant_type=\"nf4\"\n",
|
||||
" )\n",
|
||||
"else:\n",
|
||||
" quant_config = BitsAndBytesConfig(\n",
|
||||
" load_in_8bit=True,\n",
|
||||
" bnb_8bit_compute_dtype=torch.bfloat16\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 67,
|
||||
"referenced_widgets": [
|
||||
"0027ee1677114da58d2d1dfe0e7ebc63",
|
||||
"400cfd3f1d1c44d59f3ec606b3e96458",
|
||||
"1704ee1b4dd74bceb600add874adac17",
|
||||
"db9dffdcf2274e18be2060b61c946a7e",
|
||||
"93e93a10c577468faa2ccfa16c6cfdca",
|
||||
"577d8bbc02614e128c39fcb10f7237a9",
|
||||
"6ef16480ae234804906aa94346cb16fc",
|
||||
"eb0ce8001d6348e3807b79e1ca1c5c92",
|
||||
"5b4ef788e4b747b084f49265719e0d49",
|
||||
"c56db5a9708e408298cb21ded221979b",
|
||||
"904bb8b2f8374e6f856d772da322609b"
|
||||
]
|
||||
},
|
||||
"id": "R_O04fKxMMT-",
|
||||
"outputId": "e5bd004d-1143-4109-9ea9-1122634e5496"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the Tokenizer and the Model\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
|
||||
"tokenizer.pad_token = tokenizer.eos_token\n",
|
||||
"tokenizer.padding_side = \"right\"\n",
|
||||
"\n",
|
||||
"base_model = AutoModelForCausalLM.from_pretrained(\n",
|
||||
" BASE_MODEL,\n",
|
||||
" quantization_config=quant_config,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
")\n",
|
||||
"base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
|
||||
"\n",
|
||||
"print(f\"Memory footprint: {base_model.get_memory_footprint() / 1e6:.1f} MB\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "9BYO0If4uWys"
|
||||
},
|
||||
"source": [
|
||||
"# Data Collator\n",
|
||||
"\n",
|
||||
"It's important that we ensure during Training that we are not trying to train the model to predict the description of products; only their price.\n",
|
||||
"\n",
|
||||
"We need to tell the trainer that everything up to \"Price is $\" is there to give context to the model to predict the next token, but does not need to be learned.\n",
|
||||
"\n",
|
||||
"The trainer needs to teach the model to predict the token(s) after \"Price is $\".\n",
|
||||
"\n",
|
||||
"There is a complicated way to do this by setting Masks, but luckily HuggingFace provides a super simple helper class to take care of this for us."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "2omVEaPIVJZa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from trl import DataCollatorForCompletionOnlyLM\n",
|
||||
"response_template = \"Price is $\"\n",
|
||||
"collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "4DaOeBhyy9eS"
|
||||
},
|
||||
"source": [
|
||||
"# AND NOW\n",
|
||||
"\n",
|
||||
"## We set up the configuration for Training\n",
|
||||
"\n",
|
||||
"We need to create 2 objects:\n",
|
||||
"\n",
|
||||
"A LoraConfig object with our hyperparameters for LoRA\n",
|
||||
"\n",
|
||||
"An SFTConfig with our overall Training parameters"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 49,
|
||||
"referenced_widgets": [
|
||||
"dc29a54964904f39b56a28463ef7d983",
|
||||
"72ac906a836c410facee1378d0449671",
|
||||
"e553f6d313d641c3a3599ce6cc4f4511",
|
||||
"71eaeacbcbc9437aa1d5cdb90c092bac",
|
||||
"ea12f77a72614bf494559cb7b339e689",
|
||||
"22de7398b1704e7e97aad8e2af1c8ec9",
|
||||
"03f246d93610455bb6c03b626d7349e1",
|
||||
"c0d8e56eea03422aa774560e6e1c0dd5",
|
||||
"747fdc33f79549e3bf52e62c1be9f053",
|
||||
"1dfb95f7cf8b4ab48ceed2ff89769388",
|
||||
"2f5901b8ad5a47cf8e9126b2b1f46e6b"
|
||||
]
|
||||
},
|
||||
"id": "fCwmDmkSATvj",
|
||||
"outputId": "75e6caec-be5b-4ae1-a19b-272d7b35d92a"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# First, specify the configuration parameters for LoRA\n",
|
||||
"\n",
|
||||
"lora_parameters = LoraConfig(\n",
|
||||
" lora_alpha=LORA_ALPHA,\n",
|
||||
" lora_dropout=LORA_DROPOUT,\n",
|
||||
" r=LORA_R,\n",
|
||||
" bias=\"none\",\n",
|
||||
" task_type=\"CAUSAL_LM\",\n",
|
||||
" target_modules=TARGET_MODULES,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Next, specify the general configuration parameters for training\n",
|
||||
"\n",
|
||||
"train_parameters = SFTConfig(\n",
|
||||
" output_dir=PROJECT_RUN_NAME,\n",
|
||||
" num_train_epochs=EPOCHS,\n",
|
||||
" per_device_train_batch_size=BATCH_SIZE,\n",
|
||||
" per_device_eval_batch_size=1,\n",
|
||||
" eval_strategy=\"no\",\n",
|
||||
" gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n",
|
||||
" optim=OPTIMIZER,\n",
|
||||
" save_steps=SAVE_STEPS,\n",
|
||||
" save_total_limit=10,\n",
|
||||
" logging_steps=STEPS,\n",
|
||||
" learning_rate=LEARNING_RATE,\n",
|
||||
" weight_decay=0.001,\n",
|
||||
" fp16=False,\n",
|
||||
" bf16=True,\n",
|
||||
" max_grad_norm=0.3,\n",
|
||||
" max_steps=-1,\n",
|
||||
" warmup_ratio=WARMUP_RATIO,\n",
|
||||
" group_by_length=True,\n",
|
||||
" lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
|
||||
" report_to=\"wandb\" if LOG_TO_WANDB else None,\n",
|
||||
" run_name=RUN_NAME,\n",
|
||||
" max_seq_length=MAX_SEQUENCE_LENGTH,\n",
|
||||
" dataset_text_field=\"text\",\n",
|
||||
" save_strategy=\"steps\",\n",
|
||||
" hub_strategy=\"every_save\",\n",
|
||||
" push_to_hub=True,\n",
|
||||
" hub_model_id=HUB_MODEL_NAME,\n",
|
||||
" hub_private_repo=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# And now, the Supervised Fine Tuning Trainer will carry out the fine-tuning\n",
|
||||
"# Given these 2 sets of configuration parameters\n",
|
||||
"# The latest version of trl is showing a warning about labels - please ignore this warning\n",
|
||||
"# But let me know if you don't see good training results (loss coming down).\n",
|
||||
"\n",
|
||||
"fine_tuning = SFTTrainer(\n",
|
||||
" model=base_model,\n",
|
||||
" train_dataset=train,\n",
|
||||
" peft_config=lora_parameters,\n",
|
||||
" args=train_parameters,\n",
|
||||
" data_collator=collator\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ArjP7_OCQOin"
|
||||
},
|
||||
"source": [
|
||||
"## In the next cell, we kick off fine-tuning!\n",
|
||||
"\n",
|
||||
"This will run for some time, uploading to the hub every SAVE_STEPS steps.\n",
|
||||
"\n",
|
||||
"After some time, Google might stop your colab. For people on free plans, it can happen whenever Google is low on resources. For anyone on paid plans, they can give you up to 24 hours, but there's no guarantee.\n",
|
||||
"\n",
|
||||
"If your server is stopped, you can follow my colab here to resume from your last save:\n",
|
||||
"\n",
|
||||
"https://colab.research.google.com/drive/1qGTDVIas_Vwoby4UVi2vwsU0tHXy8OMO#scrollTo=R_O04fKxMMT-\n",
|
||||
"\n",
|
||||
"I've saved this colab with my final run in the output so you can see the example. The trick is that I needed to set `is_trainable=True` when loading the fine_tuned model.\n",
|
||||
"\n",
|
||||
"### Anyway, with that in mind, let's kick this off!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 469,
|
||||
"referenced_widgets": [
|
||||
"6b976e17e00b4a9195dfc9907a50cbe4",
|
||||
"332715fb37444bdb93f88cb8b799f9aa",
|
||||
"22dce334249b4bc19fe8d097f6315a05",
|
||||
"4937f94b87d0485c9491e8920944f6c7",
|
||||
"8a0719a796cf4bab876e49624a10eaf6",
|
||||
"eea597dbb971446e9c617e325d4801ec",
|
||||
"c09c60be549546a9a6320145945427ce",
|
||||
"f28bd564e8af40a0841e7b3f66a581fe",
|
||||
"45e10d3089ae41edbc93a8a7b7baf303",
|
||||
"2ad6b4650b194b2c9cbea37dbeeb3d5b",
|
||||
"51e92866fe6447c8b0dc77baee0557cc",
|
||||
"e9a0b46383254cc386a5418a40a14fb6",
|
||||
"af040db48acc4f68bca3c9b86e895632",
|
||||
"a29cfdcb36c6473ab71eed91d8690283",
|
||||
"c430f4fd7c47439092f873b8f14ee96a",
|
||||
"31e634785164408e85274bb1551ffd8a",
|
||||
"308d5947a8734e189db4a5edda4b8725",
|
||||
"db1d98d7bb984594bc8d872b5adba3a8",
|
||||
"1fc084960ffe44d9ae1df31f54404165",
|
||||
"8f296c9e5a8b41339477d26765e3c2fe",
|
||||
"74a228cd618240ca9ae23ed33a9ea313",
|
||||
"4d2fe7dc55de4445a78592b75b412041",
|
||||
"eabe30b3d32b488684b3a56d3eacceb5",
|
||||
"3529120de7d748fb975b8ad15c478158",
|
||||
"39422b90cec74671948bf58cca168e88",
|
||||
"8f0faed5ca4a4b0fb5276ae3cc99335a",
|
||||
"e3eda94aba8e4b69916dbd06a30819e1",
|
||||
"b03bb213471f442fb27869076fd0d114",
|
||||
"b527b9e6ed024456ac8c86d564037144",
|
||||
"dfee2733fb1e4442b11dfd02461ba25e",
|
||||
"815e1788df7343d5b0a4dea892fafd2f",
|
||||
"0ab93a2d9b3040c7b27fa24293bd36e1",
|
||||
"521b23912f1d4788992f896d4f511f97",
|
||||
"32ae94ab5ca744a3a31c42eef7833ccd",
|
||||
"d86b69d253954d378a008c9f7a7b2563",
|
||||
"b32566f421f9423ab2f918ba2582bd5a",
|
||||
"51e8c2fbd6c7472e95c0fdf84617a900",
|
||||
"2dcd79e6c5134a009a5db8e057ba3db0",
|
||||
"8e5e3e78758647c8ab3ab089614eb13a",
|
||||
"395d1442667d42ce851fbad753665ec1",
|
||||
"84206cac147a44bfa40b51ddd5afd973",
|
||||
"5240ecb5f0cb4b9eaa9be217022f003d",
|
||||
"29b890e29cce42c79e79698e62cd68e2",
|
||||
"4afdb6f7146f464e95375c375cbe0477"
|
||||
]
|
||||
},
|
||||
"id": "GfvAxnXPvB7w",
|
||||
"outputId": "5685ec02-7744-4aa2-e2e1-29421266a020"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Fine-tune!\n",
|
||||
"fine_tuning.train()\n",
|
||||
"\n",
|
||||
"# Push our fine-tuned model to Hugging Face\n",
|
||||
"fine_tuning.model.push_to_hub(PROJECT_RUN_NAME, private=True)\n",
|
||||
"print(f\"Saved to the hub: {PROJECT_RUN_NAME}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 597
|
||||
},
|
||||
"id": "32vvrYRVAUNg",
|
||||
"outputId": "a40f5806-9b88-4100-d662-c3db3683484a"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if LOG_TO_WANDB:\n",
|
||||
" wandb.finish()"
|
||||
"https://colab.research.google.com/drive/1W1SKxFMIXxDEs8bO12pYkl5wiXgFBtGp?usp=sharing"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -6,590 +6,10 @@
|
||||
"id": "GHsssBgWM_l0"
|
||||
},
|
||||
"source": [
|
||||
"### Link\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Predict Product Prices\n",
|
||||
"\n",
|
||||
"### And now, to evaluate our fine-tuned open source model\n",
|
||||
"\n"
|
||||
"https://colab.research.google.com/drive/1DRszViJE_yytQu2vW635uBbovYzbtGIU?usp=sharing\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "MDyR63OTNUJ6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # pip installs\n",
|
||||
"\n",
|
||||
"# !pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
|
||||
"# !pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 87,
|
||||
"referenced_widgets": [
|
||||
"fae345983a9b4c65ba5f4486968482e0",
|
||||
"5e06dd16afae4d25997cc696e242e02d",
|
||||
"6f9debbfa7b84497b4a7c6ae1dbadede",
|
||||
"0bd09927a77a46c791feddf9bd5a433e",
|
||||
"e3259f4ba2be4bfea262079f8b40ee78",
|
||||
"1f7b6a5cd65249b6a784b3d993ebf125",
|
||||
"5cf9f93381a947c99583afce4cc36f99",
|
||||
"b3845ddc2a804b3888cdf38dda1ec35b",
|
||||
"946ea0d7f775428a8d2f040e513d14b5",
|
||||
"3304d3241718445883c7f449f3eb9ba8",
|
||||
"d6a16fb6fd3546ad9a50d60e1032ba26"
|
||||
]
|
||||
},
|
||||
"id": "-yikV8pRBer9",
|
||||
"outputId": "74bb3897-aa65-4e07-c24d-481224f59cd2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import math\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from google.colab import userdata\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"import torch\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import transformers\n",
|
||||
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n",
|
||||
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
||||
"from datetime import datetime\n",
|
||||
"from peft import PeftModel\n",
|
||||
"import matplotlib.pyplot as plt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "uuTX-xonNeOK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Constants\n",
|
||||
"\n",
|
||||
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
|
||||
"PROJECT_NAME = \"pricer\"\n",
|
||||
"HF_USER = \"nikhilr12\" # your HF name here! Or use mine if you just want to reproduce my results.\n",
|
||||
"\n",
|
||||
"# The run itself\n",
|
||||
"\n",
|
||||
"RUN_NAME = \"2025-10-27_07.15.03\"\n",
|
||||
"PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
|
||||
"REVISION = None # or REVISION = None\n",
|
||||
"FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
|
||||
"\n",
|
||||
"# Uncomment this line if you wish to use my model\n",
|
||||
"# FINETUNED_MODEL = f\"ed-donner/{PROJECT_RUN_NAME}\"\n",
|
||||
"\n",
|
||||
"# Data\n",
|
||||
"\n",
|
||||
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
|
||||
"# Or just use the one I've uploaded\n",
|
||||
"# DATASET_NAME = \"ed-donner/pricer-data\"\n",
|
||||
"\n",
|
||||
"# Hyperparameters for QLoRA\n",
|
||||
"\n",
|
||||
"QUANT_4_BIT = True\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"# Used for writing to output in color\n",
|
||||
"\n",
|
||||
"GREEN = \"\\033[92m\"\n",
|
||||
"YELLOW = \"\\033[93m\"\n",
|
||||
"RED = \"\\033[91m\"\n",
|
||||
"RESET = \"\\033[0m\"\n",
|
||||
"COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8JArT3QAQAjx"
|
||||
},
|
||||
"source": [
|
||||
"### Log in to HuggingFace\n",
|
||||
"\n",
|
||||
"If you don't already have a HuggingFace account, visit https://huggingface.co to sign up and create a token.\n",
|
||||
"\n",
|
||||
"Then select the Secrets for this Notebook by clicking on the key icon in the left, and add a new secret called `HF_TOKEN` with the value as your token.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "WyFPZeMcM88v"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Log in to HuggingFace\n",
|
||||
"\n",
|
||||
"hf_token = userdata.get('HF_TOKEN')\n",
|
||||
"login(hf_token, add_to_git_credential=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "cvXVoJH8LS6u"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = load_dataset(DATASET_NAME)\n",
|
||||
"train = dataset['train']\n",
|
||||
"test = dataset['test']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "xb86e__Wc7j_",
|
||||
"outputId": "283972ae-e4ec-4178-89ec-781a2c1df941"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "qJWQ0a3wZ0Bw"
|
||||
},
|
||||
"source": [
|
||||
"## Now load the Tokenizer and Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "lAUAAcEC6ido"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# pick the right quantization (thank you Robert M. for spotting the bug with the 8 bit version!)\n",
|
||||
"\n",
|
||||
"if QUANT_4_BIT:\n",
|
||||
" quant_config = BitsAndBytesConfig(\n",
|
||||
" load_in_4bit=True,\n",
|
||||
" bnb_4bit_use_double_quant=True,\n",
|
||||
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
||||
" bnb_4bit_quant_type=\"nf4\"\n",
|
||||
" )\n",
|
||||
"else:\n",
|
||||
" quant_config = BitsAndBytesConfig(\n",
|
||||
" load_in_8bit=True,\n",
|
||||
" bnb_8bit_compute_dtype=torch.bfloat16\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 355,
|
||||
"referenced_widgets": [
|
||||
"5a9d3dfe76044f099752012d94e284d7",
|
||||
"7610f3ae45b844afb86e81d64c6f7881",
|
||||
"8749efb404234da4b98dd133b18d7d95",
|
||||
"9634570d5d4a44638873758ad8df9f56",
|
||||
"72e73bd3afb944bc90fe7b7dfec999d6",
|
||||
"5e0c49da33fb4a7bbb8500dc4f0c7a65",
|
||||
"932f9059f3e346f7804b47b32e8dfd65",
|
||||
"23818fac9f164141871ed0a6fe629e7b",
|
||||
"35863a078a0b495d9350e11779bcf207",
|
||||
"14d4e3ad3b15405380c5f1bc283ba3d0",
|
||||
"50722a00fc834405b2304ecb57ec832c",
|
||||
"ae900d2317a94a468306ee3f0be63ffa",
|
||||
"7d8d5b41af754338a361f421693c74dd",
|
||||
"ac40d68809834db0acb5541f8a1962f3",
|
||||
"4e7a9f9fa72a45599c8ca15c824c2dae",
|
||||
"62970d70378946e3b0a14e501434356d",
|
||||
"f2601bb0b92740dab682479554ebfd40",
|
||||
"091b7a56c02443c2bf7a33a745c0c5be",
|
||||
"d51654fa5fcd4ae4ad9f87a405c8fda1",
|
||||
"20f24c83262c43e7a58fd7f3a9e7cc3c",
|
||||
"926c3996d2794b97a9b67a6b3bf27a09",
|
||||
"2272028ce824446a9865bd44c167752f",
|
||||
"0a4576cf0eaf40a382bc5e54337cddb0",
|
||||
"dc5bf3cdb1b74056990dd44d273b8a9a",
|
||||
"627fcd384f034c088832ef0f2c9288df",
|
||||
"3f60821b601447ef8dd1ceb2a2dc0aba",
|
||||
"8185c16afbb7486f94d32dfeff914e81",
|
||||
"5efd72db23914252842bb3364843e43b",
|
||||
"d7cae279b79248a98a0f92bb939d1457",
|
||||
"6ec169fa557c493a8fa4d591657e45f3",
|
||||
"68cdaaa86f0b4209982b095b62640c5d",
|
||||
"e270e36535594a3ba041708e8559ab94",
|
||||
"a56e440cea7646e9899273a306d7b38a",
|
||||
"8832fa21bdb840f2a9b7e374796777e1",
|
||||
"329ead36db4b476e8abc0eb0d15cec37",
|
||||
"19016be7686e469e9d0c2461f3929edb",
|
||||
"999c954e83b94e258eae2085ec5b748d",
|
||||
"a4f7fb3d15e64003af840170937a285b",
|
||||
"481ae5f8fed3469f82687790fdad6236",
|
||||
"de5cec82e1e9495f8f35071591fa4697",
|
||||
"3d841d0ed8a24baeafada4c5d5fdcd8d",
|
||||
"82f3c4c94a98416aa8e8ea022b684ff0",
|
||||
"f6c2f058e9554b7796d67bef99051fdc",
|
||||
"2636f2d4c7c7485f849eb07d91f8e945",
|
||||
"fbbbf4669db643c69d623a96e4317b2b",
|
||||
"f44c70885e0d48be9583977df7b82999",
|
||||
"1e996ee50b844cb2a1f32944c23c2471",
|
||||
"4b7e4de9e61b426fbcd7549b382f6060",
|
||||
"581a45a563da4ed28127994fb0d3f5fa",
|
||||
"a20a37d53b3c4c44bbd0da301a9e7914",
|
||||
"0ef0eff7425d4bb298cefef5808fc3d6",
|
||||
"4b7dc3d57b694600a21eb100bcee6110",
|
||||
"e2a8c7d4e2b44de9a651770f2aebb9de",
|
||||
"efa563e444374368a4ddf96873f1bb9a",
|
||||
"0475882549294f86ad4eeadcdfd88db2",
|
||||
"ce40e46d65b04431bc7bcfea28ffa41f",
|
||||
"907d58bf33b34f9bbccda35aeb437bd7",
|
||||
"325e18eebdc34a1f848794030cbeaa72",
|
||||
"ce67ad96eff247ceb3492e9bca2fba86",
|
||||
"cbbe1416591a427689d238bce37f016e",
|
||||
"1df9dd2071124086b9552dc7ed2f6458",
|
||||
"33b742a6f64d48a6bd10acdc3b262a07",
|
||||
"572489d76d144b2e9d65c5c4449108bc",
|
||||
"b4496fc3a4854cfabacf9125585cab16",
|
||||
"a798a7bf5d624683afced668c97b0814",
|
||||
"da9d6ce9d4804325990ae6f121e66d38",
|
||||
"1a560f503bbb44e9ace4633d0799ffdb",
|
||||
"f10cef0c6b0f45b08709a38205b96c62",
|
||||
"924a858e6a684b4393f856a8fc084e1b",
|
||||
"7f1b810fef674dbfb9fe66fd742d152e",
|
||||
"38be9ef6237a437688eda77503258e14",
|
||||
"c62905f3eed24180be709b695fd1bf0d",
|
||||
"7ab94f29e5144dfab8861570272ac24e",
|
||||
"b6a86921f7dd4f9ab4fb0e4b6a0cbe23",
|
||||
"1b1bb8cb47c646f3abeee43b075e24ff",
|
||||
"320c8f0e74b1429793cae77a661e5e4d",
|
||||
"dd4c897c785d4f1188de31fb7a20e281",
|
||||
"4284fff454dd45688058556a19380cc9",
|
||||
"47441dac50f148e99751809e11b79903",
|
||||
"b5e44bb30a944fc380fb5c4bc4a1363b",
|
||||
"e42ccf6b1e0f4e679f9e961b89d0da06",
|
||||
"45b5ae1f114c4c6e9d71e56fdecbc63b",
|
||||
"055d68da77a5491a8f14d78ff59a1a5f",
|
||||
"9d3464e2ae264b01ade876842e7678ab",
|
||||
"edd40d2944cb4a3a8b09d31d42489705",
|
||||
"d48eee36b6d441a2ac35d6f850b500bb",
|
||||
"e9c9b06e881a435b860e9b3246b04e54",
|
||||
"89779ab7fb6049e7b437e452fcca78d6",
|
||||
"a6e75de2860d47e6a942f77dc474210c",
|
||||
"8dacb9554470435ba0a496d35cdb3c92",
|
||||
"6b7c7a5a758f49b895e2f6a0df35465f",
|
||||
"b7906982da7f419396278602e399e853",
|
||||
"b7c85b576e5846448cb4403aac5ad671",
|
||||
"2215030c89f844b39bd83b797037010e",
|
||||
"396a0b39fe6f41d69a91a2bafc280ebe",
|
||||
"98e07a651dfc494d865267c0e231378f",
|
||||
"1295e277482c43bfa23ad2d7464b1281",
|
||||
"9fd29b4744c64dc4acb712ef64efa348",
|
||||
"721ad0a7e54d42d78399f7f27b740467",
|
||||
"eb9bdaf911f04e44a9089b2f54e37144",
|
||||
"d41332f717e64c90943b9eb10219374d",
|
||||
"fdef479f2fe34746b586d0be8b785f46",
|
||||
"09b427e15ea7487fa3abb9f6c9b4e5be",
|
||||
"51ff51e5e66849cfb559144026e861fa",
|
||||
"c8e2f9a017004645bdbeba4e90afc36b",
|
||||
"d27811a568b54effb577adc450b6c805",
|
||||
"048e312ae9464f5aa8a64f0a140a0a8e",
|
||||
"0fb380968c9845988b58395d0830ac2a",
|
||||
"3782e192c2d147cbbc4ab8447a112fa3",
|
||||
"cd82596e540040728a08854a3afddc00"
|
||||
]
|
||||
},
|
||||
"id": "R_O04fKxMMT-",
|
||||
"outputId": "a478f7f7-9ac5-4083-aef9-7d132906b66e"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the Tokenizer and the Model\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
|
||||
"tokenizer.pad_token = tokenizer.eos_token\n",
|
||||
"tokenizer.padding_side = \"right\"\n",
|
||||
"\n",
|
||||
"base_model = AutoModelForCausalLM.from_pretrained(\n",
|
||||
" BASE_MODEL,\n",
|
||||
" quantization_config=quant_config,\n",
|
||||
" device_map=\"auto\",\n",
|
||||
")\n",
|
||||
"base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
|
||||
"\n",
|
||||
"# Load the fine-tuned model with PEFT\n",
|
||||
"if REVISION:\n",
|
||||
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n",
|
||||
"else:\n",
|
||||
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(f\"Memory footprint: {fine_tuned_model.get_memory_footprint() / 1e6:.1f} MB\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "kD-GJtbrdd5t",
|
||||
"outputId": "9300568c-0cc0-44a7-eae3-55e49b230924"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fine_tuned_model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "UObo1-RqaNnT"
|
||||
},
|
||||
"source": [
|
||||
"# THE MOMENT OF TRUTH!\n",
|
||||
"\n",
|
||||
"## Use the model in inference mode\n",
|
||||
"\n",
|
||||
"Remember, GPT-4o had an average error of \\$76. \n",
|
||||
"Llama 3.1 base model had an average error of \\$396. \n",
|
||||
"This human had an error of \\$127. \n",
|
||||
"\n",
|
||||
"## Caveat\n",
|
||||
"\n",
|
||||
"Keep in mind that prices of goods vary considerably; the model can't predict things like sale prices that it doesn't have any information about."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Qst1LhBVAB04"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def extract_price(s):\n",
|
||||
" if \"Price is $\" in s:\n",
|
||||
" contents = s.split(\"Price is $\")[1]\n",
|
||||
" contents = contents.replace(',','')\n",
|
||||
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", contents)\n",
|
||||
" return float(match.group()) if match else 0\n",
|
||||
" return 0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "jXFBW_5UeEcp",
|
||||
"outputId": "28e8a3b6-74de-43eb-ae0d-2d0566d2767d"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"extract_price(\"Price is $a fabulous 899.99 or so\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Oj_PzpdFAIMk"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Original prediction function takes the most likely next token\n",
|
||||
"\n",
|
||||
"def model_predict(prompt):\n",
|
||||
" set_seed(42)\n",
|
||||
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
|
||||
" attention_mask = torch.ones(inputs.shape, device=\"cuda\")\n",
|
||||
" outputs = fine_tuned_model.generate(inputs, attention_mask=attention_mask, max_new_tokens=3, num_return_sequences=1)\n",
|
||||
" response = tokenizer.decode(outputs[0])\n",
|
||||
" return extract_price(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Je5dR8QEAI1d"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# An improved prediction function takes a weighted average of the top 3 choices\n",
|
||||
"# This code would be more complex if we couldn't take advantage of the fact\n",
|
||||
"# That Llama generates 1 token for any 3 digit number\n",
|
||||
"\n",
|
||||
"top_K = 3\n",
|
||||
"\n",
|
||||
"def improved_model_predict(prompt, device=\"cuda\"):\n",
|
||||
" set_seed(42)\n",
|
||||
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
|
||||
" attention_mask = torch.ones(inputs.shape, device=device)\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n",
|
||||
" next_token_logits = outputs.logits[:, -1, :].to('cpu')\n",
|
||||
"\n",
|
||||
" next_token_probs = F.softmax(next_token_logits, dim=-1)\n",
|
||||
" top_prob, top_token_id = next_token_probs.topk(top_K)\n",
|
||||
" prices, weights = [], []\n",
|
||||
" for i in range(top_K):\n",
|
||||
" predicted_token = tokenizer.decode(top_token_id[0][i])\n",
|
||||
" probability = top_prob[0][i]\n",
|
||||
" try:\n",
|
||||
" result = float(predicted_token)\n",
|
||||
" except ValueError as e:\n",
|
||||
" result = 0.0\n",
|
||||
" if result > 0:\n",
|
||||
" prices.append(result)\n",
|
||||
" weights.append(probability)\n",
|
||||
" if not prices:\n",
|
||||
" return 0.0, 0.0\n",
|
||||
" total = sum(weights)\n",
|
||||
" weighted_prices = [price * weight / total for price, weight in zip(prices, weights)]\n",
|
||||
" return sum(weighted_prices).item()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "lQk7jNlm1oV9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "30lzJXBH7BcK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Tester:\n",
|
||||
"\n",
|
||||
" def __init__(self, predictor, data, title=None, size=250):\n",
|
||||
" self.predictor = predictor\n",
|
||||
" self.data = data\n",
|
||||
" self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
|
||||
" self.size = size\n",
|
||||
" self.guesses = []\n",
|
||||
" self.truths = []\n",
|
||||
" self.errors = []\n",
|
||||
" self.sles = []\n",
|
||||
" self.colors = []\n",
|
||||
"\n",
|
||||
" def color_for(self, error, truth):\n",
|
||||
" if error<40 or error/truth < 0.2:\n",
|
||||
" return \"green\"\n",
|
||||
" elif error<80 or error/truth < 0.4:\n",
|
||||
" return \"orange\"\n",
|
||||
" else:\n",
|
||||
" return \"red\"\n",
|
||||
"\n",
|
||||
" def run_datapoint(self, i):\n",
|
||||
" datapoint = self.data[i]\n",
|
||||
" guess = self.predictor(datapoint[\"text\"])\n",
|
||||
" truth = datapoint[\"price\"]\n",
|
||||
" error = abs(guess - truth)\n",
|
||||
" log_error = math.log(truth+1) - math.log(guess+1)\n",
|
||||
" sle = log_error ** 2\n",
|
||||
" color = self.color_for(error, truth)\n",
|
||||
" title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n",
|
||||
" self.guesses.append(guess)\n",
|
||||
" self.truths.append(truth)\n",
|
||||
" self.errors.append(error)\n",
|
||||
" self.sles.append(sle)\n",
|
||||
" self.colors.append(color)\n",
|
||||
" print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
|
||||
"\n",
|
||||
" def chart(self, title):\n",
|
||||
" max_error = max(self.errors)\n",
|
||||
" plt.figure(figsize=(12, 8))\n",
|
||||
" max_val = max(max(self.truths), max(self.guesses))\n",
|
||||
" plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n",
|
||||
" plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
|
||||
" plt.xlabel('Ground Truth')\n",
|
||||
" plt.ylabel('Model Estimate')\n",
|
||||
" plt.xlim(0, max_val)\n",
|
||||
" plt.ylim(0, max_val)\n",
|
||||
" plt.title(title)\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
" def report(self):\n",
|
||||
" average_error = sum(self.errors) / self.size\n",
|
||||
" rmsle = math.sqrt(sum(self.sles) / self.size)\n",
|
||||
" hits = sum(1 for color in self.colors if color==\"green\")\n",
|
||||
" title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n",
|
||||
" self.chart(title)\n",
|
||||
"\n",
|
||||
" def run(self):\n",
|
||||
" self.error = 0\n",
|
||||
" for i in range(self.size):\n",
|
||||
" self.run_datapoint(i)\n",
|
||||
" self.report()\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def test(cls, function, data):\n",
|
||||
" cls(function, data).run()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 1000
|
||||
},
|
||||
"id": "W_KcLvyt6kbb",
|
||||
"outputId": "bb975323-841f-464c-d876-83c2238573e9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Tester.test(improved_model_predict, test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "M4NSMcKl3Bhw"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Reference in New Issue
Block a user