PyTorch prefetch or rather the lack of it
Introduction
I had an issue at work where the questions was if I should stream the data from an S3 bucket via the Dataset
class, or if I should download first and simply read it in. I was hoping that increasing prefetch_factor
in dataloaders would increase the speed when streaming it via S3, and possibly even be an alternative to downloading. For anyone who has not come across this flag, it is meant to get prefetch_factor
many batches while the GPU is busy. The purpose being that there is little to no downtime when taking up the next batch.
Streaming via Dataloaders
In order to stream data instead of opening from disk (as is done in many tutorials) the dataset class was setup as the following:
class Data(Dataset):
def __init__(self, prefix, transform):
self.prefix = "https://aft-vbi-pds.s3.amazonaws.com/bin-images"
self.transform = transform
def __len__(self):
return 999
def __getitem__(self, i):
= requests.get(self.prefix + f"/{i+1}.jpg")
response = Image.open(BytesIO(response.content))
img return self.transform(img)
As shown in the experiments done in this kaggle kernel, prefetch_factor
flag did not speed things in a meaningful manner. The results are summarisd below. For each iteration the following code snippet was run, where model is simply resnet18.
with torch.inference_mode():
for img_batch in tqdm(dl):
= model(img_batch.to(device)) out
Settings | Time Elapsed |
---|---|
num_workers = 2 |
04:02 |
num_workers = 2 , prefetch_factor=8 |
03:57 |
num_workers = 8 |
1:01 |
num_workers = 8 , prefetch_factor=8 |
1:01 |
All other parameters such as batch_size=32
, pin_memory=True
was held constant across all iterations.
Note that the reason we had 2 workers was due to the fact that this was the number given by multiprocessing.cpu_count()
. However, going past that number in the last iteration and setting it at 8 gave the following ugly (repeated) warnings: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__
Downloading then reading in
As the title suggests, needless to say this was the fastest way to conduct this. However, downloading on itself can take a long time which would negate the lack of speed in pytorch dataloaders. The trick here is to use multiprocessing to download data as well. In the following case it took 55 seconds to download the data.
def download_file(i):
= PREFIX + f"/{i+1}.jpg"
image_url = requests.get(image_url).content
img_data with open(DATA + f"{i+1}.jpg", "wb") as handler:
handler.write(img_data)
return DATA + f"{i+1}.jpg"
with mp.Pool(8) as pool:
= list(tqdm(pool.imap(download_file, range(999)), total=999)) file_paths
Note how I set the number of workers / threads (I confess I don’t know the difference) to 8 which is 4x greater than mp.cpu_count()
Using a simple Dataset
class where we do Image.open
to get the image, and setting num_workers=mp.cpu_count()
(2 cores) we were able to run through the data in 6 seconds. Setting prefetch_factor=4
in this scenario actually slowed down the dataloader slightly to 7 seconds.
Conclusion
Simply due to the ugly warnings, I would say that downloading and reading in is the safest and fastest way to go. In a scenario where you do not have access to that much disk space, you would need to design a download - evaluate - delete - repeat cycle.
The disclaimer here is that in either case I had to go beyond the available cores to make this go fast. I’m not sure if this is safe and would be great if someone who understands threads/ processes comments on the safety of doing this.
Shameless Self Promotion
If you enjoyed the tutorial buy me a coffee, or better yet buy my course (usually 90% off).