Introduction

There has been a lot of hype about generating images from text. However, I had not seen many things in the caption generation domain. This is obviously the easier of the two problems, and perhaps it has been mostly solved, but I thought I'd get my hands dirty trying to do this almost from scratch. Before we get going HF does have VisionEncoderDecoderModels which does exactly what we are doing today, but I wanted to try and build this from mostly scratch.

Visual Transformers was used to classify images in the Imagenet problem and GPT2 is a language model than can be used to generate text. So the question is can we combine these two? And the answer is yes, thanks to EncoderDecoderModels from HF. In the original Attention Is All You Need paper, using attention was the game changer. Not many people are aware however, that there were two kinds of attention. 1. Self-attention which most people are familiar with, 2. Cross-attention which allows the decoder to retrieve information from the encoder. Cross Attention

By default GPT-2 does not have this cross attention layer pre-trained. This paper by Google Research demonstrated that you can simply randomly initialise these cross attention layers and train the system. And this is exactly what we will be doing in this blog using the COCO dataset. An executable version of this can be found here on kaggle.

Data

The coco dataset provides us with an image and 5 possible captions. We choose one at random during each epoch.

print(caption)
transforms.ToPILImage()(descale(img))
A lone zebra grazing in some green grass.

GPT2 Tokenizer and Model

As mentioned earlier, we will use the EncoderDecoderModel which will initialize the cross attention layers for us, and use pretrained weights from the Visual Transformer and (distil) GPT2. We only use the distil version for the sake of quick training, and as you will see soon, is good enough.

The tokenizer requires a bit more preprocessing than what you'd be used to compared to a BERT tokenizer. The following tokenizer code is something I copied (sorry don't remember where), but the important bit is that a padding token was required to be introduced which I thought was strange. Mostly because how would GPT-2 have been trained without padding?

# model
vit2gpt2 = EncoderDecoderModel.from_encoder_decoder_pretrained(VIT_MODEL, DISTIL_GPT2)

# tokenizer
# make sure GPT2 appends EOS in begin and end
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
    outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
    return outputs
    
GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens
gpt2_tokenizer = GPT2Tokenizer.from_pretrained(DISTIL_GPT2)
# set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id
gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token

Nucleus Sampling

At the time of writing it seems that EncoderDecoderModel does not seem to have a generate method which is used by GPT-2 etc. to generate text. Hence the following code.

Sampling the next token/ word is not simply a matter of taking the highest likelihood of the next token. This is due to the fact that there is no guarantee that the (log) likelihood of the entire sequence is maximised by taking the maximum at each token. This will lead to a sub-optimal answer. Beam search is an alternate method where you keep the top k tokens and iterate to the end, and hopefully one of the k beams will contain the solution we are after.

In the code below we use a sampling based method named Nucleus Sampling which is shown to have superior results and minimises common pitfalls such as repetition when generating text. The algorithm is as follows:

  • Choose all largest tokens that sums up to a given threshold.
  • Set all other token probabilities to zero, and renormalize probability distribution.
  • Sample from above distribution.
  • In the code below, apart from a threshold on top probable tokens, we also have a limit on possible tokens which is defaulted to a large number (1000).
def top_k_top_p_filtering(
    next_token_logits: torch.FloatTensor,
    top_k: Optional[float]=None, 
    top_p: Optional[float]=None,
    device: Union[str, torch.device]="cpu",
) -> torch.FloatTensor:
    if top_k is None:
        top_k = next_token_logits.shape[-1]
    if top_p is None:
        top_p = 1.0
        
    p, largest_p_idx = F.softmax(next_token_logits, dim=-1).topk(top_k, dim=-1)
    cumulative_p = p.cumsum(dim=-1)
    threshold_repeated = top_p + torch.zeros((len(p),1)).to(device)
    idx = torch.searchsorted(cumulative_p, threshold_repeated).clip(max=top_k-1).squeeze()
    cutoffs = cumulative_p[torch.arange(len(cumulative_p)), idx]
    censored_p = (cumulative_p <= cutoffs[:, None]) * p
    renormalized_p = censored_p / censored_p.sum(dim=-1, keepdims=True)
    
    final_p = torch.zeros_like(next_token_logits)
    row_idx = torch.arange(len(p)).unsqueeze(1).repeat(1,top_k).to(device)
    final_p[row_idx, largest_p_idx] = renormalized_p.to(final_p.dtype)

    return final_p

In order to generate the actual sequence we need 1. The image representation according to the encoder (ViT) and 2. The generated tokens so far. Note that the first token is always going to be a beginning of sentence token (<BOS>). We pass the generated tokens iteratively for a predefined length or until end of sentence is reached. In the following since we are using a batch, we ignore the <EOS> token.

def generate_sentence_from_image(model, encoder_outputs, tokenizer, max_text_length: int, device)-> List[str]:
    generated_so_far = torch.LongTensor([[tokenizer.bos_token_id]]*len(encoder_outputs.last_hidden_state)).to(device)
    with torch.no_grad():
        for _ in tqdm(range(max_text_length)):
            attention_mask = torch.ones_like(generated_so_far)
            decoder_out = model(
                decoder_input_ids=generated_so_far, 
                decoder_attention_mask=attention_mask,
                encoder_outputs=encoder_outputs
            )

            next_token_logits = decoder_out["logits"][:, -1, :]
            filtered_p = top_k_top_p_filtering(next_token_logits, top_k=TOP_K, top_p=TOP_P, device=device)
            next_token = torch.multinomial(filtered_p, num_samples=1)
            generated_so_far = torch.cat((generated_so_far, next_token), dim=1)

    return [tokenizer.decode(coded_sentence) for coded_sentence in generated_so_far]

Training Module (PyTorch Lightning)

Expand the button below to see the pytorch lightning code. There are a few things to note in the training step.

  1. Train only the cross-attention weights. This was a design decision based on time available, and not a necessity. This however meant that I was able to train an epoch of 110000 image caption pairs in 30 minutes.
    for name, param in self.model.named_parameters():
     if "crossattention" not in name:
         param.requires_grad = False
    
  2. Loss was calculated for you by HF. However, if you wish to understand how exactly HF calculates loss, that can be found in this discussion that I authored.
encoder_outputs = self.model.encoder(pixel_values=images)
outputs = self.model(
    encoder_outputs=encoder_outputs,
    decoder_input_ids=tokenized_captions["input_ids"],
    decoder_attention_mask=tokenized_captions["attention_mask"],
    labels=labels,
    return_dict=True,
)

return outputs["loss"]
  1. Weights and Biases are amazing and you should log all your experiments. The wandb.Table element especially was a godsend.

class LightningModule(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        tokenizer,
        lr: float,
    ):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.lr = lr
        
        for name, param in self.model.named_parameters():
            if "crossattention" not in name:
                param.requires_grad = False
        
    def common_step(self, batch: Tuple[torch.FloatTensor, List[str]]) -> torch.FloatTensor:
        images, captions = batch
        tokenized_captions = {
            k: v.to(self.device) for k, v in 
            self.tokenizer(
                captions,
                max_length=MAX_TEXT_LENGTH,
                truncation=True,
                padding=True,
                return_tensors="pt",
            ).items()
        }
        labels = tokenized_captions["input_ids"].clone()
        labels[tokenized_captions["attention_mask"]==0] = LABEL_MASK
        encoder_outputs = self.model.encoder(pixel_values=images)
        outputs = self.model(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=tokenized_captions["input_ids"],
            decoder_attention_mask=tokenized_captions["attention_mask"],
            labels=labels,
            return_dict=True,
        )
        
        return outputs["loss"]
    
    def training_step(self, batch: Tuple[torch.FloatTensor, List[str]], batch_idx: int) -> torch.FloatTensor:
        loss = self.common_step(batch)
        self.log(name="Training loss", value=loss, on_step=True, on_epoch=True)
        
        return loss
        
    def validation_step(self, batch: Tuple[torch.FloatTensor, List[str]], batch_idx: int):
        loss = self.common_step(batch)
        self.log(name="Validation loss", value=loss, on_step=True, on_epoch=True)

        images, actual_sentences = batch
        
        if batch_idx == 0:
            encoder_outputs = self.model.encoder(pixel_values=images.to(self.device))
            generated_sentences = generate_sentence_from_image(
                self.model, 
                encoder_outputs, 
                self.tokenizer, 
                MAX_TEXT_LENGTH,
                self.device
            )
            images = [wandb.Image(transforms.ToPILImage()(descale(image))) for image in images]
            data = list(map(list, zip(images, actual_sentences, generated_sentences)))
            columns = ["Images", "Actual Sentence", "Generated Sentence"]
            table = wandb.Table(data=data, columns=columns)
            self.logger.experiment.log({f"epoch {self.current_epoch} results": table})
                        
    def on_after_backward(self):
        if self.trainer.global_step % 50 == 0:  # don't make the tf file huge
            for name, param in self.model.named_parameters():
                if "weight" in name and not "norm" in name and param.requires_grad:
                    self.logger.experiment.log(
                        {f"{name}_grad": wandb.Histogram(param.grad.detach().cpu())}
                    )
                    self.logger.experiment.log(
                        {f"{name}": wandb.Histogram(param.detach().cpu())}
                    )

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)
            

Results

The full set of results can be seen here but here's some highlights: teddy Actual caption: A teddy bear and a metronome sitting on a table top. Generated caption: <|endoftext|>A stuffed teddy bear is being laid down next to a table.<|endoftext|> cat Actual caption: A white cat is sitting on a laptop. Generated caption: <|endoftext|>A cat sitting on a stool with an empty laptop.<|endoftext|> cow Actual caption: The cows are laying down in a straight row. Generated caption: <|endoftext|>A horse is drinking milk and taking pictures of cattle inside.<|endoftext|>

Atleast the last one captured cattle, not sure about the shifty horse 🤔.

Gotchas and Potential Improvements

The biggest problem I had training this model was gradient explosion. When I chose the learning rate too high (1e-3) the weights quickly hit infinity, too small 1e-5 and the generated text wasn't so good. So it was possibly worth unfreezing the weights after one epoch and training with a smaller learning rate. The following is a cross section of the histograms of one set of weights. As you can see it goes into the 1000s even though I managed to train it succesfully which is concerning:

Perhaps putting some sort of normalization would have fixed the gradient explosion, but that is something unfortunately I don't have access to as this is done for me by HF EncoderDecoderModel API.

Shameless Self Promotion

If you enjoyed the tutorial buy my course (usually 90% off).