- ๐ Run environment: python3
- ๐ฉ author: K student ah
- ๐ฅ Featured column: 100 Cases of Deep Learning
- ๐ฅ Recommended column: 100 Cases of Natural Language Processing NLP
Hello, everyone! I’m K!
In the previous article explains what is BERT, this article I BERT was used to a text categorization of actual combat, use THUCTC data sets, implement the financial, real estate, stock, education, science and technology, society, politics, sports, games, entertainment and other 10 kinds of text of efficient classification, the classification accuracy reached 83.3%, The project status is as follows:
@toc
First, import and organize data
data_path = "./5-data/data.txt"
model_path = "bert-base-chinese"
max_length = 32
batch_size = 128
learning_rate = 2e-5
num_classes = 10 # category number
# Prepare data
df_raw = pd.read_csv(data_path,sep="\t",header=None,names=["text"."label"])
class_names = ["Finance and economics"."Property"."Stock"."Education"."Science and technology"."Society"."Politics"."Sports"."Game"."Entertainment"]
# Tag digitization
df_label = pd.DataFrame({"label":class_names,"y":list(range(10))})
df_raw = pd.merge(df_raw,df_label,on="label",how="left")
df_raw.head(3)
Copy the code
text | label | y | |
---|---|---|---|
0 | China Women’s University: Only one major at undergraduate level is available for male students | education | 3 |
1 | There’s a lot of confusion about how much it costs to build a website | Science and technology | 4 |
2 | East 5th ring Haitang Commune 230-290 flat 2-bedroom quasi current house 98% off | Real estate | 1 |
Look at the proportions of each category in the data
# the source code can be read
plt.show()
Copy the code
View the data length distribution
# the source code can be read
plt.show()
Copy the code
# the source code can be read
plt.show()
Copy the code
The length of the sentence is 24.Copy the code
2. Data set division
train_data, x = train_test_split(df_raw,
stratify=df_raw['label'].# allocate according to the category ratio in df['label']
test_size=0.1,
random_state=42)
val_data, test_data = train_test_split(x,
stratify=x['label'],
test_size=0.5,
random_state=43)
train_data.head(3)
Copy the code
text | label | y | length | |
---|---|---|---|---|
603 | Online media will be allowed to compete in all Pulitzer Prizes | The current politics | 6 | 19 |
2373 | 09 postgraduate entrance examination intensification review strategy: change in the stable | education | 3 | 20 |
1759 | The suspect resisted arrest with a Tibetan mastiff and threatened to harm himself | social | 5 | 18 |
# tokenizer
tokenizer = BertTokenizer.from_pretrained(model_path)
Adjust the data format
def map_example_to_dict(input_ids, attention_masks, token_type_ids, label) :
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_masks,
}, label
def encode_examples(ds) :
input_ids_list = []
token_type_ids_list = []
attention_mask_list = []
label_list = []
for index, row in ds.iterrows():
bert_input = tokenizer.encode_plus(row["text"],
add_special_tokens = True.# add [CLS], [SEP]
max_length = max_length, # max length of the text that can go to BERT
pad_to_max_length = True.# add [PAD] tokens
return_attention_mask = True.# add attention mask to not focus on pad tokens
truncation=True)
input_ids_list.append(bert_input['input_ids'])
token_type_ids_list.append(bert_input['token_type_ids'])
attention_mask_list.append(bert_input['attention_mask'])
label_list.append(row["y"])
return tf.data.Dataset.from_tensor_slices((input_ids_list, attention_mask_list, token_type_ids_list, label_list)).map(map_example_to_dict)
Copy the code
Third, build the model
Configure the data set
ds_train_encoded = encode_examples(train_data).shuffle(10000).batch(batch_size)
ds_val_encoded = encode_examples(val_data).batch(batch_size)
ds_test_encoded = encode_examples(test_data).batch(batch_size)
Initialize the model
model = TFBertForSequenceClassification.from_pretrained(model_path, num_labels=num_classes)
Set up the optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,epsilon=1e-08, clipnorm=1)
# can refer to the article about Loss not clear: https://mtyjkh.blog.csdn.net/article/details/122309754
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer,
loss=loss,
metrics=[metric])
Copy the code
All model checkpoint layers were used when initializing TFBertForSequenceClassification.
Some layers of TFBertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Copy the code
4. Training model
# fit model
bert_history = model.fit(ds_train_encoded, epochs=10, validation_data=ds_val_encoded)
Copy the code
Epoch 1/10 22/22 [= = = = = = = = = = = = = = = = = = = = = = = = = = = = = =] - 20 s 277 ms/step - loss: 1.9059 accuracy: 0.4585 - val_loss: 1.2277 - val_accuracy: 0.7933 Epoch 2/10 22/22 [= = = = = = = = = = = = = = = = = = = = = = = = = = = = = =] - 4 s 204 ms/step - loss: 0.9633-accuracy: 0.8230-val_loss: 0.6662-val_accuracy: 0.8467 Epoch 3/10 22/22 [= = = = = = = = = = = = = = = = = = = = = = = = = = = = = =] - 5 s 204 ms/step - loss: 0.5279 accuracy: 0.8900 - val_loss: 0.5360 - val_accuracy: 0.8600 Epoch 4/10 22/22 [= = = = = = = = = = = = = = = = = = = = = = = = = = = = = =] - 5 s 205 ms/step - loss: 0.3482-accuracy: 0.9200-val_loss: 0.4698-val_accuracy: 0.8667 Epoch 5/10 22/22 [= = = = = = = = = = = = = = = = = = = = = = = = = = = = = =] - 5 s 204 ms/step - loss: 0.2514 accuracy: 0.9448 - val_loss: 0.4263 - val_accuracy: 0.8867 Epoch 6/10 22/22 [= = = = = = = = = = = = = = = = = = = = = = = = = = = = = =] - 5 s 205 ms/step - loss: 0.1654-accuracy: 0.9689-val_loss: 0.4706-val_accuracy: 0.8800 Epoch 7/10 22/22 [= = = = = = = = = = = = = = = = = = = = = = = = = = = = = =] - 5 s 205 ms/step - loss: 0.1139 accuracy: 0.9841 - val_loss: 0.4517 - val_accuracy: 0.8867 Epoch 8/10 22/22 [= = = = = = = = = = = = = = = = = = = = = = = = = = = = = =] - 4 s 204 ms/step - loss: 0.0841-accuracy: 0.9863 - val_loss: 0.4967 - val_accuracy: 0.8933 Epoch 9/10 22/22 [= = = = = = = = = = = = = = = = = = = = = = = = = = = = = =] - 5 s 205 ms/step - loss: 0.0684 accuracy: 0.9878 - val_loss: 0.4540 - val_accuracy: 0.8933 Epoch 10/10 22/22 [= = = = = = = = = = = = = = = = = = = = = = = = = = = = = =] - 5 s 204 ms/step - loss: 0.0493-accuracy: 0.9948-val_loss: 0.5542-val_accuracy: 0.8867 OSS: 0.0481-accuraCopy the code
5. Model evaluation
# evaluate test_set
test_loss, test_accuracy = model.evaluate(ds_test_encoded)
print("test_set loss:", test_loss)
print("test_set accuracy:", test_accuracy)
Copy the code
2/2 [= = = = = = = = = = = = = = = = = = = = = = = = = = = = = =] - 0 s 28 ms/step - loss: 0.6915 accuracy: 0.8333 test_set loss: 0.691510796546936 test_set accuracy: 0.8333333134651184Copy the code
1. Loss diagram and Accourcy diagram
# the source code can be read
plt.show()
Copy the code
2. Other evaluation parameters
# the source code can be read
test_accuracy_report(model)
Copy the code
Precision Recall F1-Score Support Sports 1.00 0.93 0.97 15 Entertainment 0.89 1.00 0.94 16 Real estate 0.71 0.86 0.77 14 Education 1.00 0.93 0.96 14 politics 0.87 0.76 0.81 17 games 1.00 0.73 0.85 15 Society 0.75 0.94 0.83 16 Technology 0.92 0.80 0.86 15 Stocks 0.62 0.57 0.59 14 Finance 0.69 0.79 0.73 14 Accuracy 0.83 150 macro AVG 0.84 0.83 0.83 150 weighted AVg 0.85 0.83 0.83 150 Loss function: 0.691510796546936 accuracy: 0.8333333134651184Copy the code
3. Confusion matrix
# the source code can be read
plot_cm(test_label, test_pre)
Copy the code
The source code for address: mp.weixin.qq.com/s/6K0ZInHfq…
๐ Fan benefits
- ๐ Machine learning – Zhou Zhihua Baidu cloud download link extraction code: HNzu
- ๐ Statistical Learning – Li Hang Baidu Cloud download link to get code: LC5E
- ๐งจ Neural network and deep learning – Qiu Xipeng Baidu cloud download link extraction code: M3LS
- โจ Deep reinforcement learning Hung-Yi Lee Baidu cloud download link extraction code: 9b0A
- ๐ Machine learning actual combat Chinese double page version Baidu cloud download link extraction code: JY7R
- ๐ Stanford University – Fundamentals of deep Learning
- ๐ Tensorflow combat Google deep learning framework Baidu cloud download link extraction code: F6wF