This article is fromSet wisdom column

Overfitting is probably the most annoying problem in machine learning. Now let’s talk about what is overfitting? How did you find a fit? And 6 ways to prevent overfitting.

What is overfitting?

Overfitting is when the model makes the assumptions too strict in order to obtain consistent assumptions, that is, the model overlearns the training data. Instead of learning the overall distribution of data, the model learns the expected output of each data point.

It’s like when you’re doing math, you only know the answers to certain questions, but you don’t know the formula. This makes the model impossible to generalize. It is like a man who is in his familiar territory and is at a loss once he steps out of it.

The tricky part about overfitting is that, at first glance, the model appears to be performing well because it has few errors in the training data. However, once the model is asked to predict new data, it falls off the hook.

How do you find a fit

As mentioned above, a typical feature of overfitting is that the model does not generalize. If you want to test the generalization capability of a model, a simple way is to split the data set into two parts: the training set and the test set.

  • The training set contains 80% of the available data for training the model.
  • The test set contains the remaining 20% of the available data to test the model’s accuracy on data it has never seen before.

It actually works better to split it into three parts: 60% training data set, 20% testing data set, and 20% validation data set.

By dividing data sets, we can use each data set to test the performance of the model, so as to get a deep understanding of the training process of the model and find out when overfitting occurs. The chart below shows the difference.

Note that for this approach to work, you need to make sure that each dataset is representative of your data. A good practice is to shuffle the order of the data set before splitting it.

Overfitting can be tricky, but there are ways to prevent it. Here are six ways to prevent overfitting from three different perspectives.

How to prevent overfitting (from model & data perspective)

First, we can try to look at the components of the entire system to find solutions, which means changing the data we use.

1 Obtain more data

Your model can store lots and lots of information, which means that the more training data you put into the model, the less likely it is that a fit has occurred. The reason is that as you add more data, the model can’t overfit all the data samples and is forced to generalize to make progress. Collecting more data samples should be the first step in any data science task. More data will make the model more accurate and thus reduce the probability of overfitting.

2 Data enhancement & noise data

Collecting more data can be time-consuming and labor-intensive. If you don’t have the time or energy to do this, try to make your data look more diverse. This can be done with data enhancement, so that each time the model processes the sample, it looks at it differently than the previous time. This makes it harder for the model to learn parameters from each sample.

For the method of data enhancement, please refer to this answer by Jizhi

Another good method is to add noise data:

  • For inputs: This serves the same purpose as data enhancement, but also gives the model robustness to natural disturbances it may encounter.
  • For output: also makes training more diverse.

Note: In both cases, you need to make sure that the magnitude of the noise data is not too large. Otherwise you end up getting input from noisy data, or the output of the model is incorrect. These two situations will also bring some interference to the training process of the model.

3 Simplified Model

Even if you now have all the data you need, if your model still overfits the training data set, it may be because the model is too powerful. So you can try to reduce the complexity of the model.

As mentioned earlier, the model can only overfit part of the data. By constantly reducing the complexity of the model (such as estimators in a random forest, parameters in a neural network), an equilibrium is reached: the model is simple enough not to overfit, but complex enough to learn from the data. A convenient way to do this is to look at the model’s errors across all data sets, based on the complexity of the model.

Another benefit of simplifying models is that they can be lighter, train faster and run faster.

How to prevent overfitting (from the perspective of training process)

The second place where the model may be over-fitted is during the training phase, and responses include adjusting the loss function or the way the model is run during training.

Early termination

In most cases, the model first learns the correct distribution of the data, and at some point in time begins to overfit the data. By identifying where the transformation of the model begins, the learning process of the model can be stopped before overfitting occurs. As before, you can do this by looking at training errors over time.

How to prevent overfitting (from regularization perspective)

Regularization is the process of constraining model learning to reduce overfitting. It can take many forms, but let’s look at some of them.

L1 and L2 are regularized

One of the most powerful and well-known features of regularization is the ability to add a penalty to a loss function. The so-called “penalty” refers to some restrictions on the parameters of the loss function. The most common penalties are L1 and L2:

  • The purpose of the L1 penalty term is to minimize the absolute value of the weight
  • The purpose of the L2 penalty term is to minimize the square of the weight

For linear regression models, models built using L1 regularization are called Lasso regression, and models built using L2 regularization are called Ridge regression. Below is Python Lasso return loss function, the equation of a plus sign behind an alpha | | w | | 1 is the L1 regularization item.

Below is the loss of Ridge regression in Python function, the equation of a plus sign behind an alpha | | w | | 22 is the L2 regularization item.

In general regression analysis, regression W represents the coefficient of the feature. It can be seen from the above equation that the regularization term is processed (limited) for the coefficient.

Combining the equation, we can make the following statement:

  • L1 regularization is refers to the sum of the absolute value of each element in the weight vector w, usually expressed as | | w | | 1. L1 regularization can generate a sparse weight matrix, that is, a sparse model can be used for feature selection.
  • L2 regularization is refers to the weight vector w in each element is the sum of the squares of the then square root (can see L2 regularization item of Ridge regression with square symbol), usually expressed as | | w | | 2.

With the penalty term in place, the model is forced to compromise its weight because it can no longer arbitrarily make the weight larger. This makes the model more general, which helps us combat overfitting.

Another advantage of L1 penalty is that it can carry out feature selection, that is, it can reduce the coefficient of some useless features to 0, so as to help the model find the most relevant features in the data set. The disadvantage is that it is usually not as computationally efficient as the L2 penalty term.

Here is the case for the weight matrix. Note that the L1 matrix is a sparse weight matrix with many zeros, while the L2 matrix has relatively small weights.

Besides using L1 regularization and L2 regularization to prevent model overfitting, adding noise data to parameters during training can also promote model generalization.

How to Prevent Overfitting (Dealing with Deep Learning Models)

For overfitting in deep learning models, we can approach the problem from two perspectives: Dropout and Dropconnect. Since deep learning relies on neural networks to process information from one layer to the next, it is more effective to start from these two aspects. The idea is to randomly invalidate neurons (dropout) or disable connections in a network (dropConnect) during training.

This makes neural networks long and repetitive because they can no longer rely on specific neurons or connections to extract specific features. After the model is trained, all the neurons and connections are preserved. The experiment shows that this method can play the same effect as the neural network integration method, which can help the model generalization, and thus reduce the problem of over-fitting.

conclusion

Overfitting is a big problem when we train machine learning models, and it can be a real headache if you don’t know how to deal with it. Using the methods mentioned above, you should be able to effectively prevent model overfitting when training your machine learning models.

Resources: Poke here


0806 ai – From Scratch to Master – on sale for a limited time!

Click here for details

What about online programming?

(The first 25 students can also get ¥200 coupon)