Fine Tuning GPT2 for Grammar Correction

Deep Learning
LLM
Fine-tuning GPT2 for Sequence to Sequence tasks
Author

Sachin Abeywardana

Published

September 25, 2022

jimi henfrix fine tuning guitar

Introduction

GPT2 is well known for it’s capabilities to generate text. While we could always use the existing model from huggingface in the hopes that it generates a sensible answer, it is far more profitable to tune it to our own task. In this example I show how to correct grammar using GPT2. While results aren’t perfect, had I been given enough time and (compute) resources we could have a possible replacement to chrome’s default grammar correction. If you wish to run this yourself, a working example can be found in this kaggle kernel.

GPT2 Model Architecture

As a quick primer on GPT2, note that GPT2 is a decoder only transformer. What this means is that GPT2 is only allowed to pay attention to the current token and the previous tokens. This is in contrast to encoder only transformers like BERT.

The reason that this architecture is important is that when it comes to generation time, the only tokens that ought to be visible are the previous tokens. During training, this effect is achieved by making the Attention matrix triangular.

Tokenizer

For some odd reason GPT2 does not ship with beginning of sentence or end of sentence tokens. It only contains the padding token natively. Therefore, we need to add these to our tokenizer. As a result of this change, we also need to change the number of embeddings in GPT2 model and hence, language_model.resize_token_embeddings(len(tokenizer)). This will randomly initialise the embeddings for just the new embeddings while we maintain the previously trained embeddings for all other tokens.

There are two cases for tokenizing. 1. During training we have both input_sentence and corrected output_sentence. We add a bos token, seperate with a sep token and append a eos token. 2. In the inference stage, we only have access to input_sentence. Therefore, we end those sentences with bos. This logic is captured in the __call__ method below.

Code
class Tokenizer:
    def __init__(self, tokenizer, max_len: int):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.bos = tokenizer.bos_token
        self.eos = tokenizer.eos_token
        self.sep = tokenizer.sep_token
        self.num_special_tokens = len(self.tokenizer.all_special_tokens)
        
    def __getattr__(self, attribute: str):
        if hasattr(self.tokenizer, attribute):
            return getattr(self.tokenizer, attribute)
        else:
            raise AttributeError(f"{attribute} not found")

    def __call__(self, input_sentences: List[str], output_sentences: Optional[List[str]]=None, device:torch.device=None) -> AutoTokenizer:
        if output_sentences is None:
            sentences = [self.bos + x + self.sep for x in input_sentences]
        else:
            sentences = [self.bos + x + self.sep + y + self.eos for x, y in zip(input_sentences, output_sentences)]
        
        tokenized = self.tokenizer(
            sentences, 
            truncation=True,
            padding=True,
            return_tensors="pt",
            max_length=self.max_len,
        )
        if device is not None:
            return {key: tensor.to(device) for key, tensor in tokenized.items()}
        return tokenized

    def decode(self, x: Dict[str, torch.LongTensor]):
        return [self.tokenizer.decode(sentence[:sentence_len]) for sentence, sentence_len in 
                zip(x["input_ids"], target["attention_mask"].sum(axis=-1))]
    
    def batch_decode(self, encoded_outputs: torch.LongTensor) -> List[str]:
        return self.tokenizer.batch_decode(encoded_outputs.cpu(), skip_special_tokens=True)
    
    def __len__(self):
        return len(self.tokenizer)


# get text base and transform
language_model = AutoModelForCausalLM.from_pretrained(LANGUAGE_MODEL)
tokenizer = Tokenizer(
    AutoTokenizer.from_pretrained(
        LANGUAGE_MODEL, 
        bos_token="<|startoftext|>",
        eos_token="<|endoftext|>", 
        pad_token="<|pad|>", 
        sep_token="<|sep|>"
    ),
    MAX_LEN,
)
language_model.resize_token_embeddings(len(tokenizer))
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

(Huggingface) Datasets

Due to huge kudos to HF’s new dataset API we can train large (streaming) datasets. In the following block we use the c4 dataset which contains grammar correction paris. We keep the first 100,000 as a valid dataset and the rest for training. I’m unsure what the group_batch was for. Just copied it from a tutorial.

data = datasets.load_dataset("liweili/c4_200m", cache_dir="/kaggle/working/", streaming=True, split="train")\
        .shuffle(seed=42, buffer_size=10_000)
c4_valid = data.take(100000)
c4_train = data.skip(100000)
def group_batch(batch):
    return {k: [v] for k, v in batch.items()}
train_dl = c4_train.map(group_batch, batched=True, batch_size=BATCH_SIZE)
valid_dl = c4_valid.map(group_batch, batched=True, batch_size=BATCH_SIZE)

Training

Let’s breakdown the following LightningModule.

Freezing parameters

I am personally not a fan of training the embeddings. Reason being that during training, we only see a fraction of all possible tokens. Some tokens appearing more frequently than others. So it seems unfair that we update some embeddings, while others do not get a chance to be updates. Therefore, it seems in terms of making the model resillient to unseen tokens, we should freeze the embeddings.

However, given that we have 3 new tokens (bos, eos, sep), what we do instead is every few batches, we reset the embeddings of existing tokens to what we started with.

if (batch_idx + 1) % 100 == 0:
    self.model.transformer.wte.weight[:-self.tokenizer.num_special_tokens].data = self.original_embed_weights         

In the same thought process I believe that it is beneficial freeze the bottom 2 layers (out of 12) of the transformer. This again is a step to avoid overfitting to our training data.

How we use the data

The dataset defined above returns batch which is a dictionary with keys input and output. The input contains the incorrect grammar sentences, while the other contains the corrected setences. While we can match input to output, it is also important for the model to understand when not to do anything. i.e. return the input when it sees a good sentence. Therefore, in common_step you will see input matched with output while also matching output with output.

Calculating Loss

HF transformers luckily takes care of calculating most of the loss for us. The loss is simply given the current token, what is the cross entropy loss over all possible tokens.

However, there are two cases that we need to ignore. In order to ignore a token you simply set the label to -100. This is a special label outlined in the torch cross-entropy docs. 1. When some sentences are shorter than others in the batch. This is given to us by the tokenizer’s attention_mask. 2. The second case which is not entirely necessary is that we do not need to calculate loss before the sep token. This is due to the fact that the model will always be given the input sentence. We do not need to burden the model further to learn the structure of the incoming sentence. This is why we generate a mask defined by mask = (good_grammar_labels == self.tokenizer.sep_token_id).roll(shifts=1, dims=-1).cumsum(dim=-1) == 0.

Code
class LightningModule(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        tokenizer: Tokenizer,
        generation_kwargs: Dict[str, Any],
        lr: float = 1e-3,
    ) -> None:
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.lr = lr
        self.generation_kwargs = generation_kwargs
        self.original_embed_weights = self.model.transformer.wte.weight[:-self.tokenizer.num_special_tokens].clone()
        
        for layer in self.model.transformer.h[:FREEZE_LAYERS]:
            layer.eval()
            for p in layer.parameters():
                p.requires_grad = False
        
        self.table_logging = 0
        
    def common_step(self, batch: Dict[str, torch.LongTensor]) -> torch.Tensor:
        good_grammar_batch = self.tokenizer(batch["output"], batch["output"], self.device)
        good_grammar_labels = good_grammar_batch["input_ids"].clone()
        good_grammar_labels[good_grammar_batch["attention_mask"] == 0] = LABEL_MASK
        mask = (good_grammar_labels == self.tokenizer.sep_token_id).roll(shifts=1, dims=-1).cumsum(dim=-1) == 0
        good_grammar_labels[mask] = LABEL_MASK
        
        bad_grammar_batch = self.tokenizer(batch["input"], batch["output"], self.device)
        bad_grammar_labels = bad_grammar_batch["input_ids"].clone()
        bad_grammar_labels[bad_grammar_batch["attention_mask"] == 0] = LABEL_MASK
        mask = (bad_grammar_labels == self.tokenizer.sep_token_id).roll(shifts=1, dims=-1).cumsum(dim=-1) == 0
        bad_grammar_labels[mask] = LABEL_MASK

        good_grammar_out = self.model(
            **good_grammar_batch,
            labels=good_grammar_labels,
        )
        bad_grammar_out = self.model(
            **bad_grammar_batch,
            labels=bad_grammar_labels,
        )
        return good_grammar_out.loss + bad_grammar_out.loss
        
    def training_step(
        self, batch: Dict[str, torch.LongTensor], batch_idx: int,
    ) -> torch.Tensor:
        if (batch_idx + 1) % 100 == 0:
            self.model.transformer.wte.weight[:-self.tokenizer.num_special_tokens].data = self.original_embed_weights
            
        loss = self.common_step(batch)     
        self.log("training_loss", loss, on_step=True, on_epoch=True, batch_size=len(batch["input"]))
             
        return loss

    def validation_step(
        self, batch: Tuple[torch.Tensor, List[str]], batch_idx: int,
    ) -> torch.Tensor:
        loss = self.common_step(batch)
        self.log("validation_loss", loss, on_step=False, on_epoch=True, batch_size=len(batch["input"]))
        
        if batch_idx == 0:
            self.log_examples(batch)
            
    def log_examples(self, batch):
        good_grammar_batch = self.tokenizer(batch["output"], device=self.device)
        bad_grammar_batch = self.tokenizer(batch["input"], device=self.device)
        encoded_good_outputs = self.model.generate(**good_grammar_batch, **self.generation_kwargs)
        encoded_bad_outputs = self.model.generate(**bad_grammar_batch, **self.generation_kwargs)
        generated_good_sentences = self.tokenizer.batch_decode(encoded_good_outputs)
        generated_bad_sentences = self.tokenizer.batch_decode(encoded_bad_outputs)
        
        data = list(map(list, zip(batch["output"] + batch["input"], generated_good_sentences + generated_bad_sentences)))
        columns = ["Actual Sentence", "Generated Sentence"]
        data = [[x, y.split(x)[1]] for x, y in data]
        table = wandb.Table(data=data, columns=columns)
        if self.logger is not None:
            self.table_logging += 1
            self.logger.experiment.log({f"epoch {self.table_logging} results": table})

    def configure_optimizers(self) -> torch.optim.Optimizer:
        caption_params = [
            {"params": self.model.transformer.ln_f.parameters() , "lr": self.lr},
            {"params": self.model.transformer.h[FREEZE_LAYERS:].parameters() , "lr": self.lr},
            {"params": self.model.transformer.wte.parameters() , "lr": self.lr},
        ]
        return adam.FusedAdam(caption_params)

Results

In order to prove that the model is learning, the following results show the generated text at the outset of training (which is just jibberish). This is to be expected since the model does not understand what a sep token is or what to do with it. results of epoch 1

The following are the results after 10 epochs. Which are clearly showing great improvement, but still not perfect. For instance, it doesn’t seem to understand you only capitalize only at the beginning of a sentence. However, as seen in row 23 it seems to be intelligent enough to copy across names such as Conor and nouns such as British. results of epoch 10

Summary

In summarising the main points made in this article, 1. Freeze the lower layers, and only train the new token embeddings. 2. Calculate loss for only what is necessary.

Shameless Self Promotion

If you enjoyed the tutorial buy my course (30 days moneyback).