class Dataset:
def __init__(self):
super().__init__()
def __len__(self):
return 32
def __getitem__(self, idx):
return f"hello {idx}", random.randint(0, 3)
= Dataset()
rand_ds = DataLoader(rand_ds, batch_size=4) rand_dl
Collate function tutorial
Suppose we have the following hypothetical dataset.
Printing out the first batch, notice how the first element is just a tuple of strings and the second item has automagically been converted into a tensor.
next(iter(rand_dl))
[('hello 0', 'hello 1', 'hello 2', 'hello 3'), tensor([0, 1, 1, 0])]
The Collate Function
With the collate function we can convert these strings to a tensor as well. This leads to cleaner code in that data preprocessing is kept away from model code. In my case it actually led to a slightly faster run time per epoch, but I’m not entirely sure why.
The following code takes in a list of size batch size, where each element is a string and it’s corresponding label. Then it parses the strings through the tokenizer, which converts into numerical values thanks to the huggingface tokenizer. But more importantly, note how now you have to convert the y to torch.LongTensor
, as otherwise it would remain a tuple. This is certainly an extra step that pytorch was internally taking care of for you.
class CollateFn:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def __call__(
self, batch: List[Tuple[str, int]]
-> Tuple[Dict[str, torch.LongTensor], torch.LongTensor]:
) = zip(*batch)
x, y return self.tokenizer(list(x)), torch.LongTensor(y)
We can add an instance of the above class to our dataloader, which leads us to the following results:
= CollateFn()
collate_fn = DataLoader(rand_ds, batch_size=4, collate_fn=collate_fn)
rand_dl next(iter(rand_dl))
({'input_ids': [[101, 19082, 121, 102], [101, 19082, 122, 102], [101, 19082, 123, 102], [101, 19082, 124, 102]], 'token_type_ids': [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]},
tensor([2, 1, 3, 1]))
Summary
- Collate functions are used to transform data.
- You need to transform all outputs, not just simply the one you possibly want.
Shameless self promotion
If you enjoyed the tutorial buy me a coffee, or better yet buy my course (usually 90% off).