106 lines
3.4 KiB
Python
106 lines
3.4 KiB
Python
from datetime import datetime # Measure how long loading takes
|
|
from tqdm import tqdm # Shows a progress bar while processing data
|
|
from datasets import load_dataset # Load a dataset from Hugging Face Hub
|
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor # For parallel processing (speed)
|
|
from items import Item
|
|
|
|
CHUNK_SIZE = 1000 # Process the dataset in chunks of 1000 datapoints at a time (for efficiency)
|
|
MIN_PRICE = 0.5
|
|
MAX_PRICE = 999.49
|
|
WORKER = 4 # Set the number of workers here
|
|
|
|
class ItemLoader:
|
|
|
|
def __init__(self, name):
|
|
"""
|
|
Initialize the loader with a dataset name.
|
|
"""
|
|
self.name = name # Store the category name
|
|
self.dataset = None #Placeholder for the dataset (we load it later in load())
|
|
|
|
def process_chunk(self, chunk):
|
|
"""
|
|
Convert a chunk of datapoints into valid Item objects.
|
|
"""
|
|
batch = [] # Initialize the list to hold valid items
|
|
|
|
# Loop through each datapoint in the chunk
|
|
for datapoint in chunk:
|
|
try:
|
|
# Extract price from datapoint
|
|
price_str = datapoint['price']
|
|
if price_str:
|
|
price = float(price_str)
|
|
|
|
# Check if price is within valid range
|
|
if MIN_PRICE <= price <= MAX_PRICE:
|
|
item = Item(datapoint, price)
|
|
|
|
# Keep only valid items
|
|
if item.include:
|
|
batch.append(item)
|
|
except ValueError:
|
|
continue # Skip datapoints with invalid price format
|
|
return batch # Return the list of valid items
|
|
|
|
|
|
def load_in_parallel(self, workers):
|
|
"""
|
|
Split the dataset into chunks and process them in parallel.
|
|
"""
|
|
results = []
|
|
size = len(self.dataset)
|
|
chunk_count = (size // CHUNK_SIZE) + 1
|
|
|
|
# Build chunks directly here (no separate function)
|
|
chunks = [
|
|
self.dataset.select(range(i, min(i + CHUNK_SIZE, size)))
|
|
for i in range(0, size, CHUNK_SIZE)
|
|
]
|
|
|
|
# Process chunks in parallel using multiple CPU cores
|
|
with ProcessPoolExecutor(max_workers=workers) as pool:
|
|
for batch in tqdm(pool.map(self.process_chunk, chunks), total=chunk_count):
|
|
results.extend(batch)
|
|
|
|
# Add the category name to each result
|
|
for result in results:
|
|
result.category = self.name
|
|
|
|
return results
|
|
|
|
|
|
def load(self, workers=WORKER):
|
|
"""
|
|
Load and process the dataset, returning valid items.
|
|
"""
|
|
# Record start time
|
|
start = datetime.now()
|
|
|
|
# Print loading message
|
|
print(f"Loading dataset {self.name}", flush=True)
|
|
|
|
# Load dataset from Hugging Face (based on category name)
|
|
self.dataset = load_dataset(
|
|
"McAuley-Lab/Amazon-Reviews-2023",
|
|
f"raw_meta_{self.name}",
|
|
split="full",
|
|
trust_remote_code=True
|
|
)
|
|
|
|
# Process the dataset in parallel and collect valid items
|
|
results = self.load_in_parallel(workers)
|
|
|
|
# Record end time and print summary
|
|
finish = datetime.now()
|
|
print(
|
|
f"Completed {self.name} with {len(results):,} datapoints in {(finish-start).total_seconds()/60:.1f} mins",
|
|
flush=True
|
|
)
|
|
|
|
# Return the list of valid items
|
|
return results
|
|
|
|
|
|
|
|
|