Creating a Caption Model from Scratch

LLM
multimodal-models
Published

August 30, 2023

ConvNext - LLM meme

Introduction

Before we dive into the details, I want to give you a heads up that this was just an experiment. The results leave a lot to be desired, and you can see for yourself by checking out the results section. However, if you stick around, I’ll share some interesting techniques that I learned along the way.

You can find the code for this blog post in this kaggle kernel.

Pretrained Models

To create our captioning model, I started with the hypothesis that we could take a powerful image model and plug it into a pretrained LLM. That’s exactly what I did, but with a twist: I purposely chose a non-transformer image model. This meant that we couldn’t use HF transformer’s EncoderDecoderModels.

For the LLM, we used the decoder of Flan-T5. We chose this model because we didn’t need to introduce cross attention layers. Cross-Attention (CA) layers are how we can connect two transformer models. In a nutshell, if the decoder is of shape batch_size x decoder_seq_len x dim and the encoder is of shape batch_size x encoder_seq_len x dim, the CA layer creates a batch_size x encoder_seq_len x decoder_seq_len attention matrix. The attention matrix projects the encoder sequence into the decoder space, allowing “instructions” from the image model to mix in with the LLM decoder to create a caption.

You might be wondering how we can get a CNN architecture to mimic a transformer output. To do this, we need to do some dimension transposing gymnastics. However, we use einops to make things easier. Here’s an example of how we use einops to rearrange the image features:

    image_features = self.image_model.forward_features(images)
    image_features = einops.rearrange(image_features, "bs num_features w h -> bs (w h) num_features")
    encoder_outputs = self.projector(image_features)

forward_features is used by all timm models to get the layer before the (1000 class) classification layer. In the example above, this leads to a batch_size x num_features x width x height shape. num_features in this case is analogous to the dimensionality in a transformer model, while width and height can be thought of as the tokens. We combine the width and height tokens and move the num_features dimensions to the end. Finally, there’s no guarantee that this dimension size is the same as that of the LLM, so we have a few linear layers to project it to the correct size.

In order to train Flan-T5, we generally need to provide four things: input_ids and attention_mask on the encoder side, and decoder_input_ids and decoder_attention_mask on the decoder side. However, we can bypass the encoder inputs by providing encoder_outputs. In our case, this is the output from the image model.

The final model can be seen below.

class Model(nn.Module):
    def __init__(self, image_model, language_model, num_projections: int):
        super().__init__()
        language_model.encoder.block = None
        self.image_model = image_model
        self.language_model = language_model
        self.projector = nn.Sequential(
            *projection_layers(image_model.num_features, language_model.model_dim, num_projections)
        )
        self.start_token_id = language_model.config.decoder_start_token_id
        
    def project_image_features(self, images: torch.Tensor):
        image_features = self.image_model.forward_features(images)
        image_features = einops.rearrange(image_features, "bs num_features w h -> bs (w h) num_features")
        encoder_outputs = self.projector(image_features)
        return transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=encoder_outputs,
        )
        
    def forward(self, images: torch.Tensor, tokenized_text: dict[str, torch.Tensor]):
        encoder_outputs = self.project_image_features(images)
        return self.language_model(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=tokenized_text["input_ids"],
            decoder_attention_mask=tokenized_text["attention_mask"],
        )
    
    def generate(self, images: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]):
        encoder_outputs = self.project_image_features(images)
        return self.language_model.generate(
            encoder_outputs=encoder_outputs,
            **generator_kwargs
        )

Training

There are a few compulsory things to do when training, but also a few tricks that I used, but I am not sure if they helped. I will go through them all. - We need to prepend a starting token to the decoder input ids. This is because we want to predict the next token and we need a starting token to do so. This also means prepending a 1 to the decoder attention mask. - The loss function is there to predict the next token. This means if we have the caption “A dog is running” we want to predict “dog” when we see “A”. Therefore, we need to shift the decoder input ids by one. This means that the input token ids are the the same as the output token ids, but shifted by one. loss_fn is simply nn.CrossEntropyLoss()

def calculate_loss_fn(loss_fn, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    shift_logits = logits[:, :-1, :].contiguous() # (batch_size, seq_len - 1, dim)
    shift_labels = labels[:, 1:].contiguous() # (batch_size, seq_len - 1)
    return loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  • Keep in mind that the projector in the above model is the only part that is untrained. Usually, we would freeze the pretrained models, however, in my experiments I found that it was better to train the entire model. I do not however train any of the token embeddings. This is due to the fact that in the COCO dataset I only see a limited number of tokens. Therefore, I do not want to train the embeddings to overfit to the COCO dataset. If you look at the LightningModule you can see (in __init__) how I leave it as an option to freeze the image encoder and the LLM.
  • Finally, in configure_optimizers I set the learning rate of the LLM to be a quarter of the projection layer, and the image model to be a half of the projection layer. This is because of my assumption that the LLM has seen far more training data than an image model, and I do not want to lose that information. Oh, and I have made it a habit to use OneCycleLR scheduler. I find it more effective than a using an optimizer on its own.
    def configure_optimizers(self) -> torch.optim.Optimizer:
        params = [
            {"params": self.model.language_model.decoder.block.parameters(), "lr": self.lr / 4},
            {"params": self.model.language_model.decoder.final_layer_norm.parameters(), "lr": self.lr / 4},
            {"params": self.model.image_model.parameters(), "lr": self.lr / 2},
            {"params": self.model.projector.parameters(), "lr": self.lr},
        ]
        optimizer = torch.optim.Adam(params)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=[param_group["lr"] for param_group in optimizer.param_groups],
            total_steps=self.trainer.estimated_stepping_batches,
        )
        return [optimizer], [scheduler]

The updated results are also show below.

Update (10/09/2023): Training with resetting cross attention layers

I recently realised that I was training the whole decoder with the same learning rate. However, the cross attention layers relies on the fact that the incoming signals ought to be from a language model which it is not. Therefore, I reset all the weights in all cross attention layers (EncoderDecoderLayer) instead of keeping the pretrained weights. This is done as shown below:

for block in self.model.language_model.decoder.block:
    cross_attention_layer = block.layer[1].EncDecAttention
    nn.init.xavier_uniform_(cross_attention_layer.q.weight)
    nn.init.xavier_uniform_(cross_attention_layer.k.weight)
    nn.init.xavier_uniform_(cross_attention_layer.v.weight)
    nn.init.xavier_uniform_(cross_attention_layer.o.weight)

Results

Let’s take a look at the results of our captioning model. In the first set of results, the second and fifth captions are correct, but the second to last caption is a complete miss. This was when we had only five linear layers in the projection layer.

When we increased the number of linear layers to 12, the validation loss improved (as expected). In the second set of results, the third, fourth, and fifth captions are correct, but the first caption is a complete miss, talking about tennis. This is likely due to the model overfitting, as there are quite a few tennis images in the COCO dataset.

Update (10/09/2023): The following results are from when the cross attention layer weights were reset. Again while not perfect, seems to work better than the two models demonstrated above.

Conclusion

This model will clearly not be winning any awards, but none the less it was a fun experiment. I hope you learnt something new. If you have any questions or comments please leave them below.