from typing import List
import matplotlib.pyplot as plt
from torchvision import io, transforms
from torchvision.utils import Image, ImageDraw
from torchvision.transforms.functional import to_pil_image
%matplotlib inline
PyTorch Image Patches
Introduction
Getting the 16x16 patches required for the Visual Transformer (ViT) is not that straight forward. This tutorial demonstrates how to use the unfold
function in combination with reshape
to get the required shape of data.
Let’s break up our image of size 256 x 256 into 64 x 64 patches. We should end up with 4 rows and 4 columns of these patches.
= 256
IMG_SIZE = 64
PATCH_SIZE
= transforms.Resize((IMG_SIZE, IMG_SIZE))
resize = resize(io.read_image("../images/autobot.jpg")) img
The actual image looks like so:
to_pil_image(img)
The unfold
function can be used to grab a patch of certain size and stride. Unfortunately, you need to use it twice along relevant dimension to get what we are after.
= img.unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE, PATCH_SIZE)
patches
= plt.subplots(4, 4)
fig, ax for i in range(4):
for j in range(4):
= patches[:, i, j]
sub_img
ax[i][j].imshow(to_pil_image(sub_img))'off') ax[i][j].axis(
And finally we can line up the patches and plot them using reshape
.
= patches.reshape(3, -1, PATCH_SIZE, PATCH_SIZE)
patches 0, 1)
patches.transpose_(
= plt.subplots(1, 16, figsize=(12, 12))
fig, ax for i in range(16):
ax[i].imshow(to_pil_image(patches[i]))'off') ax[i].axis(
Putting it all together
Before sending it through to a transformer, we need to reshape our images from being (batch_size, channels, img_height, img_width)
to (batch_size, number_patches, pixels)
where pixels
in the above example would be 64 x 64 x 3 = 12288 pixels.
Therefore, an example Dataset
to read in the images would look like:
from torch.utils.data import Dataset
class ImageData(Dataset):
def __init__(self, files: List[str]):
self.files = files
self.resize = transforms.Resize((IMG_SIZE, IMG_SIZE))
self.num_patches = PATCH_SIZE * PATCH_SIZE
def __len__(self):
return len(self.files)
def __getitem__(self, i):
= self.resize(io.read_image(self.files[i]))
img = img\
patches 1, PATCH_SIZE, PATCH_SIZE)\
.unfold(2, PATCH_SIZE, PATCH_SIZE)
.unfold(
= patches.reshape(3, -1, PATCH_SIZE, PATCH_SIZE)
patches 0, 1)
patches.transpose_(
return patches.reshape(self.num_patches, -1)
Shameless self promotion
If you enjoyed the tutorial buy me a coffee, or better yet buy my course (usually 90% off).