PyTorch prefetch or rather the lack of it

pytorch
How prefetch_factor did not help in streaming data
Author

Sachin Abeywardana

Published

February 13, 2022

Introduction

I had an issue at work where the questions was if I should stream the data from an S3 bucket via the Dataset class, or if I should download first and simply read it in. I was hoping that increasing prefetch_factor in dataloaders would increase the speed when streaming it via S3, and possibly even be an alternative to downloading. For anyone who has not come across this flag, it is meant to get prefetch_factor many batches while the GPU is busy. The purpose being that there is little to no downtime when taking up the next batch.

Streaming via Dataloaders

In order to stream data instead of opening from disk (as is done in many tutorials) the dataset class was setup as the following:

class Data(Dataset):
    def __init__(self, prefix, transform):
        self.prefix = "https://aft-vbi-pds.s3.amazonaws.com/bin-images"
        self.transform = transform
        
    def __len__(self):
        return 999
    
    def __getitem__(self, i):
        response = requests.get(self.prefix + f"/{i+1}.jpg")
        img = Image.open(BytesIO(response.content))
        return self.transform(img)

As shown in the experiments done in this kaggle kernel, prefetch_factor flag did not speed things in a meaningful manner. The results are summarisd below. For each iteration the following code snippet was run, where model is simply resnet18.

with torch.inference_mode():
    for img_batch in tqdm(dl):
        out = model(img_batch.to(device))
Settings Time Elapsed
num_workers = 2 04:02
num_workers = 2, prefetch_factor=8 03:57
num_workers = 8 1:01
num_workers = 8, prefetch_factor=8 1:01

All other parameters such as batch_size=32, pin_memory=True was held constant across all iterations.

Note that the reason we had 2 workers was due to the fact that this was the number given by multiprocessing.cpu_count(). However, going past that number in the last iteration and setting it at 8 gave the following ugly (repeated) warnings: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__

Downloading then reading in

As the title suggests, needless to say this was the fastest way to conduct this. However, downloading on itself can take a long time which would negate the lack of speed in pytorch dataloaders. The trick here is to use multiprocessing to download data as well. In the following case it took 55 seconds to download the data.

def download_file(i):
    image_url = PREFIX + f"/{i+1}.jpg"
    img_data = requests.get(image_url).content
    with open(DATA + f"{i+1}.jpg", "wb") as handler:
        handler.write(img_data)
        
    return DATA + f"{i+1}.jpg"
    
with mp.Pool(8) as pool:
    file_paths = list(tqdm(pool.imap(download_file, range(999)), total=999))

Note how I set the number of workers / threads (I confess I don’t know the difference) to 8 which is 4x greater than mp.cpu_count()

Using a simple Dataset class where we do Image.open to get the image, and setting num_workers=mp.cpu_count() (2 cores) we were able to run through the data in 6 seconds. Setting prefetch_factor=4 in this scenario actually slowed down the dataloader slightly to 7 seconds.

Conclusion

Simply due to the ugly warnings, I would say that downloading and reading in is the safest and fastest way to go. In a scenario where you do not have access to that much disk space, you would need to design a download - evaluate - delete - repeat cycle.

The disclaimer here is that in either case I had to go beyond the available cores to make this go fast. I’m not sure if this is safe and would be great if someone who understands threads/ processes comments on the safety of doing this.

Shameless Self Promotion

If you enjoyed the tutorial buy me a coffee, or better yet buy my course (usually 90% off).