Updates
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -1,27 +1,10 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"gpuType": "T4",
|
||||
"include_colab_link": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/dkisselev-zz/llm_engineering/blob/wk7/Week_7_Excersise_fine_tuned_model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
@@ -29,65 +12,67 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "GHsssBgWM_l0"
|
||||
},
|
||||
"source": [
|
||||
"# Predict Product Prices\n",
|
||||
"\n",
|
||||
"Model evaluation and inference tuning\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "GHsssBgWM_l0"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Libraries and configuration"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "HnwMdAP3IHad"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Libraries and configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
|
||||
"!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
|
||||
],
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "MDyR63OTNUJ6"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
|
||||
"!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "-yikV8pRBer9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import math\n",
|
||||
"import numpy as np\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from google.colab import userdata\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"import wandb\n",
|
||||
"import torch\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import transformers\n",
|
||||
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n",
|
||||
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
||||
"from datetime import datetime\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"from peft import PeftModel\n",
|
||||
"import matplotlib.pyplot as plt"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "-yikV8pRBer9"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "uuTX-xonNeOK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Models\n",
|
||||
"\n",
|
||||
@@ -140,61 +125,61 @@
|
||||
"BLUE = \"\\033[94m\"\n",
|
||||
"RESET = \"\\033[0m\"\n",
|
||||
"COLOR_MAP = {\"red\":RED, \"orange\": BLUE, \"green\": GREEN}"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "uuTX-xonNeOK"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8JArT3QAQAjx"
|
||||
},
|
||||
"source": [
|
||||
"### Load Data\n",
|
||||
"\n",
|
||||
"Data is loaded from Huggin Face\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "8JArT3QAQAjx"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "WyFPZeMcM88v"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Log in to HuggingFace\n",
|
||||
"hf_token = userdata.get('HF_TOKEN')\n",
|
||||
"login(hf_token)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "WyFPZeMcM88v"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "cvXVoJH8LS6u"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = load_dataset(DATASET_NAME)\n",
|
||||
"train = dataset['train']\n",
|
||||
"test = dataset['test']"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "cvXVoJH8LS6u"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Load Tokenizer and Model"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "qJWQ0a3wZ0Bw"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Load Tokenizer and Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "lAUAAcEC6ido"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 4 or 8 but quantization\n",
|
||||
"if QUANT_4_BIT:\n",
|
||||
@@ -208,29 +193,29 @@
|
||||
" quant_config = BitsAndBytesConfig(\n",
|
||||
" load_in_8bit=True\n",
|
||||
" )"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "lAUAAcEC6ido"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "OQy4pCk-dutf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load model from w&b\n",
|
||||
"if ARTIFCAT_LOCATTION==\"WB\":\n",
|
||||
" artifact = wandb.Api().artifact(WANDB_ARTIFACT_PATH, type='model')\n",
|
||||
" artifact_dir = artifact.download() # Downloads to a local cache dir"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "OQy4pCk-dutf"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "R_O04fKxMMT-"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the Tokenizer and the Model\n",
|
||||
"\n",
|
||||
@@ -256,24 +241,24 @@
|
||||
" fine_tuned_model = PeftModel.from_pretrained(base_model, artifact_dir)\n",
|
||||
"\n",
|
||||
"print(f\"Memory footprint: {fine_tuned_model.get_memory_footprint() / 1e6:.1f} MB\")"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "R_O04fKxMMT-"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Hyperparameter helpers"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "UObo1-RqaNnT"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Hyperparameter helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "n4u27kbwlekE"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def calculate_weighted_price(prices, probabilities):\n",
|
||||
" \"\"\"\n",
|
||||
@@ -303,15 +288,15 @@
|
||||
" weighted_price = np.average(prices_array, weights=probs_array)\n",
|
||||
"\n",
|
||||
" return weighted_price"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "n4u27kbwlekE"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ROjIbGuH0FWS"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_top_k_predictions(prompt, device=\"cuda\"):\n",
|
||||
" \"\"\"\n",
|
||||
@@ -351,15 +336,15 @@
|
||||
" return [], []\n",
|
||||
"\n",
|
||||
" return prices, probabilities"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ROjIbGuH0FWS"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "tnmTAiEG32xK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def make_prompt(text):\n",
|
||||
" if ARTIFCAT_LOCATTION==\"HF\":\n",
|
||||
@@ -373,15 +358,15 @@
|
||||
" # prompt = p_array[0] + \"\\n\\n\\n\" + p_title + \"\\n\\n\" + p_descr + \"\\n\\n\" + p_price\n",
|
||||
" # return text\n",
|
||||
" return prompt"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "tnmTAiEG32xK"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "VNAEw5Eg4ABk"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
@@ -457,15 +442,15 @@
|
||||
" @classmethod\n",
|
||||
" def test(cls, function, data):\n",
|
||||
" cls(function, data).run()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "VNAEw5Eg4ABk"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "dbWS1DPV4TPQ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Search_K:\n",
|
||||
" \"\"\"\n",
|
||||
@@ -710,28 +695,28 @@
|
||||
" @classmethod\n",
|
||||
" def test(cls, function, data):\n",
|
||||
" cls(function, data).run()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "dbWS1DPV4TPQ"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Vtt13OuVE-t7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Search best K\n",
|
||||
"search_k = Search_K(get_top_k_predictions, test, title=f\"{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\" if ARTIFCAT_LOCATTION==\"WB\" else None)\n",
|
||||
"best_k = search_k.run()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Vtt13OuVE-t7"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "tuwYu1NYljIv"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"top_K = best_k\n",
|
||||
"\n",
|
||||
@@ -785,39 +770,51 @@
|
||||
" final_price = np.average(prices_np, weights=probs_np)\n",
|
||||
"\n",
|
||||
" return float(final_price) # Return as a standard python float"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "tuwYu1NYljIv"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "3SxpLBJH70E-"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt=make_prompt(test[80]['text'])\n",
|
||||
"print(prompt)\n",
|
||||
"\n",
|
||||
"improved_model_predict(prompt)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "3SxpLBJH70E-"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "W_KcLvyt6kbb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Run Estimate vs Ground Truth\n",
|
||||
"tester = Tester(improved_model_predict, test, title=f\"{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\" if ARTIFCAT_LOCATTION==\"WB\" else None)\n",
|
||||
"tester.run()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "W_KcLvyt6kbb"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "T4",
|
||||
"include_colab_link": true,
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
Reference in New Issue
Block a user