In this Kaggle Competition, we are tasked to find out the toxicity probability of a given comment. This challenge, at its core, is a binary text classification problem. The dataset provided is a multilingual one which makes it a bit more challenging from the other text classification-based NLP problems.
Our dataset looks like so:
The labels useful for the competition are present in the
toxic columns where
0 indicates a benign, non-toxic comment and
1 indicates a toxic comment. As you would expect the dataset is a highly imbalanced one. The ratio of toxic to non-toxic comments is really skewed. An amazing EDA on the dataset is available here.
In this report, I am going to present a comparison between three models:
Hopefully, this would make it easier for you to get started with the competition.
The DistilBERT model comes from the good folks at Hugging Face via their mighty library
transformers. It makes it extremely easier to plug SoTA NLP-based models in our applications. They also provide the utility functions necessary to prepare your text data ready conforming to what the corresponding NLP model would need. We will mainly be used the following two components from
The comments are non-English and they are impossible for me to analyze (without Google Translate). In the next section, when logging the predictions we will also log their English translations (with the help of the
googletrans library). The next iteration of the network will also include class weights during training.
The model was only trained for two epochs and we already have a model that yields 84.85% accuracy on the validation dataset. As this dataset suffers from class imbalance problem, we should also consider the precision and recall for positive classes (toxic comments) but for now, we can skip them.
It takes 1090 seconds to train this model on TPU v3-8, thanks to Kaggle for making them available. As I am logging some demo predictions in between this training time should not be used for any benchmarks.
To see how the model would do as it is getting trained, I implemented a simple callback that would:
A schematic diagram of the model is as follows -
The input sequences and their corresponding masks both are 500-d vectors each. The vectors then pass through the DistilBERT model where the pre-trained weights are utilized and the information then propagates through a full-connected network. This model was not trained using class weights. Let's see how it performs.
For the next model, instead of using a fully-connected network, we will use a CNN-based network.
Class weighting did not change anything that much and it likely because of the multilingual nature of the dataset. In both of the above cases, the models largely overfitted. Below you can see the enhanced version of the prediction logger.
Scikit-Learn provides a utility function to compute the class weights and here's how I used it -
# Account for the class imbalance from sklearn.utils import class_weight class_weights = class_weight.compute_class_weight('balanced', np.unique(y_train), y_train)
If you print out the
class_weights you would get -
array([0.55288749, 5.22701553]). This means that the model would treat a toxic comment ~9.46x (5.22701553/0.55288749) important as compared to a non-toxic comment. This way the model would equally penalize under (in this case, the toxic comments) or over-represented classes (the non-toxic comments) in the training set.
Let's see if it changes anything.
We will be using a CNN-based architecture as the classification head that looks like so -
So, here we are using 1D convolutions to exploit the locality in the patterns present in the comments. The model reduces overfitting as well. This is likely because this model is better suited at exploring the locality of the comments and in turn figure out the discriminative features that lead to the toxicity/non-toxicity of the comments.
All these results suggest that the model with CNN-based classification top with class weights (orange one) performs better than the other two. As mentioned above, CNN's ability to better figure out the local patterns in the comments (that contribute to toxicity/non-toxicity of the comments) makes it a good candidate here. Here are some additional hacks you might want to incorporate to improve performance further -