This implementation is based off of this paper by FAIR. The reason I’m excited about this paper is because 1. I was able to implement this by reading the paper (don’t underestimate how convoluted this step is) 2. We have a setting where we ignore any kind of labels altogether, i.e. completely self supervised. 3. We don’t need negative labels.
In a nutshell what the paper attempts to do, is to take two different augmentations of the same image, and try and push these embeddings closer together. Most (and possibly all) other such tasks attempt to do a triplet loss where the image is compared to a similar image (positive example), and different image(s) (negative examples). What’s amazing about this paper is that it ignores negative examples altogether and is still able to get a meaningful embedding space for a given image.
If you wish to run this live, please visit this kaggle kernel instead of the colab link above.
The following shows how a given image is passed through two transforms (RandomResizedCrop) with one guaranteed to be atleast half as large as the original image. A collate function is then used to break it up into 16x16 patches, and stack those patches into a sequence, so that we can fit it into a transformer. The collate function is a minor detail in the overall picture, but if you wish to read about it you can do so here.
The model is simply the 256 image patches passed through an encoder transformer with a CLS token. However, a few things that I think newcomers to the transformer/ pytorch field ought to know is as follows (also I made these mistakes 😅). - Make sure to use self.register_parameter when declaring trainable variables in pytorch. doing self.myvar = nn.Parameter(...) is not enough. - I have used LayerNorm everywhere possible to keep gradient sizes reasonable for a optimizer with constant learning rate.
The walkthough of the model operation ignoring normalisations is as follows: 1. Take the flattened 3x16x16 image patches and append them to a positional encoding. 2. Pass this through a linear layer and append the CLS token embedding to this. 3. Pass this through transformer, and take the 0th token since this correspond to the CLS token. 4. Normalize this vector to unit length so that you have a final image embedding.
We need a loss function that will push the output vector of the above model towards each other for the two augmented images. The way that this paper does it, is by treating the vector as a (log of) a histogram, and trying to line it up with its augmented version. Why is this interesting? Usually, most authors tend to simply take the dot product and maximise that to indicate that vectors need to line up (be close). This paper achieves the same result but by asking histograms to line up instead.
The cross entropy loss takes on the form of \(-\sum_{j=1}^J y_j\log p(x_j)\). This is minimised when \(p(x_j)\) tends towards \(y_j\). Similarly if we were to replace \(y_j\) with a probability distribution \(q_j\in[0, 1]\text{ s.t. }\sum_j q_j = 1\), p_j is minimised when the distributions line up.
My personal intuition as to why this might be better might be as follows: (Feel free to correct me if I am wrong). When trying to maximise dot products, we are asking two points on a sphere to move towards each other. Now there is an obvious shortest path distance when you visualise it by drawing a line between the two points. However, going away from the point will also, eventually get you to the same point but by taking a longer time.
By using the following equation we need two histograms to line up, which is a simpler derivative, and therefore more likely to get to a local minima faster. \[ \frac{1}{N}\sum_{n=1}^N\sum_{j=1}^J q_j\log p_j\]
Note that in the implementation below, we have extra two concepts of centering, and using a temperature. I’m not sure about the centering, but the temperature variable (which is positive but less than one), sharpens the histogram, making the peaks more prominent.
The LightningModule below goes through the training step. The main steps are: 1. Create two copies of the model with the exact same parameters. One would be considered teacher (with the gradients not being calculated at backprop) and the student. 2. Pass both augmentations through both student and teacher. 3. Calculate the loss function shown above for the two augmentations, but with one embedding from teacher and the other from the student. 4. Calculate the new exponentially weighted teacher parameters with the corresponding student parameters. 5. Calculate a new (exponentially weighted) center parameter from the embeddings passed through the teacher only. Steps 1-4 being the most important aspects IMO. 5th step feels more like a trick needed to stabalize optimization.
The following gif from the official repo is an overview of the algorithm:
The non-standard (and important to note) things I’ve done in the LightningModule are as follows: - Set all parameters in teacher model to non-trainable. - Register a buffer (not parameter) center to track the output of the teacher. - At each validation_epoch_end randomly pick an image from validation set and find 5 closest images. Use these results and push it to weights and biases as a table of images. - Every 50th batch save a histogram of gradients to weights and biases. This ensures that there is no gradient explosion or degredation as training evolves.
Thanks to PyTorch Lightning and WandB I can easily do half precision training and log the results to a beautiful dashboard, with results in the link below.
Using /root/.cache/torch_extensions as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/fused_adam...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/fused_adam/build.ninja...
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
wandb: Currently logged in as: sachinruk (use `wandb login --relogin` to force relogin)
Loading extension module fused_adam...
Time to load fused_adam op: 29.516473054885864 seconds
Finally lets look at some random images with its most closest images according to the model. Note that at this point, we throw away the student and simply take the teacher, even though it is only the student that used gradient information directly. The following image shows the weights and biases table that I created during training using only the validation dataset. The results that follow use the entire set of images and it’s corresponding closest images. Considering that this is a self supervised task, this is not “cheating”.
teacher = teacher.eval().to(device)embedding = []with torch.no_grad():for x in tqdm(image_orig_dl): out = teacher(x.to(device)) embedding.append(out.cpu()) embedding = torch.cat(embedding, dim=0)
i =64plot_closest_pairs(embedding, i, files)
i =42plot_closest_pairs(embedding, i, files)
i =21plot_closest_pairs(embedding, i, files)
Shameless Self Promotion
If you enjoyed the tutorial buy my course (usually 90% off).