398 lines
11 KiB
Plaintext
398 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "28a0673e-96b5-43f2-8a8b-bd033bf851b0",
|
|
"metadata": {},
|
|
"source": [
|
|
"# The Big Project begins!!\n",
|
|
"\n",
|
|
"## The Product Pricer\n",
|
|
"\n",
|
|
"A model that can estimate how much something costs, from its description.\n",
|
|
"\n",
|
|
"## Data Curation Part 1\n",
|
|
"\n",
|
|
"Today we'll begin our scrubbing and curating our dataset by focusing on a subset of the data: Home Appliances.\n",
|
|
"\n",
|
|
"The dataset is here: \n",
|
|
"https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023\n",
|
|
"\n",
|
|
"And the folder with all the product datasets is here: \n",
|
|
"https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/tree/main/raw/meta_categories"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "67cedf85-8125-4322-998e-9375fe745597",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# imports\n",
|
|
"\n",
|
|
"import os\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"from huggingface_hub import login\n",
|
|
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
|
"from items import Item\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7390a6aa-79cb-4dea-b6d7-de7e4b13e472",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# environment\n",
|
|
"\n",
|
|
"load_dotenv()\n",
|
|
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
|
|
"os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n",
|
|
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0732274a-aa6a-44fc-aee2-40dc8a8e4451",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Log in to HuggingFace\n",
|
|
"\n",
|
|
"hf_token = os.environ['HF_TOKEN']\n",
|
|
"login(hf_token, add_to_git_credential=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1adcf323-de9d-4c24-a9c3-d7ae554d06ca",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%matplotlib inline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "049885d4-fdfa-4ff0-a932-4a2ed73928e2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Load in our dataset\n",
|
|
"\n",
|
|
"dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", f\"raw_meta_Appliances\", split=\"full\", trust_remote_code=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cde08860-b393-49b8-a620-06a8c0990a64",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(f\"Number of Appliances: {len(dataset):,}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3e29a5ab-ca61-41cc-9b33-22d374681b85",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Investigate a particular datapoint\n",
|
|
"datapoint = dataset[2]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "40a4e10f-6710-4780-a95e-6c0030c3fb87",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Investigate\n",
|
|
"\n",
|
|
"print(datapoint[\"title\"])\n",
|
|
"print(datapoint[\"description\"])\n",
|
|
"print(datapoint[\"features\"])\n",
|
|
"print(datapoint[\"details\"])\n",
|
|
"print(datapoint[\"price\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9d356c6f-b6e8-4e01-98cd-c562d132aafa",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# How many have prices?\n",
|
|
"\n",
|
|
"prices = 0\n",
|
|
"for datapoint in dataset:\n",
|
|
" try:\n",
|
|
" price = float(datapoint[\"price\"])\n",
|
|
" if price > 0:\n",
|
|
" prices += 1\n",
|
|
" except ValueError as e:\n",
|
|
" pass\n",
|
|
"\n",
|
|
"print(f\"There are {prices:,} with prices which is {prices/len(dataset)*100:,.1f}%\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "bd890259-aa25-4097-9524-f91c2bdd719b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# For those with prices, gather the price and the length\n",
|
|
"\n",
|
|
"prices = []\n",
|
|
"lengths = []\n",
|
|
"for datapoint in dataset:\n",
|
|
" try:\n",
|
|
" price = float(datapoint[\"price\"])\n",
|
|
" if price > 0:\n",
|
|
" prices.append(price)\n",
|
|
" contents = datapoint[\"title\"] + str(datapoint[\"description\"]) + str(datapoint[\"features\"]) + str(datapoint[\"details\"])\n",
|
|
" lengths.append(len(contents))\n",
|
|
" except ValueError as e:\n",
|
|
" pass"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "89078cb1-9679-4eb0-b295-599b8586bcd1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Plot the distribution of lengths\n",
|
|
"\n",
|
|
"plt.figure(figsize=(15, 6))\n",
|
|
"plt.title(f\"Lengths: Avg {sum(lengths)/len(lengths):,.0f} and highest {max(lengths):,}\\n\")\n",
|
|
"plt.xlabel('Length (chars)')\n",
|
|
"plt.ylabel('Count')\n",
|
|
"plt.hist(lengths, rwidth=0.7, color=\"lightblue\", bins=range(0, 6000, 100))\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c38e0c43-9f7a-450e-a911-c94d37d9b9c3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Plot the distribution of prices\n",
|
|
"\n",
|
|
"plt.figure(figsize=(15, 6))\n",
|
|
"plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.2f} and highest {max(prices):,}\\n\")\n",
|
|
"plt.xlabel('Price ($)')\n",
|
|
"plt.ylabel('Count')\n",
|
|
"plt.hist(prices, rwidth=0.7, color=\"orange\", bins=range(0, 1000, 10))\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "eabc7c61-0cd2-41f4-baa1-b85400bbf87f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# So what is this item??\n",
|
|
"\n",
|
|
"for datapoint in dataset:\n",
|
|
" try:\n",
|
|
" price = float(datapoint[\"price\"])\n",
|
|
" if price > 21000:\n",
|
|
" print(datapoint['title'])\n",
|
|
" except ValueError as e:\n",
|
|
" pass"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3668ae25-3461-4e6e-9ccb-221c1925a497",
|
|
"metadata": {},
|
|
"source": [
|
|
"This is the closest I can find - looks like it's going at a bargain price!!\n",
|
|
"\n",
|
|
"https://www.amazon.com/TurboChef-Electric-Countertop-Microwave-Convection/dp/B01D05U9NO/"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a0d02f58-23f6-4f81-a779-7c0555afd13d",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Now it's time to curate our dataset\n",
|
|
"\n",
|
|
"We select items that cost between 1 and 999 USD\n",
|
|
"\n",
|
|
"We will be create Item instances, which truncate the text to fit within 180 tokens using the right Tokenizer\n",
|
|
"\n",
|
|
"And will create a prompt to be used during Training.\n",
|
|
"\n",
|
|
"Items will be rejected if they don't have sufficient characters."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "430b432f-b769-41da-9506-a238cb5cf1b6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create an Item object for each with a price\n",
|
|
"\n",
|
|
"items = []\n",
|
|
"for datapoint in dataset:\n",
|
|
" try:\n",
|
|
" price = float(datapoint[\"price\"])\n",
|
|
" if price > 0:\n",
|
|
" item = Item(datapoint, price)\n",
|
|
" if item.include:\n",
|
|
" items.append(item)\n",
|
|
" except ValueError as e:\n",
|
|
" pass\n",
|
|
"\n",
|
|
"print(f\"There are {len(items):,} items\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0d570794-6f1d-462e-b567-a46bae3556a1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Look at the first item\n",
|
|
"\n",
|
|
"items[1]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "70219e99-22cc-4e08-9121-51f9707caef0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Investigate the prompt that will be used during training - the model learns to complete this\n",
|
|
"\n",
|
|
"print(items[100].prompt)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d9998b8d-d746-4541-9ac2-701108e0e8fb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Investigate the prompt that will be used during testing - the model has to complete this\n",
|
|
"\n",
|
|
"print(items[100].test_prompt())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7a116369-335a-412b-b70c-2add6675c2e3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Plot the distribution of token counts\n",
|
|
"\n",
|
|
"tokens = [item.token_count for item in items]\n",
|
|
"plt.figure(figsize=(15, 6))\n",
|
|
"plt.title(f\"Token counts: Avg {sum(tokens)/len(tokens):,.1f} and highest {max(tokens):,}\\n\")\n",
|
|
"plt.xlabel('Length (tokens)')\n",
|
|
"plt.ylabel('Count')\n",
|
|
"plt.hist(tokens, rwidth=0.7, color=\"green\", bins=range(0, 300, 10))\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8d1744aa-71e7-435e-876e-91f06583211a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Plot the distribution of prices\n",
|
|
"\n",
|
|
"prices = [item.price for item in items]\n",
|
|
"plt.figure(figsize=(15, 6))\n",
|
|
"plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.1f} and highest {max(prices):,}\\n\")\n",
|
|
"plt.xlabel('Price ($)')\n",
|
|
"plt.ylabel('Count')\n",
|
|
"plt.hist(prices, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2b58dc61-747f-46f7-b9e0-c205db4f3e5e",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Sidenote\n",
|
|
"\n",
|
|
"If you like the variety of colors that matplotlib can use in its charts, you should bookmark this:\n",
|
|
"\n",
|
|
"https://matplotlib.org/stable/gallery/color/named_colors.html\n",
|
|
"\n",
|
|
"## Todos for you:\n",
|
|
"\n",
|
|
"- Review the Item class and check you're comfortable with it\n",
|
|
"- Examine some Item objects, look at the training prompt with `item.prompt` and test prompt with `item.test_prompt()`\n",
|
|
"- Make some more histograms to better understand the data\n",
|
|
"\n",
|
|
"## Next time we will combine with many other types of product\n",
|
|
"\n",
|
|
"Like Electronics and Automotive. This will give us a massive dataset, and we can then be picky about choosing a subset that will be most suitable for training."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "01401283-d111-40a7-96e5-0ca05bb20857",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|