Collate function tutorial

pytorch
PyTorch Collate function tutorial
Published

June 5, 2021

Suppose we have the following hypothetical dataset.

class Dataset:
    def __init__(self):
        super().__init__()
        
    def __len__(self):
        return 32
    
    def __getitem__(self, idx):
        return f"hello {idx}", random.randint(0, 3)
    
rand_ds = Dataset()
rand_dl = DataLoader(rand_ds, batch_size=4)

Printing out the first batch, notice how the first element is just a tuple of strings and the second item has automagically been converted into a tensor.

next(iter(rand_dl))
[('hello 0', 'hello 1', 'hello 2', 'hello 3'), tensor([0, 1, 1, 0])]

The Collate Function

With the collate function we can convert these strings to a tensor as well. This leads to cleaner code in that data preprocessing is kept away from model code. In my case it actually led to a slightly faster run time per epoch, but I’m not entirely sure why.

The following code takes in a list of size batch size, where each element is a string and it’s corresponding label. Then it parses the strings through the tokenizer, which converts into numerical values thanks to the huggingface tokenizer. But more importantly, note how now you have to convert the y to torch.LongTensor, as otherwise it would remain a tuple. This is certainly an extra step that pytorch was internally taking care of for you.

class CollateFn:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

    def __call__(
        self, batch: List[Tuple[str, int]]
    ) -> Tuple[Dict[str, torch.LongTensor], torch.LongTensor]:
        x, y = zip(*batch)
        return self.tokenizer(list(x)), torch.LongTensor(y)

We can add an instance of the above class to our dataloader, which leads us to the following results:

collate_fn = CollateFn()
rand_dl = DataLoader(rand_ds, batch_size=4, collate_fn=collate_fn)
next(iter(rand_dl))
({'input_ids': [[101, 19082, 121, 102], [101, 19082, 122, 102], [101, 19082, 123, 102], [101, 19082, 124, 102]], 'token_type_ids': [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]},
 tensor([2, 1, 3, 1]))

Summary

  • Collate functions are used to transform data.
  • You need to transform all outputs, not just simply the one you possibly want.

Shameless self promotion

If you enjoyed the tutorial buy me a coffee, or better yet buy my course (usually 90% off).