Generating captions with ViT and GPT2 using 🤗 Transformers - Part 2

pytorch
huggingface
Using Encoder Decoder models in HF to combine vision and text
Author

Sachin Abeywardana

Published

January 26, 2022

Introduction

If you’ve read my previous post, this post is slightly different. Here we will focus on how we read in data, some training tricks we used, along with logging, and finally how we pushed the model up into 🤗. The training code can be found on kaggle.

Data

While we can use the VitFeatureExtractor directly from HF, this doesn’t allow you to do any augmentations. Digging into the VitFeatureExtractor all it does is 1. normalize the pixel values to be 0 and 1 (by dividing by 255) 2. Minusing a ‘mean’ value of 0.5 and dividing by a ‘standard deviation’ value of 0.5. Given this we can do the following augmentations to our training data:

tfms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.RandomRotation(degrees=20)], p=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
   ]
)

This states that 50% of the time we flip the image horizontally, and 10% of the time we would then apply a random rotation of +/- 20 degrees.

Training Module

My previous post talked about the actual model and loss. The main difference here is that I used the VisionEncoderDecoder along with its generate function instead of EncoderDecoder model where I had to implement my own generate function. Instead I want to focus here on how we trained the model and logged results.

Freezing/ Unfreezing model

When initialising a encoder decoder model, the cross attention weights are initialised randomly and do not mean anything. We could simply use a low learning rate and train the model, however, I opted to freeze the already trained parts for the first epoch. Kudos to Jeremy Howard + Rachel Thomas’ for this trick I learnt in their Fastai course.

When freezing models it is not enough to simply set parameters requires_grad=False, you need to make sure every submodule is set to .eval(). This is because the parameters of the frozen modules will continue to update in for example the normalization layers, and any dropout will continue to drop out inputs. By inspecting the model we were able to see that we needed to unfreeze any layer that had crossattention in it and ln_cross_attn layers which were the layer normalizations associated with the former.

model.eval()
    for p in model.parameters():
        p.requires_grad = False

    # only allow training of cross attention parameters
    for layer in model.decoder.transformer.h:
        layer.crossattention.train()
        for p in layer.crossattention.parameters():
            p.requires_grad = True
        layer.ln_cross_attn.train()
        for p in layer.ln_cross_attn.parameters():
            p.requires_grad = True

Logging metrics/ results/ weights

Logging loss was the simple part. In PytorchLightning you simply had to do self.log(name="Training loss", value=loss, on_step=True, on_epoch=True) within one of the steps. However, I wanted to log the results of images and its generated text. I only did this for a single batch in my validation step. This allowed me to compare results across epochs:

if batch_idx == 0:
    images = [wandb.Image(transforms.ToPILImage()(descale(image))) for image in images]
    data = list(map(list, zip(images, actual_sentences, generated_sentences)))
    columns = ["Images", "Actual Sentence", "Generated Sentence"]
    table = wandb.Table(data=data, columns=columns)
    self.logger.experiment.log({f"epoch {self.current_epoch} results": table})

Here are the results of the last epoch:

We can (and should) log the gradients too. In this case since there are many parameter groups we will only log the ones with cross attention. We can do this by using wandb.Histogram function.

def on_after_backward(self):
    if self.trainer.global_step % 50 == 0:  # don't make the tf file huge
        for name, param in self.model.named_parameters():
            if "crossattention" in name and not "norm" in name and param.requires_grad:
                self.logger.experiment.log(
                    {f"{name}_grad": wandb.Histogram(param.grad.detach().cpu())}
                )
                self.logger.experiment.log(
                    {f"{name}": wandb.Histogram(param.detach().cpu())}
                )

The histogram over time can be seen below. Note how large the gradient distribution is until end of epoch 1 where we unfreeze everything. Just to be cautious though I reduced the learning rate of everything by a factor of 10 at the end of the first epoch.

While I was training I was a bit worried that the kaggle kernel would die since there was a time limit. Therefore I decided to log the weights each epoch. This was achieved by doing the following in pytorch-lightning.

def on_train_epoch_end(self, *args):
    model_name = f"model-epoch-{self.current_epoch}.ckpt"
    model_path = f"/kaggle/working/{model_name}"
    torch.save(self.model.state_dict(), model_path)
    self.logger.experiment.log_artifact(
        artifact_or_path=model_path,
        name=model_name,
        type="model",
    )

Pushing to HuggingFace library

This was the easy part. Go to https://huggingface.co/ and register for and account, and then make a token with write access in https://huggingface.co/settings/token. Once you have this you can simply do model.push_to_hub("my-awesome-model", access_token=token) where token is the string you generated.

Shameless Self Promotion

If you enjoyed the tutorial buy my course (usually 90% off).