Merge pull request #912 from ranskills/week7-fine-tuning-os-model-qlora
Bootcamp(Ransford): Week7 - Fine-Tuning OS Model QLoRA
This commit is contained in:
@@ -0,0 +1,322 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "ogushXV4ZGMi"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Fine-Tuning an OpenSource Model using QLoRA"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "chQmy4_HXhgr"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install -qU peft trl bitsandbytes datasets wandb"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "zHZJoUeQZJNo"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"from transformers import (\n",
|
||||||
|
" AutoModelForCausalLM,\n",
|
||||||
|
" AutoTokenizer,\n",
|
||||||
|
" BitsAndBytesConfig,\n",
|
||||||
|
" TrainingArguments\n",
|
||||||
|
")\n",
|
||||||
|
"from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
|
||||||
|
"from trl import SFTConfig, SFTTrainer\n",
|
||||||
|
"from datasets import load_dataset"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"MODEL_NAME = \"mistralai/Mistral-7B-v0.1\"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"bnb_config = BitsAndBytesConfig(\n",
|
||||||
|
" load_in_4bit=True,\n",
|
||||||
|
" bnb_4bit_use_double_quant=True,\n",
|
||||||
|
" bnb_4bit_quant_type=\"nf4\",\n",
|
||||||
|
" bnb_4bit_compute_dtype=torch.bfloat16\n",
|
||||||
|
")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "NyDHu1vrZ1gO"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"model = AutoModelForCausalLM.from_pretrained(\n",
|
||||||
|
" MODEL_NAME,\n",
|
||||||
|
" quantization_config=bnb_config,\n",
|
||||||
|
" device_map=\"auto\",\n",
|
||||||
|
" trust_remote_code=True\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
||||||
|
"tokenizer.pad_token = tokenizer.eos_token\n",
|
||||||
|
"tokenizer.padding_side = \"right\"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"model = prepare_model_for_kbit_training(model)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "v2X2xM-dZ7fN"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"lora_config = LoraConfig(\n",
|
||||||
|
" r=8,\n",
|
||||||
|
" lora_alpha=16,\n",
|
||||||
|
" target_modules=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"],\n",
|
||||||
|
" lora_dropout=0.05,\n",
|
||||||
|
" bias=\"none\",\n",
|
||||||
|
" task_type=\"CAUSAL_LM\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"model = get_peft_model(model, lora_config)\n",
|
||||||
|
"model.print_trainable_parameters()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "akK6pJnLaeIr"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"DATA_DIR = '/content/data'\n",
|
||||||
|
"\n",
|
||||||
|
"data_files = [\n",
|
||||||
|
" f'{DATA_DIR}/all_beauty_train.parquet',\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"dataset = load_dataset('parquet', data_files=data_files, split='train')\n",
|
||||||
|
"\n",
|
||||||
|
"train_test = dataset.train_test_split(train_size=100, test_size=20, seed=42)\n",
|
||||||
|
"train_dataset = train_test[\"train\"]\n",
|
||||||
|
"\n",
|
||||||
|
"test_dataset = load_dataset('parquet', data_files=[f'{DATA_DIR}/all_beauty_test.parquet'], split='train')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "Nm1O_IjBbNa0"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"sft_config = SFTConfig(\n",
|
||||||
|
" output_dir=\"./price-prediction-qlora\",\n",
|
||||||
|
" num_train_epochs=1,\n",
|
||||||
|
" per_device_train_batch_size=4,\n",
|
||||||
|
" gradient_accumulation_steps=2,\n",
|
||||||
|
" gradient_checkpointing=True,\n",
|
||||||
|
" optim=\"paged_adamw_8bit\",\n",
|
||||||
|
" learning_rate=2e-4,\n",
|
||||||
|
" lr_scheduler_type=\"cosine\",\n",
|
||||||
|
" warmup_steps=50,\n",
|
||||||
|
" logging_steps=10,\n",
|
||||||
|
" save_strategy=\"no\",\n",
|
||||||
|
" fp16=False,\n",
|
||||||
|
" bf16=True,\n",
|
||||||
|
" max_grad_norm=0.3,\n",
|
||||||
|
" save_total_limit=2,\n",
|
||||||
|
" group_by_length=True,\n",
|
||||||
|
" report_to=\"none\",\n",
|
||||||
|
" packing=False,\n",
|
||||||
|
" dataset_text_field=\"text\",\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"trainer = SFTTrainer(\n",
|
||||||
|
" model=model,\n",
|
||||||
|
" args=sft_config,\n",
|
||||||
|
" train_dataset=train_dataset,\n",
|
||||||
|
")\n",
|
||||||
|
"\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "a8Nx8GJTb-Wz"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"### Start Training"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "UmH5E6Xvn8so"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"print(\"Starting training...\")\n",
|
||||||
|
"\n",
|
||||||
|
"trainer.train()\n",
|
||||||
|
"\n",
|
||||||
|
"trainer.model.save_pretrained(\"./price-prediction-final\")\n",
|
||||||
|
"tokenizer.save_pretrained(\"./price-prediction-final\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Training complete! LoRA adapters saved to ./price-prediction-final\")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "-nXZ5O_ifFVh"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"def predict_price_inmemory(prompt, model, tokenizer):\n",
|
||||||
|
"\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
"\n",
|
||||||
|
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
|
||||||
|
"\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):\n",
|
||||||
|
" outputs = model.generate(\n",
|
||||||
|
" **inputs,\n",
|
||||||
|
" max_new_tokens=10,\n",
|
||||||
|
" temperature=0.1,\n",
|
||||||
|
" do_sample=False,\n",
|
||||||
|
" pad_token_id=tokenizer.eos_token_id\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" result = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
|
||||||
|
"\n",
|
||||||
|
" if \"Price is $\" in result:\n",
|
||||||
|
" predicted = result.split(\"Price is $\")[-1].strip()\n",
|
||||||
|
"\n",
|
||||||
|
" import re\n",
|
||||||
|
" match = re.search(r'(\\d+\\.?\\d*)', predicted)\n",
|
||||||
|
" if match:\n",
|
||||||
|
" return match.group(1)\n",
|
||||||
|
" return predicted"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "LSvPhf-3fYaZ"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Eye-test Validation\n",
|
||||||
|
"\n",
|
||||||
|
"Not the best I know, but I wanted to go through the entire process myself and not enough time on my hands."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "m6L5CET_sXmx"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"for item in test_dataset.take(5):\n",
|
||||||
|
" prompt = item[\"text\"]\n",
|
||||||
|
" actual_price = item[\"price\"]\n",
|
||||||
|
"\n",
|
||||||
|
" predicted_price = float(predict_price_inmemory(prompt, model, tokenizer))\n",
|
||||||
|
" print(\"\\n\" + \"*\" * 80)\n",
|
||||||
|
" print(prompt)\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"Prediction: ${predicted_price}. Actual: ${actual_price}. Diff {abs(predicted_price - actual_price):,.2f}\")\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "6pSYhLn_kROQ"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## Loading Somewhere in the future\n",
|
||||||
|
"\n",
|
||||||
|
"It can even be loaded in a different notebook."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "OnJAD7YihyAD"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
|
||||||
|
"from peft import PeftModel\n",
|
||||||
|
"import torch\n",
|
||||||
|
"\n",
|
||||||
|
"bnb_config = BitsAndBytesConfig(\n",
|
||||||
|
" load_in_4bit=True,\n",
|
||||||
|
" bnb_4bit_use_double_quant=True,\n",
|
||||||
|
" bnb_4bit_quant_type=\"nf4\",\n",
|
||||||
|
" bnb_4bit_compute_dtype=torch.bfloat16\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"base_model = AutoModelForCausalLM.from_pretrained(\n",
|
||||||
|
" MODEL_NAME,\n",
|
||||||
|
" quantization_config=bnb_config,\n",
|
||||||
|
" device_map=\"auto\",\n",
|
||||||
|
" trust_remote_code=True\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"model = PeftModel.from_pretrained(base_model, \"./price-prediction-final\")\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(\"./price-prediction-final\")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "5RCClHQHijes"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"gpuType": "T4",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user