676 lines
23 KiB
Plaintext
676 lines
23 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "91ae778b"
|
|
},
|
|
"source": [
|
|
"# Getting Started\n",
|
|
"\n",
|
|
"Before running this notebook, please ensure you have the following:\n",
|
|
"\n",
|
|
"1. **Local Modules:** Upload the necessary local Python files (`items.py`, `loaders.py`, `testing.py`) to the Colab runtime's temporary storage. You can do this by clicking the folder icon on the left sidebar, then the upload icon, and selecting the files.\n",
|
|
"2. **Hugging Face Access Token:** Add your Hugging Face access token to Colab's user data secrets. Click the key icon on the left sidebar, click \"New secret\", and add your token with the name `HF_TOKEN`.\n",
|
|
"3. **Install Dependencies:** Run the first code cell to install the required libraries with the specified versions.\n",
|
|
"\n",
|
|
"Once these steps are completed, you can run the rest of the notebook cells sequentially."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "_fj3pImYM5dw"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Install exact versions from local environment to match the course's environment\n",
|
|
"!pip install --upgrade pip\n",
|
|
"\n",
|
|
"# Install specific versions of required libraries\n",
|
|
"!pip install datasets==3.6.0\n",
|
|
"!pip install transformers==4.51.3\n",
|
|
"!pip install huggingface_hub==0.31.2\n",
|
|
"!pip install matplotlib==3.10.3\n",
|
|
"!pip install numpy==1.26.4\n",
|
|
"!pip install python-dotenv==1.1.0\n",
|
|
"!pip install tqdm==4.67.1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Import necessary libraries\n",
|
|
"import os\n",
|
|
"import random\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"from huggingface_hub import login\n",
|
|
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from collections import Counter, defaultdict\n",
|
|
"import numpy as np\n",
|
|
"import pickle"
|
|
],
|
|
"metadata": {
|
|
"id": "YQHruTKgPMRX"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Retrieve the Hugging Face access token from Colab's user data secrets\n",
|
|
"# This token is needed to interact with the Hugging Face Hub\n",
|
|
"from google.colab import userdata\n",
|
|
"userdata.get('HF_TOKEN')"
|
|
],
|
|
"metadata": {
|
|
"id": "jBdHkdyVNj_f"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Import custom classes from local files (items.py and loaders.py)\n",
|
|
"# These files were manually added to the Colab runtime's temporary storage\n",
|
|
"from loaders import ItemLoader\n",
|
|
"from items import Item"
|
|
],
|
|
"metadata": {
|
|
"id": "FdBT3PPzNmq3"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Set the backend for matplotlib to display plots inline in the notebook\n",
|
|
"%matplotlib inline"
|
|
],
|
|
"metadata": {
|
|
"id": "vynEBaq6OGEZ"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Load a single dataset (\"All_Beauty\") using the custom ItemLoader\n",
|
|
"# This was likely an initial test or example loading step\n",
|
|
"items = ItemLoader(\"Appliances\").load()"
|
|
],
|
|
"metadata": {
|
|
"id": "OFOJtH6FOG2u"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Define a list of dataset names (Amazon product categories) to be loaded\n",
|
|
"dataset_names = [\n",
|
|
" \"Automotive\",\n",
|
|
" \"Electronics\",\n",
|
|
" \"Office_Products\",\n",
|
|
" \"Tools_and_Home_Improvement\",\n",
|
|
" \"Cell_Phones_and_Accessories\",\n",
|
|
" \"Toys_and_Games\",\n",
|
|
" \"Appliances\",\n",
|
|
" \"Musical_Instruments\",\n",
|
|
"]"
|
|
],
|
|
"metadata": {
|
|
"id": "rkLXYtfhOJNn"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Check and print the available CPU cores and RAM in the Colab runtime environment\n",
|
|
"# This helps understand the resources available for data processing\n",
|
|
"import psutil\n",
|
|
"print(f\"CPU cores: {psutil.cpu_count()}\")\n",
|
|
"print(f\"Available RAM: {psutil.virtual_memory().available / (1024**3):.1f} GB\")"
|
|
],
|
|
"metadata": {
|
|
"id": "1oQSUpovOfKs"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"items = []\n",
|
|
"for dataset_name in dataset_names:\n",
|
|
" loader = ItemLoader(dataset_name)\n",
|
|
" items.extend(loader.load(workers=8))\n",
|
|
"\n",
|
|
"# Now, time for a coffee break!!\n",
|
|
"# By the way, I put the biggest datasets first.. it gets faster."
|
|
],
|
|
"metadata": {
|
|
"id": "UcV9RB2Go8nC"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Print the total number of items loaded from all datasets\n",
|
|
"print(f\"A grand total of {len(items):,} items\")"
|
|
],
|
|
"metadata": {
|
|
"id": "YdkGJ_X3oI1g"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Extract token counts from all loaded items\n",
|
|
"tokens = [item.token_count for item in items]\n",
|
|
"# Create and display a histogram of token counts\n",
|
|
"plt.figure(figsize=(15, 6))\n",
|
|
"plt.title(f\"Token counts: Avg {sum(tokens)/len(tokens):,.1f} and highest {max(tokens):,}\\n\")\n",
|
|
"plt.xlabel('Length (tokens)')\n",
|
|
"plt.ylabel('Count')\n",
|
|
"plt.hist(tokens, rwidth=0.7, color=\"skyblue\", bins=range(0, 300, 10))\n",
|
|
"plt.show()"
|
|
],
|
|
"metadata": {
|
|
"id": "8VzKJ7neo-wp"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Extract prices from all loaded items\n",
|
|
"prices = [item.price for item in items]\n",
|
|
"# Create and display a histogram of item prices\n",
|
|
"plt.figure(figsize=(15, 6))\n",
|
|
"plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.1f} and highest {max(prices):,}\\n\")\n",
|
|
"plt.xlabel('Price ($)')\n",
|
|
"plt.ylabel('Count')\n",
|
|
"plt.hist(prices, rwidth=0.7, color=\"blueviolet\", bins=range(0, 1000, 10))\n",
|
|
"plt.show()"
|
|
],
|
|
"metadata": {
|
|
"id": "ZLFJycNZpDak"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Count the occurrences of each category in the loaded items\n",
|
|
"category_counts = Counter()\n",
|
|
"for item in items:\n",
|
|
" category_counts[item.category]+=1\n",
|
|
"\n",
|
|
"# Extract categories and their counts for plotting\n",
|
|
"categories = category_counts.keys()\n",
|
|
"counts = [category_counts[category] for category in categories]\n",
|
|
"\n",
|
|
"# Create and display a bar chart showing the count of items per category\n",
|
|
"plt.figure(figsize=(15, 6))\n",
|
|
"plt.bar(categories, counts, color=\"goldenrod\")\n",
|
|
"plt.title('How many in each category')\n",
|
|
"plt.xlabel('Categories')\n",
|
|
"plt.ylabel('Count')\n",
|
|
"\n",
|
|
"# Rotate x-axis labels for better readability\n",
|
|
"plt.xticks(rotation=30, ha='right')\n",
|
|
"\n",
|
|
"# Add value labels on top of each bar for clarity\n",
|
|
"for i, v in enumerate(counts):\n",
|
|
" plt.text(i, v, f\"{v:,}\", ha='center', va='bottom')\n",
|
|
"\n",
|
|
"# Display the chart\n",
|
|
"plt.show()"
|
|
],
|
|
"metadata": {
|
|
"id": "6oRa8rI6pGb0"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Create a dictionary where keys are rounded prices and values are lists of items with that price\n",
|
|
"# This is done to group items by price for sampling\n",
|
|
"slots = defaultdict(list)\n",
|
|
"for item in items:\n",
|
|
" slots[round(item.price)].append(item)"
|
|
],
|
|
"metadata": {
|
|
"id": "7mT5ZubkpJ06"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Create a curated sample dataset with a more even distribution of prices and reduced bias towards 'Automotive' category\n",
|
|
"# Items with price >= $240 are included entirely\n",
|
|
"# For prices < $240, if the number of items is <= 1200, all are included\n",
|
|
"# If the number of items > 1200, a weighted random sample of 1200 items is taken,\n",
|
|
"# giving non-Automotive items higher weight\n",
|
|
"\n",
|
|
"# Set random seeds for reproducibility\n",
|
|
"np.random.seed(42)\n",
|
|
"random.seed(42)\n",
|
|
"sample = []\n",
|
|
"for i in range(1, 1000):\n",
|
|
" slot = slots[i]\n",
|
|
" if i>=240:\n",
|
|
" sample.extend(slot)\n",
|
|
" elif len(slot) <= 1200:\n",
|
|
" sample.extend(slot)\n",
|
|
" else:\n",
|
|
" # Assign weights: 1 for 'Automotive', 5 for other categories\n",
|
|
" weights = np.array([1 if item.category=='Automotive' else 5 for item in slot])\n",
|
|
" # Normalize weights\n",
|
|
" weights = weights / np.sum(weights)\n",
|
|
" # Randomly select 1200 indices based on weights\n",
|
|
" selected_indices = np.random.choice(len(slot), size=1200, replace=False, p=weights)\n",
|
|
" # Select the items corresponding to the chosen indices\n",
|
|
" selected = [slot[i] for i in selected_indices]\n",
|
|
" sample.extend(selected)\n",
|
|
"\n",
|
|
"# Print the total number of items in the curated sample\n",
|
|
"print(f\"There are {len(sample):,} items in the sample\")"
|
|
],
|
|
"metadata": {
|
|
"id": "qHJdXynopMBp"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Extract prices from the curated sample\n",
|
|
"prices = [float(item.price) for item in sample]\n",
|
|
"# Create and display a histogram of prices for the sample dataset\n",
|
|
"# This helps visualize the effect of the sampling process on the price distribution\n",
|
|
"plt.figure(figsize=(15, 10))\n",
|
|
"plt.title(f\"Avg {sum(prices)/len(prices):.2f} and highest {max(prices):,.2f}\\n\")\n",
|
|
"plt.xlabel('Price ($)')\n",
|
|
"plt.ylabel('Count')\n",
|
|
"plt.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 10))\n",
|
|
"plt.show()"
|
|
],
|
|
"metadata": {
|
|
"id": "gtBkOdPGpOou"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Count the occurrences of each category in the curated sample\n",
|
|
"category_counts = Counter()\n",
|
|
"for item in sample:\n",
|
|
" category_counts[item.category]+=1\n",
|
|
"\n",
|
|
"# Extract categories and their counts for plotting\n",
|
|
"categories = category_counts.keys()\n",
|
|
"counts = [category_counts[category] for category in categories]\n",
|
|
"\n",
|
|
"# Create and display a bar chart showing the count of items per category in the sample\n",
|
|
"# This helps visualize the effect of weighted sampling on category distribution\n",
|
|
"plt.figure(figsize=(15, 6))\n",
|
|
"plt.bar(categories, counts, color=\"lightgreen\")\n",
|
|
"\n",
|
|
"# Customize the chart\n",
|
|
"plt.title('How many in each category')\n",
|
|
"plt.xlabel('Categories')\n",
|
|
"plt.ylabel('Count')\n",
|
|
"\n",
|
|
"# Rotate x-axis labels for better readability\n",
|
|
"plt.xticks(rotation=30, ha='right')\n",
|
|
"\n",
|
|
"# Add value labels on top of each bar for clarity\n",
|
|
"for i, v in enumerate(counts):\n",
|
|
" plt.text(i, v, f\"{v:,}\", ha='center', va='bottom')\n",
|
|
"\n",
|
|
"# Display the chart\n",
|
|
"plt.show()"
|
|
],
|
|
"metadata": {
|
|
"id": "-lYpt40jpTE1"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Create and display a pie chart showing the percentage distribution of items across categories in the sample\n",
|
|
"plt.figure(figsize=(12, 10))\n",
|
|
"plt.pie(counts, labels=categories, autopct='%1.0f%%', startangle=90)\n",
|
|
"\n",
|
|
"# Add a circle at the center to create a donut chart (optional)\n",
|
|
"centre_circle = plt.Circle((0,0), 0.70, fc='white')\n",
|
|
"fig = plt.gcf()\n",
|
|
"fig.gca().add_artist(centre_circle)\n",
|
|
"plt.title('Categories')\n",
|
|
"\n",
|
|
"# Equal aspect ratio ensures that pie is drawn as a circle\n",
|
|
"plt.axis('equal')\n",
|
|
"\n",
|
|
"plt.show()"
|
|
],
|
|
"metadata": {
|
|
"id": "5QPV4m2LpV3g"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Markdown cell indicates that the dataset curation is complete and ready for final checks\n",
|
|
"# Dataset Curated!\n",
|
|
"\n",
|
|
"# We've crafted an excellent dataset.\n",
|
|
"\n",
|
|
"# Let's do some final checks"
|
|
],
|
|
"metadata": {
|
|
"id": "3Xc2ZxjapZ0a"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Extract prompt lengths (character counts) and prices from the curated sample\n",
|
|
"sizes = [len(item.prompt) for item in sample]\n",
|
|
"prices = [item.price for item in sample]\n",
|
|
"\n",
|
|
"# Create and display a scatter plot to visualize the relationship between prompt size and price\n",
|
|
"# This helps check for any simple correlation between the two\n",
|
|
"plt.figure(figsize=(15, 8))\n",
|
|
"plt.scatter(sizes, prices, s=0.2, color=\"red\")\n",
|
|
"\n",
|
|
"# Add labels and title\n",
|
|
"plt.xlabel('Size')\n",
|
|
"plt.ylabel('Price')\n",
|
|
"plt.title('Is there a simple correlation?')\n",
|
|
"\n",
|
|
"# Display the plot\n",
|
|
"plt.show()"
|
|
],
|
|
"metadata": {
|
|
"id": "VXYQkVarpceE"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Define a helper function to report information about an item\n",
|
|
"# It prints the item's prompt, the last 10 token IDs, and the decoded last 10 tokens\n",
|
|
"def report(item):\n",
|
|
" prompt = item.prompt\n",
|
|
" tokens = Item.tokenizer.encode(item.prompt)\n",
|
|
" print(prompt)\n",
|
|
" print(tokens[-10:])\n",
|
|
" print(Item.tokenizer.batch_decode(tokens[-10:]))"
|
|
],
|
|
"metadata": {
|
|
"id": "1BBJNDAKpgL_"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Use the report function to display information about a specific item in the sample\n",
|
|
"# This helps inspect the data and the tokenizer's behavior\n",
|
|
"report(sample[398000])"
|
|
],
|
|
"metadata": {
|
|
"id": "ZO2zF09wpiPp"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Observation\n",
|
|
"\n",
|
|
"An interesting thing about the Llama tokenizer is that every number from 1 to 999 gets mapped to 1 token, much as we saw with gpt-4o. The same is not true of qwen2, gemma and phi3, which all map individual digits to tokens. This does turn out to be a bit useful for our project, although it's not an essential requirement."
|
|
],
|
|
"metadata": {
|
|
"id": "GCkwmt_VpsaU"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# Finally\n",
|
|
"\n",
|
|
"It's time to break down our data into a training, test and validation dataset.\n",
|
|
"\n",
|
|
"It's typical to use 5%-10% of your data for testing purposes, but actually we have far more than we need at this point. We'll take 400,000 points for training, and we'll reserve 2,000 for testing, although we won't use all of them.\n"
|
|
],
|
|
"metadata": {
|
|
"id": "dy6WGVAmpx0g"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Set random seed for reproducibility before shuffling and splitting the sample\n",
|
|
"random.seed(42)\n",
|
|
"# Shuffle the curated sample dataset\n",
|
|
"random.shuffle(sample)\n",
|
|
"# Split the shuffled sample into training (400,000 items) and testing (2,000 items) sets\n",
|
|
"train = sample[:400_000]\n",
|
|
"test = sample[400_000:402_000]\n",
|
|
"# Print the sizes of the training and testing sets\n",
|
|
"print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")"
|
|
],
|
|
"metadata": {
|
|
"id": "oY1ZSkW7p0VS"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Extract prices from the first 250 items of the test set\n",
|
|
"prices = [float(item.price) for item in test[:250]]\n",
|
|
"# Create and display a histogram of prices for the first 250 test items\n",
|
|
"# This provides a quick look at the price distribution in a small portion of the test data\n",
|
|
"plt.figure(figsize=(15, 6))\n",
|
|
"plt.title(f\"Avg {sum(prices)/len(prices):.2f} and highest {max(prices):,.2f}\\n\")\n",
|
|
"plt.xlabel('Price ($)')\n",
|
|
"plt.ylabel('Count')\n",
|
|
"plt.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 10))\n",
|
|
"plt.show()"
|
|
],
|
|
"metadata": {
|
|
"id": "nLnRpUbtp17N"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Extract prompts from the training set\n",
|
|
"train_prompts = [item.prompt for item in train]\n",
|
|
"# Extract prices from the training set\n",
|
|
"train_prices = [item.price for item in train]\n",
|
|
"# Extract test prompts (using the test_prompt method) from the test set\n",
|
|
"test_prompts = [item.test_prompt() for item in test]\n",
|
|
"# Extract prices from the test set\n",
|
|
"test_prices = [item.price for item in test]"
|
|
],
|
|
"metadata": {
|
|
"id": "kpw1Y8IIp6kw"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Create Hugging Face Dataset objects from the training and testing data\n",
|
|
"train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n",
|
|
"test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n",
|
|
"# Create a DatasetDict containing the training and testing datasets\n",
|
|
"dataset = DatasetDict({\n",
|
|
" \"train\": train_dataset,\n",
|
|
" \"test\": test_dataset\n",
|
|
"})"
|
|
],
|
|
"metadata": {
|
|
"id": "WtEFiTlvp8hL"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Push the created DatasetDict to the Hugging Face Hub\n",
|
|
"# Replace \"aaron-official\" with your Hugging Face username\n",
|
|
"# The dataset will be named \"your-username/pricer-data\" and will be private\n",
|
|
"# HF_USER = \"aaron-official\" # Uncomment and replace with your HF username\n",
|
|
"# DATASET_NAME = f\"{HF_USER}/pricer-data\" # Uncomment\n",
|
|
"# dataset.push_to_hub(DATASET_NAME, private=True) # Uncomment to push to hub"
|
|
],
|
|
"metadata": {
|
|
"id": "sSnwZIxHp-VJ"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Pickle (serialize) the training and testing datasets and save them as files\n",
|
|
"# This allows for quick loading of the processed data in future sessions\n",
|
|
"with open('train.pkl', 'wb') as file:\n",
|
|
" pickle.dump(train, file)\n",
|
|
"\n",
|
|
"with open('test.pkl', 'wb') as file:\n",
|
|
" pickle.dump(test, file)"
|
|
],
|
|
"metadata": {
|
|
"id": "WRawIsrOqMQ-"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "bd72e246"
|
|
},
|
|
"source": [
|
|
"# Mount Google Drive to access files in your Drive\n",
|
|
"from google.colab import drive\n",
|
|
"drive.mount('/content/drive')"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "6fc5d915"
|
|
},
|
|
"source": [
|
|
"Once your Google Drive is mounted, you can copy the file to a folder in your Drive. Replace `My Drive/your_folder_name` with the path to the folder where you want to save the file."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "f319129b"
|
|
},
|
|
"source": [
|
|
"# Import the shutil module for file operations\n",
|
|
"import shutil\n",
|
|
"\n",
|
|
"# Define the destination path in Google Drive and the source path of the pickled training data\n",
|
|
"# Replace 'My Drive/your_folder_name' with your desired folder path in Google Drive\n",
|
|
"destination_path = '/content/drive/My Drive/train.pkl'\n",
|
|
"source_path = '/content/train.pkl'\n",
|
|
"\n",
|
|
"# Copy the pickled training data file from the Colab environment to Google Drive\n",
|
|
"shutil.copyfile(source_path, destination_path)\n",
|
|
"\n",
|
|
"# Print a confirmation message\n",
|
|
"print(f\"Copied {source_path} to {destination_path}\")"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "d23d6cf0"
|
|
},
|
|
"source": [
|
|
"# Import the shutil module for file operations\n",
|
|
"import shutil\n",
|
|
"\n",
|
|
"# Define the destination path in Google Drive and the source path of the pickled test data\n",
|
|
"# Replace 'My Drive/your_folder_name' with your desired folder path in Google Drive\n",
|
|
"destination_path = '/content/drive/My Drive/test.pkl'\n",
|
|
"source_path = '/content/test.pkl'\n",
|
|
"\n",
|
|
"# Copy the pickled test data file from the Colab environment to Google Drive\n",
|
|
"shutil.copyfile(source_path, destination_path)\n",
|
|
"\n",
|
|
"# Print a confirmation message\n",
|
|
"print(f\"Copied {source_path} to {destination_path}\")"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
}
|
|
]
|
|
} |