print(caption)
transforms.ToPILImage()(descale(img))
A lone zebra grazing in some green grass.
Sachin Abeywardana
December 28, 2021
There has been a lot of hype about generating images from text. However, I had not seen many things in the caption generation domain. This is obviously the easier of the two problems, and perhaps it has been mostly solved, but I thought Iβd get my hands dirty trying to do this almost from scratch. Before we get going HF does have VisionEncoderDecoderModel
s which does exactly what we are doing today, but I wanted to try and build this from mostly scratch.
Visual Transformers was used to classify images in the Imagenet problem and GPT2 is a language model than can be used to generate text. So the question is can we combine these two? And the answer is yes, thanks to EncoderDecoderModel
s from HF. In the original Attention Is All You Need paper, using attention was the game changer. Not many people are aware however, that there were two kinds of attention. 1. Self-attention which most people are familiar with, 2. Cross-attention which allows the decoder to retrieve information from the encoder.
By default GPT-2 does not have this cross attention layer pre-trained. This paper by Google Research demonstrated that you can simply randomly initialise these cross attention layers and train the system. And this is exactly what we will be doing in this blog using the COCO dataset. An executable version of this can be found here on kaggle.
The coco dataset provides us with an image and 5 possible captions. We choose one at random during each epoch.
As mentioned earlier, we will use the EncoderDecoderModel
which will initialize the cross attention layers for us, and use pretrained weights from the Visual Transformer and (distil) GPT2. We only use the distil version for the sake of quick training, and as you will see soon, is good enough.
The tokenizer requires a bit more preprocessing than what youβd be used to compared to a BERT tokenizer. The following tokenizer code is something I copied (sorry donβt remember where), but the important bit is that a padding token was required to be introduced which I thought was strange. Mostly because how would GPT-2 have been trained without padding?
# model
vit2gpt2 = EncoderDecoderModel.from_encoder_decoder_pretrained(VIT_MODEL, DISTIL_GPT2)
# tokenizer
# make sure GPT2 appends EOS in begin and end
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
return outputs
GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens
gpt2_tokenizer = GPT2Tokenizer.from_pretrained(DISTIL_GPT2)
# set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id
gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token
At the time of writing it seems that EncoderDecoderModel
does not seem to have a generate
method which is used by GPT-2 etc. to generate text. Hence the following code.
Sampling the next token/ word is not simply a matter of taking the highest likelihood of the next token. This is due to the fact that there is no guarantee that the (log) likelihood of the entire sequence is maximised by taking the maximum at each token. This will lead to a sub-optimal answer. Beam search is an alternate method where you keep the top k tokens and iterate to the end, and hopefully one of the k beams will contain the solution we are after.
In the code below we use a sampling based method named Nucleus Sampling which is shown to have superior results and minimises common pitfalls such as repetition when generating text. The algorithm is as follows:
def top_k_top_p_filtering(
next_token_logits: torch.FloatTensor,
top_k: Optional[float]=None,
top_p: Optional[float]=None,
device: Union[str, torch.device]="cpu",
) -> torch.FloatTensor:
if top_k is None:
top_k = next_token_logits.shape[-1]
if top_p is None:
top_p = 1.0
p, largest_p_idx = F.softmax(next_token_logits, dim=-1).topk(top_k, dim=-1)
cumulative_p = p.cumsum(dim=-1)
threshold_repeated = top_p + torch.zeros((len(p),1)).to(device)
idx = torch.searchsorted(cumulative_p, threshold_repeated).clip(max=top_k-1).squeeze()
cutoffs = cumulative_p[torch.arange(len(cumulative_p)), idx]
censored_p = (cumulative_p <= cutoffs[:, None]) * p
renormalized_p = censored_p / censored_p.sum(dim=-1, keepdims=True)
final_p = torch.zeros_like(next_token_logits)
row_idx = torch.arange(len(p)).unsqueeze(1).repeat(1,top_k).to(device)
final_p[row_idx, largest_p_idx] = renormalized_p.to(final_p.dtype)
return final_p
In order to generate the actual sequence we need 1. The image representation according to the encoder (ViT) and 2. The generated tokens so far. Note that the first token is always going to be a beginning of sentence token (<BOS>
). We pass the generated tokens iteratively for a predefined length or until end of sentence is reached. In the following since we are using a batch, we ignore the <EOS>
token.
def generate_sentence_from_image(model, encoder_outputs, tokenizer, max_text_length: int, device)-> List[str]:
generated_so_far = torch.LongTensor([[tokenizer.bos_token_id]]*len(encoder_outputs.last_hidden_state)).to(device)
with torch.no_grad():
for _ in tqdm(range(max_text_length)):
attention_mask = torch.ones_like(generated_so_far)
decoder_out = model(
decoder_input_ids=generated_so_far,
decoder_attention_mask=attention_mask,
encoder_outputs=encoder_outputs
)
next_token_logits = decoder_out["logits"][:, -1, :]
filtered_p = top_k_top_p_filtering(next_token_logits, top_k=TOP_K, top_p=TOP_P, device=device)
next_token = torch.multinomial(filtered_p, num_samples=1)
generated_so_far = torch.cat((generated_so_far, next_token), dim=1)
return [tokenizer.decode(coded_sentence) for coded_sentence in generated_so_far]
Expand the button below to see the pytorch lightning code. There are a few things to note in the training step.
for name, param in self.model.named_parameters():
if "crossattention" not in name:
param.requires_grad = False
encoder_outputs = self.model.encoder(pixel_values=images)
outputs = self.model(
encoder_outputs=encoder_outputs,
decoder_input_ids=tokenized_captions["input_ids"],
decoder_attention_mask=tokenized_captions["attention_mask"],
labels=labels,
return_dict=True,
)
return outputs["loss"]
wandb.Table
element especially was a godsend.class LightningModule(pl.LightningModule):
def __init__(
self,
model: nn.Module,
tokenizer,
lr: float,
):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.lr = lr
for name, param in self.model.named_parameters():
if "crossattention" not in name:
param.requires_grad = False
def common_step(self, batch: Tuple[torch.FloatTensor, List[str]]) -> torch.FloatTensor:
images, captions = batch
tokenized_captions = {
k: v.to(self.device) for k, v in
self.tokenizer(
captions,
max_length=MAX_TEXT_LENGTH,
truncation=True,
padding=True,
return_tensors="pt",
).items()
}
labels = tokenized_captions["input_ids"].clone()
labels[tokenized_captions["attention_mask"]==0] = LABEL_MASK
encoder_outputs = self.model.encoder(pixel_values=images)
outputs = self.model(
encoder_outputs=encoder_outputs,
decoder_input_ids=tokenized_captions["input_ids"],
decoder_attention_mask=tokenized_captions["attention_mask"],
labels=labels,
return_dict=True,
)
return outputs["loss"]
def training_step(self, batch: Tuple[torch.FloatTensor, List[str]], batch_idx: int) -> torch.FloatTensor:
loss = self.common_step(batch)
self.log(name="Training loss", value=loss, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch: Tuple[torch.FloatTensor, List[str]], batch_idx: int):
loss = self.common_step(batch)
self.log(name="Validation loss", value=loss, on_step=True, on_epoch=True)
images, actual_sentences = batch
if batch_idx == 0:
encoder_outputs = self.model.encoder(pixel_values=images.to(self.device))
generated_sentences = generate_sentence_from_image(
self.model,
encoder_outputs,
self.tokenizer,
MAX_TEXT_LENGTH,
self.device
)
images = [wandb.Image(transforms.ToPILImage()(descale(image))) for image in images]
data = list(map(list, zip(images, actual_sentences, generated_sentences)))
columns = ["Images", "Actual Sentence", "Generated Sentence"]
table = wandb.Table(data=data, columns=columns)
self.logger.experiment.log({f"epoch {self.current_epoch} results": table})
def on_after_backward(self):
if self.trainer.global_step % 50 == 0: # don't make the tf file huge
for name, param in self.model.named_parameters():
if "weight" in name and not "norm" in name and param.requires_grad:
self.logger.experiment.log(
{f"{name}_grad": wandb.Histogram(param.grad.detach().cpu())}
)
self.logger.experiment.log(
{f"{name}": wandb.Histogram(param.detach().cpu())}
)
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.Adam(self.model.parameters(), lr=self.lr)
The full set of results can be seen here but hereβs some highlights: Actual caption: A teddy bear and a metronome sitting on a table top. Generated caption: <|endoftext|>A stuffed teddy bear is being laid down next to a table.<|endoftext|> Actual caption: A white cat is sitting on a laptop. Generated caption: <|endoftext|>A cat sitting on a stool with an empty laptop.<|endoftext|> Actual caption: The cows are laying down in a straight row. Generated caption: <|endoftext|>A horse is drinking milk and taking pictures of cattle inside.<|endoftext|>
Atleast the last one captured cattle, not sure about the shifty horse π€.
The biggest problem I had training this model was gradient explosion. When I chose the learning rate too high (1e-3) the weights quickly hit infinity, too small 1e-5 and the generated text wasnβt so good. So it was possibly worth unfreezing the weights after one epoch and training with a smaller learning rate. The following is a cross section of the histograms of one set of weights. As you can see it goes into the 1000s even though I managed to train it succesfully which is concerning:
Perhaps putting some sort of normalization would have fixed the gradient explosion, but that is something unfortunately I donβt have access to as this is done for me by HF EncoderDecoderModel
API.
If you enjoyed the tutorial buy my course (usually 90% off).