Training an unknown label detecting classifier using fastai

Unknow labels detection: Detecting out of domain data i.e. classifying images only it is trained for.

Photo by Sean Pollock on Unsplash

Motivation

I have discussed in my previous blog post how we can train an out-of-domain input detecting image classifier using a multi-label classification approach. In this blog post let’s discuss how to train a regular multi-class classifier but make it some more intelligent i.e giving it the ability to detect data that it doesn’t know about. We are using the fast.ai library to create and train the model.

Explore Dataset

Let’s approach the problem by using the PETs dataset. This dataset has 37 categories of different pet breeds with nearly 200 images for each class. More details about this dataset can be found here.

Let’s get started by installing and importing the required libraries.

Now let’s download the PETs dataset,

Output

Since we have downloaded the dataset, let’s build a datablock object and dataloaders.

Getting the data ready

DataBlock is a high-level API that is used to easily build and load the data i.e. create dataloaders that can directly be served to model. This makes data loading a lot easier and more customizable. Now, let’s create a datablock object according to our dataset.

Now since the datablock object is ready, let’s see in some detail what it is doing,

When the size and min_scale arguments are given, the aug_transform function randomly crops the image to the given size by retaining at least some minimum amount of image data every epoch, which is specified by the min_scale argument.

This datablock object accepts the path to the directory containing images/data as a parameter while loading dataloaders. So now let’s create dataloaders and display some image samples,

Output

Let’s also look at some of the targets,

Output

We can see that the target is a one-hot-encoded array with a total of 37 classes. You can check all the labels by dls.vocab .

List of all the 37 classes

Now let’s use a pre-trained model and fine-tune it with our data.

Model Training

Let’s create a learner object using the resnet50 pre-trained model. Since we are using the multi-label classification method to detect unknown labels, we need to change the loss function as well. So we should use the binary cross-entropy loss function or fastai’s BCEWithLogitsLossFlat() loss function with the default threshold value. We should use accuracy_multi with a higher threshold as a metric to ensure only the label with the highest probability is chosen.

To reduce the model size and training time, we are using the half-precision training method by converting all the weights in the model to 16-bit floats. For achieving this to_fp16() is used.

Now let’s look at the code,

Output- Learning rate finder

Since we have found the optimal learning rate to use, let’s fine-tune the model for 3 epochs.

Output-Model training

We can see that within 3 epochs of training, our model has got 99% accuracy… This is a pretty impressive result, thanks to the pre-trained resnet50 model. Let’s also look at the loss plot,

Output-Loss Plot

Okay, we can see that the training and validation loss has reduced gradually and the model is not overfitting.

Now let’s proceed to model inference, where we can test the unknown label detection in action…

Inference

Before inference, let’s update the threshold of our BCEWithLogitsLossFlat() loss function to 0.95. Since the loss function we are using has the sigmoid activation, we should increase the threshold to enable the detection of labels that the model is highly confident about.

Now let’s check some results,

Output: Inference Results

We can see from the above results that our model is doing good on the validation set.

Now let’s do the model inference manually i.e. download some images from the internet and check if the results are matching...

To test the positive case, I have downloaded an image of ‘Bombay cat’ from google images to the notebook location and created a PIL Image object,

Output: Bombay Cat Image.

Now let’s predict what our model will predict,

Output for prediction

Yay... Our model correctly predicted that the class is ‘Bombay’. We can also see that the probability for this class is 0.99… which is pretty high.

Now let’s try to find predictions for some out of domain data. I have downloaded an image of the Eiffel tower from google images and used it for prediction…

Output: Eiffle Tower Image

Now if we try to find predictions for this image, our model should return negative for all the classes. Let’s find the prediction and check…

Output: Prediction for tower image

Great!! we can see that our model was able to understand that this image doesn’t belong to any dog or cat breed it was trained for i.e. 37 breeds of pets and returned False for all the available classes.

Conclusion

So from the above results, we can see that using the multi-label classification technique to train a multi-class classification model is the best way to train an intelligent image classifier.

Thank you

Sources:

  1. Deep Learning for Coders with fastai and PyTorch, Book by Howard & Gugger.
  2. Walk with fastai

You can connect with me on LinkedIn here. GitHub link for the notebook can be accessed here.

Mlearning.ai Submission Suggestions