class Model(nn.Module):
def __init__(self, image_model, language_model, num_projections: int):
super().__init__()
= None
language_model.encoder.block self.image_model = image_model
self.language_model = language_model
self.projector = nn.Sequential(
*projection_layers(image_model.num_features, language_model.model_dim, num_projections)
)self.start_token_id = language_model.config.decoder_start_token_id
def project_image_features(self, images: torch.Tensor):
= self.image_model.forward_features(images)
image_features = einops.rearrange(image_features, "bs num_features w h -> bs (w h) num_features")
image_features = self.projector(image_features)
encoder_outputs return transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions(
=encoder_outputs,
last_hidden_state
)
def forward(self, images: torch.Tensor, tokenized_text: dict[str, torch.Tensor]):
= self.project_image_features(images)
encoder_outputs return self.language_model(
=encoder_outputs,
encoder_outputs=tokenized_text["input_ids"],
decoder_input_ids=tokenized_text["attention_mask"],
decoder_attention_mask
)
def generate(self, images: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]):
= self.project_image_features(images)
encoder_outputs return self.language_model.generate(
=encoder_outputs,
encoder_outputs**generator_kwargs
)
Creating a Caption Model from Scratch
Introduction
Before we dive into the details, I want to give you a heads up that this was just an experiment. The results leave a lot to be desired, and you can see for yourself by checking out the results section. However, if you stick around, I’ll share some interesting techniques that I learned along the way.
You can find the code for this blog post in this kaggle kernel.
Pretrained Models
To create our captioning model, I started with the hypothesis that we could take a powerful image model and plug it into a pretrained LLM. That’s exactly what I did, but with a twist: I purposely chose a non-transformer image model. This meant that we couldn’t use HF transformer’s EncoderDecoderModel
s.
For the LLM, we used the decoder of Flan-T5. We chose this model because we didn’t need to introduce cross attention layers. Cross-Attention (CA) layers are how we can connect two transformer models. In a nutshell, if the decoder is of shape batch_size x decoder_seq_len x dim
and the encoder is of shape batch_size x encoder_seq_len x dim
, the CA layer creates a batch_size x encoder_seq_len x decoder_seq_len
attention matrix. The attention matrix projects the encoder sequence into the decoder space, allowing “instructions” from the image model to mix in with the LLM decoder to create a caption.
You might be wondering how we can get a CNN architecture to mimic a transformer output. To do this, we need to do some dimension transposing gymnastics. However, we use einops
to make things easier. Here’s an example of how we use einops
to rearrange the image features:
= self.image_model.forward_features(images)
image_features = einops.rearrange(image_features, "bs num_features w h -> bs (w h) num_features")
image_features = self.projector(image_features) encoder_outputs
forward_features
is used by all timm
models to get the layer before the (1000 class) classification layer. In the example above, this leads to a batch_size x num_features x width x height
shape. num_features
in this case is analogous to the dimensionality in a transformer model, while width
and height
can be thought of as the tokens. We combine the width and height tokens and move the num_features
dimensions to the end. Finally, there’s no guarantee that this dimension size is the same as that of the LLM, so we have a few linear layers to project it to the correct size.
In order to train Flan-T5, we generally need to provide four things: input_ids
and attention_mask
on the encoder side, and decoder_input_ids
and decoder_attention_mask
on the decoder side. However, we can bypass the encoder inputs by providing encoder_outputs
. In our case, this is the output from the image model.
The final model can be seen below.
Training
There are a few compulsory things to do when training, but also a few tricks that I used, but I am not sure if they helped. I will go through them all. - We need to prepend a starting token to the decoder input ids. This is because we want to predict the next token and we need a starting token to do so. This also means prepending a 1 to the decoder attention mask. - The loss function is there to predict the next token. This means if we have the caption “A dog is running” we want to predict “dog” when we see “A”. Therefore, we need to shift the decoder input ids by one. This means that the input token ids are the the same as the output token ids, but shifted by one. loss_fn
is simply nn.CrossEntropyLoss()
def calculate_loss_fn(loss_fn, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
= logits[:, :-1, :].contiguous() # (batch_size, seq_len - 1, dim)
shift_logits = labels[:, 1:].contiguous() # (batch_size, seq_len - 1)
shift_labels return loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
- Keep in mind that the
projector
in the above model is the only part that is untrained. Usually, we would freeze the pretrained models, however, in my experiments I found that it was better to train the entire model. I do not however train any of the token embeddings. This is due to the fact that in the COCO dataset I only see a limited number of tokens. Therefore, I do not want to train the embeddings to overfit to the COCO dataset. If you look at the LightningModule you can see (in__init__
) how I leave it as an option to freeze the image encoder and the LLM. - Finally, in
configure_optimizers
I set the learning rate of the LLM to be a quarter of the projection layer, and the image model to be a half of the projection layer. This is because of my assumption that the LLM has seen far more training data than an image model, and I do not want to lose that information. Oh, and I have made it a habit to useOneCycleLR
scheduler. I find it more effective than a using an optimizer on its own.
def configure_optimizers(self) -> torch.optim.Optimizer:
= [
params "params": self.model.language_model.decoder.block.parameters(), "lr": self.lr / 4},
{"params": self.model.language_model.decoder.final_layer_norm.parameters(), "lr": self.lr / 4},
{"params": self.model.image_model.parameters(), "lr": self.lr / 2},
{"params": self.model.projector.parameters(), "lr": self.lr},
{
]= torch.optim.Adam(params)
optimizer = torch.optim.lr_scheduler.OneCycleLR(
scheduler
optimizer,=[param_group["lr"] for param_group in optimizer.param_groups],
max_lr=self.trainer.estimated_stepping_batches,
total_steps
)return [optimizer], [scheduler]
The updated results are also show below.
Update (10/09/2023): Training with resetting cross attention layers
I recently realised that I was training the whole decoder with the same learning rate. However, the cross attention layers relies on the fact that the incoming signals ought to be from a language model which it is not. Therefore, I reset all the weights in all cross attention layers (EncoderDecoderLayer
) instead of keeping the pretrained weights. This is done as shown below:
for block in self.model.language_model.decoder.block:
= block.layer[1].EncDecAttention
cross_attention_layer
nn.init.xavier_uniform_(cross_attention_layer.q.weight)
nn.init.xavier_uniform_(cross_attention_layer.k.weight)
nn.init.xavier_uniform_(cross_attention_layer.v.weight) nn.init.xavier_uniform_(cross_attention_layer.o.weight)
Results
Let’s take a look at the results of our captioning model. In the first set of results, the second and fifth captions are correct, but the second to last caption is a complete miss. This was when we had only five linear layers in the projection layer.
When we increased the number of linear layers to 12, the validation loss improved (as expected). In the second set of results, the third, fourth, and fifth captions are correct, but the first caption is a complete miss, talking about tennis. This is likely due to the model overfitting, as there are quite a few tennis images in the COCO dataset.
Update (10/09/2023): The following results are from when the cross attention layer weights were reset. Again while not perfect, seems to work better than the two models demonstrated above.
Conclusion
This model will clearly not be winning any awards, but none the less it was a fun experiment. I hope you learnt something new. If you have any questions or comments please leave them below.