Transformer Model Compression (Attempt)

Deep Learning
A failed attempt at model compression using student teacher learning
Author

Sachin Abeywardana

Published

August 7, 2022

image of transformer being crushed

Introduction

So straight off the bat let me warn you that this is a failed experiment. However, I do think that the method that I have used here should be interesting enough to warrant a read.

As deep learning architectures keep spitting out more and more amazing results whether it be GPT-3 or DALL-E-2 they all rely on gigantic scale which remains unaffordable for most small scale startups.

Despite the original model remaining out of reach, compressions methods have also been getting popular. These advances come from both the hardware side (eg. float16, quantization) as well as software side. Our focus is on the latter.

Current Distillation Method

Distilbert is one of the popular models in huggingface model hub which is a distilled version of the larger BERT model. If I understand the distillation training method correctly this is the training loss: \[ \mathcal{L}_{ce} = \sum_i t_i \log(s_i) \] where \(t_i\) is the teacher model (BERT) logits, and \(s_i\) are the predicted logits of the student model (DistilBERT).

bert_model = BertModel.from_pretrained("bert-base-uncased")

Hypothesis

My hypothesis was that in order to make a model smaller, we don’t need to simply imitate the entire network, but instead to focus on sub layers (or a group of such). Therefore as you will see below in the Teacher network, I’ve chosen two encoder layers stacked on each other to be the teacher network. Each layer is a combination of standard attention layer upon which we have a few linear layers.

class Teacher(nn.Module):
    def __init__(self, bert_model: nn.Module, layers: int):
        super().__init__()
        self.layers = bert_model.encoder.layer[:layers]

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)[0] # 0 because HF returns as a tuple

        return x

In the student network (ApproximateBertAttention) there’s a few twists (which may or may not have helped). 1. I have a similar network to the teacher, except that it is only one attention like layer upon which we have a single dense layer with a residual connection. 2. I squeeze the number of dimensions to be less than 768 (dimensionality of BERT) in most layers except for the final output layer. 3. The query, key and value layers are actually stacked linear layers, instead of the single linear layer as done in standard Attention layers. 4. The dimensionality of the query and value layers are not necessarily the same. The only actual requirement is that the query and key dimensionality must be the same to calculate softmax step. Despite being the same in the code below, I attempted with slighly lower attention_hidden_size and the drop in performance was slim.

Code
def linear_projector(input_dim: int, output_dim: int, layers: int = 1) -> nn.Module:
    layers = sum([[nn.Linear(input_dim, input_dim), nn.GELU()] for _ in range(layers)], []) + [nn.Linear(input_dim, output_dim)]
    return nn.Sequential(*layers)

class ApproxBertAttention(nn.Module):
    def __init__(
        self,
        num_attention_heads=12,
        attention_hidden_size=32,
        value_dim=384,
        input_dim=BERT_DIM,
        output_dim=BERT_DIM,
        qkv_layers=2,
        p=0.1,
        eps=1e-12,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = attention_hidden_size
        self.query_head_size = (num_attention_heads, attention_hidden_size)
        self.key_head_size = (num_attention_heads, attention_hidden_size)
        self.value_head_size = (num_attention_heads, value_dim // num_attention_heads)
        self.value_dim = value_dim

        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.attention_denominator = self.attention_head_size ** 0.5

        self.query = linear_projector(input_dim, self.all_head_size, qkv_layers)
        self.key = linear_projector(input_dim, self.all_head_size, qkv_layers)
        self.value = linear_projector(input_dim, value_dim, qkv_layers)
        self.dropout_1 = nn.Dropout(p)

        self.dense = linear_projector(value_dim, output_dim, qkv_layers)
        self.LayerNorm = nn.LayerNorm(output_dim, eps=eps)
        self.dropout_2 = nn.Dropout(p)

    def transpose_for_scores(self, x, reshape_size):
        new_x_shape = x.size()[:-1] +  reshape_size
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
    ):
        query_layer = self.transpose_for_scores(self.query(hidden_states), self.query_head_size)
        key_layer = self.transpose_for_scores(self.key(hidden_states), self.key_head_size)
        value_layer = self.transpose_for_scores(self.value(hidden_states), self.value_head_size)
        
        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / self.attention_denominator
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout_1(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.value_dim,)
        context_layer = context_layer.view(*new_context_layer_shape)

        context_projection_layer = self.dense(context_layer)
        context_projection_layer = self.dropout_2(context_projection_layer)
        # skip layer
        return self.LayerNorm(context_projection_layer + hidden_states)

Below you can see that the number of parameters in the student network is almost a third

def get_num_elements(module) -> int:
    return sum(
        [torch.prod(torch.tensor(p.shape)) for p in module.parameters()]
    )
    
teacher = Teacher(bert_model, 2)
student = ApproxBertAttention()

print(get_num_elements(teacher))
print(get_num_elements(student))
tensor(14175744)
tensor(5022336)

Data

I believe that most distillation methods require you to pass in the correct format of data in order to train. For example a CNN would require images and BERT would require tokenized text. In our case I take quite a different approach and push random numbers through.

The intuition is that 1. It will be faster, 2. Considering that the student network is still a universal approximator, we should be able to simply treat the outputs of the teacher model (via random inputs) as the true value that we are trying to approximate.

class DummyData(Dataset):
    def __init__(self, seq_length:int, dim: int, batches_per_epoch: int):
        self.seq_length = seq_length
        self.dim = dim
        self.batches_per_epoch = batches_per_epoch

    def __len__(self) -> int:
        return self.batches_per_epoch

    def __getitem__(self, idx):
        return torch.randn(self.seq_length, self.dim)

train_ds = DummyData(SEQ_LEN, BERT_DIM, 1000 * BATCH_SIZE)
valid_ds = DummyData(SEQ_LEN, BERT_DIM, 100 * BATCH_SIZE)

train_dl = DataLoader(train_ds, BATCH_SIZE, num_workers=mp.cpu_count(), pin_memory=True)
valid_dl = DataLoader(train_ds, BATCH_SIZE, num_workers=mp.cpu_count(), pin_memory=True)

Training

Finally we use pytorch-lightning along with l1_loss to train our model. We get a loss of 0.22. For comparison if we use a smoothed exponential average of the output of the teacher model as an estimate, the error that we see 0.59, an almost 3x improvement. The code snipped for comparison is shown in the cell before next.

Code
class LightningModule(pl.LightningModule):
    def __init__(self, teacher: nn.Module, student: nn.Module, learning_rate: float, loss_fn: nn.Module):
        super().__init__()
        self.teacher = teacher.eval()
        self.student = student
        self.learning_rate = learning_rate
        self.loss_fn = loss_fn

    def common_step(self, x: torch.FloatTensor) -> torch.FloatTensor:
        y = self.teacher(x).detach()
        y_est = self.student(x)
        return self.loss_fn(y, y_est) #, basic_error

    def training_step(self, x: torch.FloatTensor, batch_idx: int) -> torch.FloatTensor:
        loss = self.common_step(x)
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, x: torch.FloatTensor, batch_idx: int) -> torch.FloatTensor:
        loss = self.common_step(x)
        self.log("valid_loss", loss, on_step=True, on_epoch=True)
    
    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.AdamW(self.student.parameters(), lr=self.learning_rate, weight_decay=1e-4)
    
lightning_module = LightningModule(teacher, student, 1e-3, loss_fn)
num_gpus = torch.cuda.device_count()
trainer = pl.Trainer(
    fast_dev_run=False,
    max_epochs=2,
    gpus=num_gpus,
    gradient_clip_val=1.0,
    num_sanity_val_steps=0,
    precision=16 if num_gpus > 0 else 32,
)
trainer.fit(lightning_module, train_dl, valid_dl)
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:446: LightningDeprecationWarning: Setting `Trainer(gpus=1)` is deprecated in v1.7 and will be removed in v2.0. Please use `Trainer(accelerator='gpu', devices=1)` instead.
  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                | Params
------------------------------------------------
0 | teacher | Teacher             | 14.2 M
1 | student | ApproxBertAttention | 5.0 M 
2 | loss_fn | L1Loss              | 0     
------------------------------------------------
19.2 M    Trainable params
0         Non-trainable params
19.2 M    Total params
38.396    Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2` reached.
with torch.inference_mode():
    mean = torch.zeros(128, 768).to(DEVICE)
    alpha = 0.01
    progress_bar = tqdm(train_dl)
    teacher = teacher.eval().to(DEVICE)
    for x in progress_bar:
        y = teacher(x.to(DEVICE))[0]
        # batch_mean = torch.cat([tensor for tensor in y]).mean(dim=0)
        batch_mean = y.mean(dim=0)
        mean = alpha * batch_mean + (1- alpha) * mean
        error = loss_fn(y, mean)
        progress_bar.set_description(f"Current error {error:.4f}")

Conclusion

In concluding, these is clearly more to be done. However, I do believe that we could potentially use this for pretraining distilled networks. This is due to the simple fact that it is faster to generate synthetic random data, than it would be to preprocess raw text or images.

Furthermore, in this experiment I only stacked 2 layers for for the query, key and value modules. Increasing this, along with the dimensionality of attention_hidden_size could lead to further gains.

Shameless Self Promotion

If you enjoyed the tutorial buy my course (usually 90% off).