Suck the cat with code! This paper is participating in[Cat Essay Campaign].
Whose kitten are you?
Small Ming classmate: the cat of 2 dog son home come again steal eat, floret cat already hungry bad, how do do??
What to do?? Small pets identify to the university scientific research team to do, as a kindergarten how do I do??
Small Ming classmates fret scratched his head, a painful ah ah sound……
Xiao Ming found the news report “monkey face recognition technology” is coming!New.qq.com/omn/2021022…
Don’t worry, kindergarten xiao Ming classmates I teach you to fly PaddleHub PaddleHub to achieve cat recognition, do not know the cat is not allowed into.
1. data collection
All cat videos are collected from public videos, and the photos of the cat’s face can be obtained through video screenshots, instead of being taken separately.
! unzip -q data/data71411/cat.zip
Copy the code
1.1 Python calls openCV to capture a picture from the video every one second, number it and save it.
import cv2
import os
for i in range(1.5) :Create an image directory
print(i)
mp4_file=str(i)+'.mp4'
dir_path=os.path.join('dataset'.str(i))
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# Save images per second
vidcap = cv2.VideoCapture(mp4_file)
success,image = vidcap.read()
fps = int(vidcap.get(cv2.CAP_PROP_FPS))
count = 0
while success:
if count % fps == 0:
cv2.imwrite("{}/{}.jpg".format(dir_path, int(count / fps)), image)
print('Process %dth seconds: ' % int(count / fps), success)
success,image = vidcap.read()
count += 1
Copy the code
1.2 Generate pictures for processing
Delete abnormal pictures such as end credits
In the manual…
import matplotlib.pyplot as plt
%matplotlib inline
import cv2 as cv
import numpy as np
# jupyter notebook display
def visualize_images() :
img = cv.imread('dataset/1/1.jpg')
plt.imshow(img)
plt.show()
visualize_images()
Copy the code
1.3 Data set Viewing
Four different kittens
1.4 the list generated
Custom data set, first of all to generate a list of images, the customized image is divided into test sets and training sets, and with labels. The following program can be run on its own by simply passing in the folder path for a single category, iterating through each of the smaller categories to produce a fixed format list. For example, we upload the root directory of the face category to the./dataset. Finally, three files, readme.json, train.list, and test.list, are generated under the specified directory.
import os
import json
Set the path to the file to be generated
data_root_path = '/home/aistudio/dataset'
# All categories of information
class_detail = []
['1', '2', '3','4']
class_dirs = os.listdir(data_root_path)
# Category tag
class_label = 0
Get the name of the total category
father_paths = data_root_path.split('/') #['', 'home', 'aistudio', 'dataset']
while True:
if father_paths[father_paths.__len__() - 1] = =' ':
del father_paths[father_paths.__len__() - 1]
else:
break
father_path = father_paths[father_paths.__len__() - 1]
Put the list of production data in its own general category folder
data_list_path = '/home/aistudio/%s/' % father_path
Create a folder if it doesn't already exist
isexist = os.path.exists(data_list_path)
if not isexist:
os.makedirs(data_list_path)
# Empty the original data
with open(data_list_path + "test.txt".'w') as f:
pass
with open(data_list_path + "trainer.txt".'w') as f:
pass
# Total number of images
all_class_images = 0
# Read each category
for class_dir in class_dirs:
# Information for each category
class_detail_list = {}
test_sum = 0
trainer_sum = 0
# Count how many images there are in each category
class_sum = 0
Get the category path
path = data_root_path + "/" + class_dir
# Get all images
img_paths = os.listdir(path)
for img_path in img_paths: # Walk through each image in the folder
name_path = path + '/' + img_path # Path for each image
if class_sum % 10= =0: # Take one out of every 10 images for test data
test_sum += 1 #test_sum specifies the number of test_sum data
with open(data_list_path + "test.txt".'a') as f:
f.write(name_path + "\t%d" % class_label + "\n") #class_label labels: 0,1,2
else:
trainer_sum += 1 Trainer_sum Specifies the number of test data
with open(data_list_path + "trainer.txt".'a') as f:
f.write(name_path + "\t%d" % class_label + "\n")#class_label labels: 0,1,2
class_sum += 1 # Number of images per category
all_class_images += 1 # Number of images of all classes
Json file class_detail data
class_detail_list['class_name'] = class_dir # Class name, e.g. Jiangwen
class_detail_list['class_label'] = class_label # category tag, 0,1,2
class_detail_list['class_test_images'] = test_sum # Number of test sets for this type of data
class_detail_list['class_trainer_images'] = trainer_sum # Number of training sets for this type of data
class_detail.append(class_detail_list)
class_label += 1 #class_label labels: 0,1,2
Get the number of categories
all_class_sum = class_dirs.__len__()
# specify json file information
readjson = {}
readjson['all_class_name'] = father_path # File parent directory
readjson['all_class_sum'] = all_class_sum #
readjson['all_class_images'] = all_class_images
readjson['class_detail'] = class_detail
jsons = json.dumps(readjson, sort_keys=True, indent=4, separators=(', '.':'))
with open(data_list_path + "readme.json".'w') as f:
f.write(jsons)
print ('Generating data list complete! ')
Copy the code
Generating the data list is complete!Copy the code
1.5 the DataSet structure
import paddle
import paddle.vision.transforms as T
import numpy as np
from PIL import Image
class MiaoMiaoDataset(paddle.io.Dataset) :
Definition of 2 Bee data set Classes
def __init__(self,mode='train') :
""" Initialization function """
self.data = []
with open('dataset/{}.txt'.format(mode)) as f:
for line in f.readlines():
info = line.strip().split('\t')
if len(info) > 0:
self.data.append([info[0].strip(), info[1].strip()])
if mode == 'train':
self.transforms = T.Compose([
T.Resize((224.224)),
T.RandomHorizontalFlip(0.5), # Random horizontal flip
T.ToTensor(), # Data format conversion and standardization HWC => CHW
T.Normalize(mean=[0.485.0.456.0.406], std=[0.229.0.224.0.225]) # Image normalization
])
else:
self.transforms = T.Compose([
T.Resize((224.224)), # Image size modification
# t.crop (IMAGE_SIZE), # random crop
T.ToTensor(), # Data format conversion and standardization HWC => CHW
T.Normalize(mean=[0.485.0.456.0.406], std=[0.229.0.224.0.225]) # Image normalization
])
def get_origin_data(self) :
return self.data
def __getitem__(self, index) :
""" Get a single sample based on the index. ""
image_file, label = self.data[index]
image = Image.open(image_file)
ifimage.mode ! ='RGB':
image = image.convert('RGB')
image = self.transforms(image)
return image, np.array(label, dtype='int64')
def __len__(self) :
""" Get the total number of samples """
return len(self.data)
Copy the code
train_dataset=MiaoMiaoDataset(mode='trainer')
test_dataset=MiaoMiaoDataset(mode='test')
print('train_data len: {}, test_data len:{}'.format(train_dataset.__len__(), test_dataset.__len__()))
Copy the code
train_data len: 45, test_data len:7
Copy the code
2.Model definition and training
At present, the data has been divided into train and test data sets, as well as the number of classifications, etc.
Next we will define the model and review the resnet50 network.
import paddle
from paddle import Model
# Define the network
network=paddle.vision.models.resnet50(num_classes=4, pretrained=True)
model = paddle.Model(network)
model.summary((-1.3.224 , 224))
Copy the code
100%|██████████| 151272/151272 [00:02<00:00, 72148.01it/s]
-------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===============================================================================
Conv2D-1 [[1, 3, 224, 224]] [1, 64, 112, 112] 9,408
BatchNorm2D-1 [[1, 64, 112, 112]] [1, 64, 112, 112] 256
ReLU-1 [[1, 64, 112, 112]] [1, 64, 112, 112] 0
MaxPool2D-1 [[1, 64, 112, 112]] [1, 64, 56, 56] 0
Conv2D-3 [[1, 64, 56, 56]] [1, 64, 56, 56] 4,096
BatchNorm2D-3 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
ReLU-2 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-4 [[1, 64, 56, 56]] [1, 64, 56, 56] 36,864
BatchNorm2D-4 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Conv2D-5 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-5 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
Conv2D-2 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-2 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
BottleneckBlock-1 [[1, 64, 56, 56]] [1, 256, 56, 56] 0
Conv2D-6 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,384
BatchNorm2D-6 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
ReLU-3 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-7 [[1, 64, 56, 56]] [1, 64, 56, 56] 36,864
BatchNorm2D-7 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Conv2D-8 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-8 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
BottleneckBlock-2 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-9 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,384
BatchNorm2D-9 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
ReLU-4 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-10 [[1, 64, 56, 56]] [1, 64, 56, 56] 36,864
BatchNorm2D-10 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Conv2D-11 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-11 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
BottleneckBlock-3 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-13 [[1, 256, 56, 56]] [1, 128, 56, 56] 32,768
BatchNorm2D-13 [[1, 128, 56, 56]] [1, 128, 56, 56] 512
ReLU-5 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-14 [[1, 128, 56, 56]] [1, 128, 28, 28] 147,456
BatchNorm2D-14 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-15 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-15 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
Conv2D-12 [[1, 256, 56, 56]] [1, 512, 28, 28] 131,072
BatchNorm2D-12 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
BottleneckBlock-4 [[1, 256, 56, 56]] [1, 512, 28, 28] 0
Conv2D-16 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,536
BatchNorm2D-16 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
ReLU-6 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-17 [[1, 128, 28, 28]] [1, 128, 28, 28] 147,456
BatchNorm2D-17 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-18 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-18 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
BottleneckBlock-5 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-19 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,536
BatchNorm2D-19 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
ReLU-7 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-20 [[1, 128, 28, 28]] [1, 128, 28, 28] 147,456
BatchNorm2D-20 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-21 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-21 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
BottleneckBlock-6 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-22 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,536
BatchNorm2D-22 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
ReLU-8 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-23 [[1, 128, 28, 28]] [1, 128, 28, 28] 147,456
BatchNorm2D-23 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-24 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-24 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
BottleneckBlock-7 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-26 [[1, 512, 28, 28]] [1, 256, 28, 28] 131,072
BatchNorm2D-26 [[1, 256, 28, 28]] [1, 256, 28, 28] 1,024
ReLU-9 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-27 [[1, 256, 28, 28]] [1, 256, 14, 14] 589,824
BatchNorm2D-27 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-28 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-28 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
Conv2D-25 [[1, 512, 28, 28]] [1, 1024, 14, 14] 524,288
BatchNorm2D-25 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-8 [[1, 512, 28, 28]] [1, 1024, 14, 14] 0
Conv2D-29 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-29 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-10 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-30 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-30 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-31 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-31 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-9 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-32 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-32 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-11 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-33 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-33 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-34 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-34 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-10 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-35 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-35 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-12 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-36 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-36 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-37 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-37 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-11 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-38 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-38 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-13 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-39 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-39 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-40 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-40 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-12 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-41 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-41 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-14 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-42 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-42 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-43 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-43 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-13 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-45 [[1, 1024, 14, 14]] [1, 512, 14, 14] 524,288
BatchNorm2D-45 [[1, 512, 14, 14]] [1, 512, 14, 14] 2,048
ReLU-15 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-46 [[1, 512, 14, 14]] [1, 512, 7, 7] 2,359,296
BatchNorm2D-46 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
Conv2D-47 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,048,576
BatchNorm2D-47 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
Conv2D-44 [[1, 1024, 14, 14]] [1, 2048, 7, 7] 2,097,152
BatchNorm2D-44 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
BottleneckBlock-14 [[1, 1024, 14, 14]] [1, 2048, 7, 7] 0
Conv2D-48 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,048,576
BatchNorm2D-48 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
ReLU-16 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-49 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,359,296
BatchNorm2D-49 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
Conv2D-50 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,048,576
BatchNorm2D-50 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
BottleneckBlock-15 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-51 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,048,576
BatchNorm2D-51 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
ReLU-17 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-52 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,359,296
BatchNorm2D-52 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
Conv2D-53 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,048,576
BatchNorm2D-53 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
BottleneckBlock-16 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
AdaptiveAvgPool2D-1 [[1, 2048, 7, 7]] [1, 2048, 1, 1] 0
Linear-1 [[1, 2048]] [1, 4] 8,196
===============================================================================
Total params: 23,569,348
Trainable params: 23,463,108
Non-trainable params: 106,240
-------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 261.48
Params size (MB): 89.91
Estimated Total Size (MB): 351.96
-------------------------------------------------------------------------------
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1263: UserWarning: Skip loading for fc.weight. fc.weight receives a shape [2048, 1000], but the expected shape is [2048, 4].
warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1263: UserWarning: Skip loading for fc.bias. fc.bias receives a shape [1000], but the expected shape is [4].
warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
{'total_params': 23569348, 'trainable_params': 23463108}
Copy the code
# Model training configuration
model.prepare(optimizer=paddle.optimizer.Adam(learning_rate=0.000005,parameters=model.parameters()),# the optimizer
loss=paddle.nn.CrossEntropyLoss(), # Loss function
metrics=paddle.metric.Accuracy()) # Evaluation indicators
# Train VisualDL tool callback functions
visualdl = paddle.callbacks.VisualDL(log_dir='visualdl_log')
Copy the code
Start the whole process training of the model
model.fit(train_dataset, # Training data set
# test_dataset, # evaluate dataset
epochs=20.# Total training rounds
batch_size=256.# Sample size for batch calculation
shuffle=True.# Whether to scramble the sample set
verbose=1.# Log display format
save_dir='./chk_points/'.# Phased training model storage path
callbacks=[visualdl]) The callback function is used
Copy the code
model.save('model_save')
Copy the code
3. Model evaluation and testing
# plot the evaluate
model.evaluate(test_dataset,verbose=1)
Copy the code
Eval begin... The loss value printed in the log is the current batch, and the metric is the average value of previous step. step 7/7 [==============================] - loss: + +00 - ACC: 0.7143-30ms/Step Eval Samples: 7 {' Loss ': [0.0], 'ACC ': 0.7142857142857143}Copy the code
To predict
Predict the test_dataset data
print('Test data set sample size: {}'.format(len(test_dataset)))
Copy the code
Sample size of test data set: 7Copy the code
# Execute forecast
result = model.predict(test_dataset)
Copy the code
Predict begin...
step 7/7 [==============================] - 32ms/step
Predict samples: 7
Copy the code
# Print the first 10 to see the results
for idx in range(7):
predict_label = str(np.argmax(result[0][idx]))
real_label = str(test_dataset.__getitem__(idx)[1])
print('Sample ID: {}, true label: {}, predicted value: {}'.format(idx, real_label, predict_label))
Copy the code
Sample ID: 0, actual tag: 0, Predicted value: 0 Sample ID: 1, Actual tag: 0, Predicted value: 0 Sample ID: 2 Sample ID: 6, True label: 4, predicted value: 1Copy the code
# Define drawing methods
from PIL import Image
import matplotlib.font_manager as font_manager
import matplotlib.pyplot as plt
%matplotlib inline
fontpath = 'MINGHEI_R.TTF'
font = font_manager.FontProperties(fname=fontpath, size=10)
def show_img(img, predict) :
plt.figure()
plt.title(predict, FontProperties=font)
plt.imshow(img, cmap=plt.cm.binary)
plt.show()
# Sample display
origin_data=test_dataset.get_origin_data()
for i in range(7):
img_path=origin_data[i][0]
real_label=str(origin_data[i][1])
predict_label= str(np.argmax(result[0][i]))
img=Image.open(img_path)
title='Sample ID: {}, true label: {}, predicted value: {}'.format(idx, real_label, predict_label)
show_img(img, title)
Copy the code