= BertModel.from_pretrained("bert-base-uncased") bert_model
Transformer Model Compression (Attempt)
Introduction
So straight off the bat let me warn you that this is a failed experiment. However, I do think that the method that I have used here should be interesting enough to warrant a read.
As deep learning architectures keep spitting out more and more amazing results whether it be GPT-3 or DALL-E-2 they all rely on gigantic scale which remains unaffordable for most small scale startups.
Despite the original model remaining out of reach, compressions methods have also been getting popular. These advances come from both the hardware side (eg. float16, quantization) as well as software side. Our focus is on the latter.
Current Distillation Method
Distilbert is one of the popular models in huggingface model hub which is a distilled version of the larger BERT model. If I understand the distillation training method correctly this is the training loss: \[ \mathcal{L}_{ce} = \sum_i t_i \log(s_i) \] where \(t_i\) is the teacher model (BERT) logits, and \(s_i\) are the predicted logits of the student model (DistilBERT).
Hypothesis
My hypothesis was that in order to make a model smaller, we don’t need to simply imitate the entire network, but instead to focus on sub layers (or a group of such). Therefore as you will see below in the Teacher
network, I’ve chosen two encoder layers stacked on each other to be the teacher network. Each layer is a combination of standard attention layer upon which we have a few linear layers.
class Teacher(nn.Module):
def __init__(self, bert_model: nn.Module, layers: int):
super().__init__()
self.layers = bert_model.encoder.layer[:layers]
def forward(self, x):
for layer in self.layers:
= layer(x)[0] # 0 because HF returns as a tuple
x
return x
In the student network (ApproximateBertAttention
) there’s a few twists (which may or may not have helped). 1. I have a similar network to the teacher, except that it is only one attention like layer upon which we have a single dense
layer with a residual connection. 2. I squeeze the number of dimensions to be less than 768 (dimensionality of BERT) in most layers except for the final output layer. 3. The query, key and value layers are actually stacked linear layers, instead of the single linear layer as done in standard Attention layers. 4. The dimensionality of the query and value layers are not necessarily the same. The only actual requirement is that the query and key dimensionality must be the same to calculate softmax step. Despite being the same in the code below, I attempted with slighly lower attention_hidden_size
and the drop in performance was slim.
Code
def linear_projector(input_dim: int, output_dim: int, layers: int = 1) -> nn.Module:
= sum([[nn.Linear(input_dim, input_dim), nn.GELU()] for _ in range(layers)], []) + [nn.Linear(input_dim, output_dim)]
layers return nn.Sequential(*layers)
class ApproxBertAttention(nn.Module):
def __init__(
self,
=12,
num_attention_heads=32,
attention_hidden_size=384,
value_dim=BERT_DIM,
input_dim=BERT_DIM,
output_dim=2,
qkv_layers=0.1,
p=1e-12,
eps
):super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = attention_hidden_size
self.query_head_size = (num_attention_heads, attention_hidden_size)
self.key_head_size = (num_attention_heads, attention_hidden_size)
self.value_head_size = (num_attention_heads, value_dim // num_attention_heads)
self.value_dim = value_dim
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.attention_denominator = self.attention_head_size ** 0.5
self.query = linear_projector(input_dim, self.all_head_size, qkv_layers)
self.key = linear_projector(input_dim, self.all_head_size, qkv_layers)
self.value = linear_projector(input_dim, value_dim, qkv_layers)
self.dropout_1 = nn.Dropout(p)
self.dense = linear_projector(value_dim, output_dim, qkv_layers)
self.LayerNorm = nn.LayerNorm(output_dim, eps=eps)
self.dropout_2 = nn.Dropout(p)
def transpose_for_scores(self, x, reshape_size):
= x.size()[:-1] + reshape_size
new_x_shape = x.view(*new_x_shape)
x return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,=None,
attention_mask
):= self.transpose_for_scores(self.query(hidden_states), self.query_head_size)
query_layer = self.transpose_for_scores(self.key(hidden_states), self.key_head_size)
key_layer = self.transpose_for_scores(self.value(hidden_states), self.value_head_size)
value_layer
# Take the dot product between "query" and "key" to get the raw attention scores.
= torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores
= attention_scores / self.attention_denominator
attention_scores if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
= attention_scores + attention_mask
attention_scores
# Normalize the attention scores to probabilities.
= nn.Softmax(dim=-1)(attention_scores)
attention_probs
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
= self.dropout_1(attention_probs)
attention_probs
= torch.matmul(attention_probs, value_layer)
context_layer
= context_layer.permute(0, 2, 1, 3).contiguous()
context_layer = context_layer.size()[:-2] + (self.value_dim,)
new_context_layer_shape = context_layer.view(*new_context_layer_shape)
context_layer
= self.dense(context_layer)
context_projection_layer = self.dropout_2(context_projection_layer)
context_projection_layer # skip layer
return self.LayerNorm(context_projection_layer + hidden_states)
Below you can see that the number of parameters in the student network is almost a third
def get_num_elements(module) -> int:
return sum(
for p in module.parameters()]
[torch.prod(torch.tensor(p.shape))
)
= Teacher(bert_model, 2)
teacher = ApproxBertAttention()
student
print(get_num_elements(teacher))
print(get_num_elements(student))
tensor(14175744)
tensor(5022336)
Data
I believe that most distillation methods require you to pass in the correct format of data in order to train. For example a CNN would require images and BERT would require tokenized text. In our case I take quite a different approach and push random numbers through.
The intuition is that 1. It will be faster, 2. Considering that the student network is still a universal approximator, we should be able to simply treat the outputs of the teacher model (via random inputs) as the true value that we are trying to approximate.
class DummyData(Dataset):
def __init__(self, seq_length:int, dim: int, batches_per_epoch: int):
self.seq_length = seq_length
self.dim = dim
self.batches_per_epoch = batches_per_epoch
def __len__(self) -> int:
return self.batches_per_epoch
def __getitem__(self, idx):
return torch.randn(self.seq_length, self.dim)
= DummyData(SEQ_LEN, BERT_DIM, 1000 * BATCH_SIZE)
train_ds = DummyData(SEQ_LEN, BERT_DIM, 100 * BATCH_SIZE)
valid_ds
= DataLoader(train_ds, BATCH_SIZE, num_workers=mp.cpu_count(), pin_memory=True)
train_dl = DataLoader(train_ds, BATCH_SIZE, num_workers=mp.cpu_count(), pin_memory=True) valid_dl
Training
Finally we use pytorch-lightning along with l1_loss
to train our model. We get a loss of 0.22. For comparison if we use a smoothed exponential average of the output of the teacher model as an estimate, the error that we see 0.59, an almost 3x improvement. The code snipped for comparison is shown in the cell before next.
Code
class LightningModule(pl.LightningModule):
def __init__(self, teacher: nn.Module, student: nn.Module, learning_rate: float, loss_fn: nn.Module):
super().__init__()
self.teacher = teacher.eval()
self.student = student
self.learning_rate = learning_rate
self.loss_fn = loss_fn
def common_step(self, x: torch.FloatTensor) -> torch.FloatTensor:
= self.teacher(x).detach()
y = self.student(x)
y_est return self.loss_fn(y, y_est) #, basic_error
def training_step(self, x: torch.FloatTensor, batch_idx: int) -> torch.FloatTensor:
= self.common_step(x)
loss self.log("train_loss", loss, on_step=True, on_epoch=True)
return loss
def validation_step(self, x: torch.FloatTensor, batch_idx: int) -> torch.FloatTensor:
= self.common_step(x)
loss self.log("valid_loss", loss, on_step=True, on_epoch=True)
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.AdamW(self.student.parameters(), lr=self.learning_rate, weight_decay=1e-4)
= LightningModule(teacher, student, 1e-3, loss_fn)
lightning_module = torch.cuda.device_count()
num_gpus = pl.Trainer(
trainer =False,
fast_dev_run=2,
max_epochs=num_gpus,
gpus=1.0,
gradient_clip_val=0,
num_sanity_val_steps=16 if num_gpus > 0 else 32,
precision
) trainer.fit(lightning_module, train_dl, valid_dl)
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:446: LightningDeprecationWarning: Setting `Trainer(gpus=1)` is deprecated in v1.7 and will be removed in v2.0. Please use `Trainer(accelerator='gpu', devices=1)` instead.
f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
------------------------------------------------
0 | teacher | Teacher | 14.2 M
1 | student | ApproxBertAttention | 5.0 M
2 | loss_fn | L1Loss | 0
------------------------------------------------
19.2 M Trainable params
0 Non-trainable params
19.2 M Total params
38.396 Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2` reached.
with torch.inference_mode():
= torch.zeros(128, 768).to(DEVICE)
mean = 0.01
alpha = tqdm(train_dl)
progress_bar = teacher.eval().to(DEVICE)
teacher for x in progress_bar:
= teacher(x.to(DEVICE))[0]
y # batch_mean = torch.cat([tensor for tensor in y]).mean(dim=0)
= y.mean(dim=0)
batch_mean = alpha * batch_mean + (1- alpha) * mean
mean = loss_fn(y, mean)
error f"Current error {error:.4f}") progress_bar.set_description(
Conclusion
In concluding, these is clearly more to be done. However, I do believe that we could potentially use this for pretraining distilled networks. This is due to the simple fact that it is faster to generate synthetic random data, than it would be to preprocess raw text or images.
Furthermore, in this experiment I only stacked 2 layers for for the query
, key
and value
modules. Increasing this, along with the dimensionality of attention_hidden_size
could lead to further gains.
Shameless Self Promotion
If you enjoyed the tutorial buy my course (usually 90% off).