Introduction 
This post describes how to use the coco dataset for semantic segmentation. Kudos to this blog  for giving me the necessary hints to create this.
=  COCO(ROOT_PATH /  "annotations/instances_train2017.json" )=  COCO(ROOT_PATH /  "annotations/instances_val2017.json" )=  train_annotations.getCatIds(supNms= ["person" , "vehicle" ])=  []for  cat in  cat_ids:= cat))=  list (set (train_img_ids))print (f"Number of training images:  { len (train_img_ids)} " )=  []for  cat in  cat_ids:= cat))=  list (set (valid_img_ids))print (f"Number of validation images:  { len (valid_img_ids)} " )
Number of training images: 74152
Number of validation images: 3125 
 
class  ImageData(Dataset):def  __init__ (self , int ], int ], = None ->  None :super ().__init__ ()self .annotations =  annotationsself .img_data =  annotations.loadImgs(img_ids)self .cat_ids =  cat_idsself .files =  [str (root_path /  img["file_name" ]) for  img in  self .img_data]self .transform =  transformdef  __len__ (self ) ->  int :return  len (self .files)def  __getitem__ (self , i: int ) ->  Tuple[torch.Tensor, torch.LongTensor]:=  self .annotations.getAnnIds(= self .img_data[i]['id' ], = self .cat_ids, = None =  self .annotations.loadAnns(ann_ids)=  torch.LongTensor(np.max (np.stack([self .annotations.annToMask(ann) *  ann["category_id" ] for  ann in  anns]), axis= 0 )).unsqueeze(0 )=  io.read_image(self .files[i])if  img.shape[0 ] ==  1 :=  torch.cat([img]* 3 )if  self .transform is  not  None :return  self .transform(img, mask)return  img, mask 
 
Image Augmentations 
When using augmentations we need to be careful to apply the same transformation to image and the mask. So for example when doing a random crop as below, we need to make it somewhat deterministic. The way to do that in torch is by getting the transformation parameters and then using torchvision.transforms.functional which are deterministic transformations.
def  train_transform(->  Tuple[torch.LongTensor, torch.LongTensor]:=  transforms.RandomResizedCrop.get_params(img1, scale= (0.5 , 1.0 ), ratio= (0.75 , 1.33 ))=  TF.resized_crop(img1, * params, size= IMAGE_SIZE)=  TF.resized_crop(img2, * params, size= IMAGE_SIZE)# Random horizontal flipping if  random.random() >  0.5 :=  TF.hflip(img1)=  TF.hflip(img2)return  img1, img2 
=  ImageData(train_annotations, train_img_ids, cat_ids, ROOT_PATH /  "train2017" , train_transform)=  ImageData(valid_annotations, valid_img_ids, cat_ids, ROOT_PATH /  "val2017" , train_transform)=  DataLoader(= True , = True , = 4 ,= True ,=  DataLoader(= False , = False , = 4 ,= True , 
The following demos a single output.
=  train_data[22 ]= (12 , 5 ))121 )122 ) 
Run the following commented out section to see how long data loading takes. It takes approximately 10 minutes to run through an epoch without any modelling.
=  next (iter (train_dl))print (x.shape, y.shape)# for x, y in tqdm(train_dl): #     continue 
torch.Size([64, 3, 128, 128]) torch.Size([64, 1, 128, 128])