- Improving imbalanced datasets in machine learning with synthetic data
- Written by Alexander Watson
- The Nuggets translation Project
- Permanent link to this article: github.com/xitu/gold-m…
- Translator: PingHGao
- Proofread by: Sun Stu, Lsvih
Using synthetic data to improve extremely unbalanced data sets in machine learning
We will use the aggregate data and some concepts from SMOTE to improve the accuracy of the classification model for fraud, network security, or any minimal category.
Dealing with lopsided data sets in machine learning is a daunting challenge, and can involve topics such as payment fraud, diagnosing cancer or disease, and even network security. What all of this has in common is that only a tiny fraction of the transaction is fraudulent, and that’s what we really care about. In this article, we will greatly improve the algorithm’s accuracy on the Kaggle fraud dataset by training a model that generates additional fraud records. Uniquely, the model will combine features from fraudulent records and nonfraudulent records that are sufficiently similar to them to be more difficult to distinguish.
Our unbalanced data set
For this article, we have chosen the “credit card fraud Detection” data set that is widely used on Kaggle. This data set contains annotated transaction records from European credit card holders in September 2013. To protect user privacy, the dataset uses dimensionality reduction to convert sensitive data into 27 floating-point columns (V1-27) and a time column (the time difference between this record and the first record, in seconds). For this article, we will use the first 10,000 records from the credit card fraud dataset – click below to generate the following graph in Google Colaboratory.
Classification and visualization of fraud data
The pitfalls of evaluation criteria
Let’s look at the performance that can be achieved using a state-of-the-art ML classifier to detect fraudulent records. First, we divide the data set into training set and test set.
Wow, 99.75% accuracy. Isn’t that great? Perhaps the overall accuracy of the model only reflects the model’s performance across the entire set, not how well we detect fraudulent records. To see how well we’re actually doing, print the obfuscation matrix and the accuracy report.
As you can see from the above, we misclassified 43% of fraud cases in our test set, despite an overall accuracy rate of 99.75%!
The fraud example is augmented with synthetic data
In this section, we will focus on how wide-footing can be used to generate additional samples of fraud recordings to improve model performance and the ability to generalize fraud recordings. Let’s start with what we want to accomplish — our goal is to generate additional samples of fraud records to improve the generalization capability of our classifier and better detect fraud records in our test set.
Synthetic minority class oversampling technique
One popular Technique in the data world for achieving this is called SMOTE(Synthetic Minority Oversampling Technique), suggested by Nitesh Chawla et al in their 2002 paper. SMOTE works by selecting samples from the low sample category, finding their closest neighbor in the low sample category, and effectively interpolating new points between them. Although SMOTE cannot insert any data records outside of the low sample category, in our situation there may be useful information — it could introduce suspected fraudulent or mislabeled records into the dataset.
Take Gretel Synwide-footing from SMOTE
Our training set had only 31 examples of fraud data, which presents a particular challenge to network generalization capabilities, because Gretel-Synwide-footing uses deep learning techniques to learn and generate new samples that traditionally require large amounts of data to converge. Open the notebook below and use Google Colab to generate your own synthetic fraud dataset for free.
gretel-synthetics-generate-fraud-data colab.research.google.com
By using SMOTE’s approach to find the closest record in the cheat set, and also incorporating some highly similar records from the main categories, we have the opportunity to expand our training set and also incorporate some of the cheating (let’s call them stunts) records into our learning. This approach doesn’t modify Gretel Synwide-footing, and we just smartly select data from fraudulent records + potentially fraudulent (insidious) records. Let’s get started!
#! pip install s3fs smart_open pandas sklearn
import pandas as pd
from smart_open import open
from sklearn.neighbors import NearestNeighbors
# set parameters
NEAREST_NEIGHBOR_COUNT = 5
TRAINING_SET = 's3://gretel-public-website/datasets/creditcard_train.csv'
# Separate positive samples (non-fraudulent records) from negative samples (fraudulent records)
df = pd.read_csv(TRAINING_SET, nrows=999999).round(6)
positive = df[df['Class'] = =1]
negative = df[df['Class'] = =0]
# Train a similar sample generation model on a non-fraudulent dataset
neighbors = NearestNeighbors(n_neighbors=5, algorithm='ball_tree')
neighbors.fit(negative)
# Select X samples closest to our fraud records
nn = neighbors.kneighbors(positive, 5, return_distance=False)
nn_idx = list(set([item for sublist in nn for item in sublist]))
nearest_neighbors = negative.iloc[nn_idx, :]
nearest_neighbors
# Oversampling positive samples while adding similar (insidious, non-fraudulent) samples
# and randomly shuffle this data set
oversample = pd.concat([positive] * NEAREST_NEIGHBOR_COUNT)
training_set = pd.concat([oversample, nearest_neighbors]).sample(frac=1)
Copy the code
To construct the composite model, we will use Gretel’s new data frame training mode, while setting the default values of some parameters as shown below to optimize the results.
epochs: 7
. Set the epoch times as low as possible to balance between generating usable records and not overfitting on our limited training set.dp: False
There is no need to use differential privacy technology to compromise accuracy.gen_lines: 1000
We will generate 1000 records to greatly expand the existing 31 positive samples. Note that not all of the records generated by the model are positive samples, as we have incorporated some negative samples — but we should be able to get at least hundreds of new positive samples.batch_size=32
. Put all 30 rows into a single neural network model to preserve all field-field correlations at the cost of more records failing validation.- Train the model, generate multiple lines of data, and only retain the fraud records generated by the data synthesis model.
#! pip install gretel-synthetics --upgrade
from gretel_synthetics.batch import DataFrameBatch
from pathlib import Path
config_template = {
"max_lines": 0."max_line_len": 2048."epochs": 7."vocab_size": 20000."gen_lines": 1000."dp": False."field_delimiter": ","."overwrite": True."checkpoint_dir": str(Path.cwd() / "checkpoints")}# Train the data synthesis model
batcher = DataFrameBatch(df=training_set, batch_size=32, config=config_template)
batcher.create_training_data()
batcher.train_all_batches()
# Generate synthetic data
status = batcher.generate_all_batch_lines(max_invalid=5000)
df_synthetic = batcher.batches_to_df()
# Keep only fraud records generated by our model
df_synthetic = df_synthetic[df_synthetic['Class'] = =1]
Copy the code
Verify our synthetic data set
Now, let’s take a look at our composite data and see if we can visually confirm that our composite records represent the fraudulent records they were trained to use. Our data set has 30 dimensions, so we will display the data in 2D and 3D using a dimensionality reduction technique in data science called principal component analysis (PCA).
As shown below, we can see the training, composition, and test datasets compressed into two dimensions. Intuitively, 883 new synthetic fraud records might be helpful to the classifier as a complement to the 31 original training examples. We added positive samples from 7 test sets (which were misclassified by the default model in 3/7), and we hope that the enhanced composite data will help improve detection rates.
As you can see from our chart, it looks like my synthetic fraud example might actually work! Note that the examples near the negative sample training set appear to be false positives. If you see many examples of this, try reducing “NEAREST_NEIGHBOR_COUNT” from 5 to 3 for better results. Let’s visually observe the reduction of dimension to 3 using PCA.
Looking at the data set above, it seems possible to use synthetic data to enhance our sparse collection of fraud records and possibly greatly improve model performance. Let’s try it!
Use synthetic data to enhance our training data set
Now let’s reload the training and test data set, but this time augment our existing training data with the newly generated composited records.
Train XGBoost on the enhanced dataset, run the model on the test dataset, and look at the confusion matrix.
As we have seen, training machine learning models to accurately detect extreme minorities is a formidable challenge. However, synthetic data creates a way to improve accuracy and potentially improve the ability of models to generalize to new data sets, and to uniquely incorporate features and correlations from the entire data set into synthetic fraud examples.
For the next step, try running the above notebook on your own dataset. Want to learn more about synthetic data? Check out some of the data science articles in Gretel-Synwide-footing here and here.
conclusion
At Gretel.ai, we are excited about the possibility of using synthetic data extension training sets to create ML and AI models that can better generalize to unknown data while reducing algorithm bias. We’d love to hear about your experiences – feel free to contact us in the comments, on Twitter and at [email protected] for a more in-depth discussion. Keep an eye on us for the latest developments in synthetic data.
Interested in training with your own data? Gretel-synwide-footing is open source and free, and you can start experimenting right away through Colaboratory. If you like Gretel-Synwide-footing please give us one at GitHub ⭐!
If you find any mistakes in your translation or other areas that need to be improved, you are welcome to the Nuggets Translation Program to revise and PR your translation, and you can also get the corresponding reward points. The permanent link to this article at the beginning of this article is the MarkDown link to this article on GitHub.
The Nuggets Translation Project is a community that translates quality Internet technical articles from English sharing articles on nuggets. The content covers Android, iOS, front-end, back-end, blockchain, products, design, artificial intelligence and other fields. If you want to see more high-quality translation, please continue to pay attention to the Translation plan of Digging Gold, the official Weibo, Zhihu column.