Ommiting corrupted images inside a Pytorch dataloader

Sometimes our data is not as clean as it should be. Maybe some images are corruped, or the entries inside the database were not as expected. In that case, iterating through them using the Pytorch dataloader can be hustle. Recently I learned that you can clean your data after it has been loaded, but before you are iterating throught in your training lops. In this case I am going to give an example with images, but it can be applied to any kind of data.

First, we need to create or __getitem__ function:

def read_image(self, path):
    """ 
    Try tot read the image. If it is not possible, return None.
    """
    try:
        image = cv2.imread(path)
    except:
        image = None

    return image

def __getitem__(self, idx):
    image = read_image(self.paths[idx])
    label = self.labels[idx]
    return image, label

We are reading the image and handling the error by using a try/except. In this case, we are going to have images in our iterator that are None. Then, what we need is to code a collate_fn function that ommits this Nones and clean up the batch.

def collate_fn(batch):
    batch = list(filter(lamdba x: x[0] is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

The collate function just clean up from the batch all those images that we have returned as None. Then, we just need add this as an argument to our DataLoader.

tran_dataloader = DataLoader(dataset, collate_fn=collate_fn, **kwargs)

I like Pytorch for it’s flexibility, and how easy is to clean up the mess that you are doing. However, if the data is not changing over time, it would be better to check just those corrupted images and delete them from your dataset. Happy coding!

Written on January 31, 2022