Generate training validation test TXT file
- voc_annotation.py
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - -- -- -- -- -- -- -- -- - #
# annotation_mode specifies what is evaluated at run time for this file
# annotation_mode 0 represents the entire label process, including access to VOCdevkit VOC2007 / ImageSets TXT inside and training with 2007 _train. TXT, 2007 _val. TXT
# annotation_mode 1 represents VOCdevkit VOC2007 / ImageSets TXT inside
# annotation_mode = 2 to obtain 2007_train. TXT and 2007_val.txt for training
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - -- -- -- -- -- -- -- -- - #
# this is the type of file we generate. If we specify different values, we will generate different files
annotation_mode = 0
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
TXT, 2007_val.txt, 2007_val.txt, 2007_val.txt
Use the same classes_path used for training and prediction
# if 2007_train.txt is generated without destination information
# Then classes are not set correctly
# only valid if annotation_mode is 0 and 2
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# specify the location of the class name file
classes_path = 'model_data/voc_classes.txt'
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - -- -- -- -- -- -- -- -- - #
# trainval_percent Specifies the ratio of (training set + verification set) to test set. By default (training set + verification set): test set = 9:1
# train_percent Specifies the ratio of training sets to verification sets in (training set + verification set). By default, training set: verification set = 9:1
Annotation_mode is valid only if annotation_mode is 0 and 1
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - -- -- -- -- -- -- -- -- - #
# trainval is
trainval_percent = 0.9
train_percent = 0.9
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- #
# point to the folder where the VOC dataset is located
The default is to point to VOC datasets in the root directory
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- #
Specify the root folder for the data
VOCdevkit_path = 'VOCdevkit'
# use it in the loop below
VOCdevkit_sets = [('2007'.'train'), ('2007'.'val')]
Read the file information from class_path above to get the list of class_name
classes, _ = get_classes(classes_path)
def convert_annotation(year, image_id, list_file) :
Convert data in XML to x1,y1,x2,y2,class_num multi-group format
in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
# XML parses data
tree=ET.parse(in_file)
root = tree.getroot()
for obj in root.iter('object'):
difficult = 0
if obj.find('difficult')! =None:
difficult = obj.find('difficult').text
cls = obj.find('name').text
# Just note that here we remove the ones that are not in our class_list and those marked difficult
if cls not in classes or int(difficult)==1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
list_file.write("" + ",".join([str(a) for a in b]) + ', ' + str(cls_id))
if __name__ == "__main__":
random.seed(0)
# select 0 or 1 in the above mode
if annotation_mode == 0 or annotation_mode == 1:
print("Generate txt in ImageSets.")
Get the file path of the XML file
xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
# Select the path to the file we want to save
saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
List all files in the XML folder to form a list
temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
if xml.endswith(".xml"):
total_xml.append(xml)
Divide the data according to the ratio we defined above
num = len(total_xml)
list = range(num)
# trainval accounts for 90% of the total data, train accounts for 90% of trainval
tv = int(num*trainval_percent)
tr = int(tv*train_percent)
# We take 90% of the randomly selected data set as trainval's data set, where subscripts are placed in a list
trainval= random.sample(list,tv)
# We randomly selected 90% of trainval data set as train data set
train = random.sample(trainval,tr)
print("train and val size",tv)
print("train size",tr)
Open the four files and place the proportioned data into the files
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
# We loop through the subscript list
for i in list:
Write the.xml file to the file
name=total_xml[i][:-4] +'\n'
# in trainval file
if i in trainval:
Write the trainval file
ftrainval.write(name)
Write the train file in train
if i in train:
ftrain.write(name)
Write to val if not present
else:
fval.write(name)
Write test to 10% of files not in trainval
else:
ftest.write(name)
# close file
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
print("Generate txt in ImageSets done.")
# when mode 0 or 2
if annotation_mode == 0 or annotation_mode == 2:
print("Generate 2007_train.txt and 2007_val.txt for train.")
# create two files in a different directory (year: 2007 or 2012)
for year, image_set in VOCdevkit_sets:
Get the id list of the image from the train. TXT and val. TXT we built above
image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
# Open file
list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
# loop over the list of image_id
for image_id in image_ids:
Write file path information
list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))
Convert data in XML to x1,y1,x2,y2,class_num multi-group format
convert_annotation(year, image_id, list_file)
Record a single line of data plus a newline character
list_file.write('\n')
# close file
list_file.close()
print("Generate 2007_train.txt and 2007_val.txt for train done.")
Copy the code
Split file will not be shown, only show 2007_train.txt effect
View the network hierarchy and the number of parameters
- summary.py
import torch
from torchsummary import summary
from nets.yolo import YoloBody
if __name__ == "__main__":
Device is required to specify whether the network is running on the GPU or CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
m = YoloBody([[6.7.8], [3.4.5], [0.1.2]], 80).to(device)
summary(m, input_size=(3.416.416))
Copy the code
The structure of the network is as follows
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1.32.416.416] 864
BatchNorm2d-2 [-1.32.416.416] 64
LeakyReLU-3 [-1.32.416.416] 0
Conv2d-4 [-1.64.208.208] 18.432
BatchNorm2d-5 [-1.64.208.208] 128
LeakyReLU-6 [-1.64.208.208] 0
Conv2d-7 [-1.32.208.208] 2,048
BatchNorm2d-8 [-1.32.208.208] 64
LeakyReLU-9 [-1.32.208.208] 0
Conv2d-10 [-1.64.208.208] 18.432
BatchNorm2d-11 [-1.64.208.208] 128
LeakyReLU-12 [-1.64.208.208] 0
BasicBlock-13 [-1.64.208.208] 0
Conv2d-14 [-1.128.104.104] 73.728
BatchNorm2d-15 [-1.128.104.104] 256
LeakyReLU-16 [-1.128.104.104] 0
Conv2d-17 [-1.64.104.104] 8.192
BatchNorm2d-18 [-1.64.104.104] 128
LeakyReLU-19 [-1.64.104.104] 0
Conv2d-20 [-1.128.104.104] 73.728
BatchNorm2d-21 [-1.128.104.104] 256
LeakyReLU-22 [-1.128.104.104] 0
BasicBlock-23 [-1.128.104.104] 0
Conv2d-24 [-1.64.104.104] 8.192
BatchNorm2d-25 [-1.64.104.104] 128
LeakyReLU-26 [-1.64.104.104] 0
Conv2d-27 [-1.128.104.104] 73.728
BatchNorm2d-28 [-1.128.104.104] 256
LeakyReLU-29 [-1.128.104.104] 0
BasicBlock-30 [-1.128.104.104] 0
Conv2d-31 [-1.256.52.52] 294.912
BatchNorm2d-32 [-1.256.52.52] 512
LeakyReLU-33 [-1.256.52.52] 0
Conv2d-34 [-1.128.52.52] 32.768
BatchNorm2d-35 [-1.128.52.52] 256
LeakyReLU-36 [-1.128.52.52] 0
Conv2d-37 [-1.256.52.52] 294.912
BatchNorm2d-38 [-1.256.52.52] 512
LeakyReLU-39 [-1.256.52.52] 0
BasicBlock-40 [-1.256.52.52] 0
Conv2d-41 [-1.128.52.52] 32.768
BatchNorm2d-42 [-1.128.52.52] 256
LeakyReLU-43 [-1.128.52.52] 0
Conv2d-44 [-1.256.52.52] 294.912
BatchNorm2d-45 [-1.256.52.52] 512
LeakyReLU-46 [-1.256.52.52] 0
BasicBlock-47 [-1.256.52.52] 0
Conv2d-48 [-1.128.52.52] 32.768
BatchNorm2d-49 [-1.128.52.52] 256
LeakyReLU-50 [-1.128.52.52] 0
Conv2d-51 [-1.256.52.52] 294.912
BatchNorm2d-52 [-1.256.52.52] 512
LeakyReLU-53 [-1.256.52.52] 0
BasicBlock-54 [-1.256.52.52] 0
Conv2d-55 [-1.128.52.52] 32.768
BatchNorm2d-56 [-1.128.52.52] 256
LeakyReLU-57 [-1.128.52.52] 0
Conv2d-58 [-1.256.52.52] 294.912
BatchNorm2d-59 [-1.256.52.52] 512
LeakyReLU-60 [-1.256.52.52] 0
BasicBlock-61 [-1.256.52.52] 0
Conv2d-62 [-1.128.52.52] 32.768
BatchNorm2d-63 [-1.128.52.52] 256
LeakyReLU-64 [-1.128.52.52] 0
Conv2d-65 [-1.256.52.52] 294.912
BatchNorm2d-66 [-1.256.52.52] 512
LeakyReLU-67 [-1.256.52.52] 0
BasicBlock-68 [-1.256.52.52] 0
Conv2d-69 [-1.128.52.52] 32.768
BatchNorm2d-70 [-1.128.52.52] 256
LeakyReLU-71 [-1.128.52.52] 0
Conv2d-72 [-1.256.52.52] 294.912
BatchNorm2d-73 [-1.256.52.52] 512
LeakyReLU-74 [-1.256.52.52] 0
BasicBlock-75 [-1.256.52.52] 0
Conv2d-76 [-1.128.52.52] 32.768
BatchNorm2d-77 [-1.128.52.52] 256
LeakyReLU-78 [-1.128.52.52] 0
Conv2d-79 [-1.256.52.52] 294.912
BatchNorm2d-80 [-1.256.52.52] 512
LeakyReLU-81 [-1.256.52.52] 0
BasicBlock-82 [-1.256.52.52] 0
Conv2d-83 [-1.128.52.52] 32.768
BatchNorm2d-84 [-1.128.52.52] 256
LeakyReLU-85 [-1.128.52.52] 0
Conv2d-86 [-1.256.52.52] 294.912
BatchNorm2d-87 [-1.256.52.52] 512
LeakyReLU-88 [-1.256.52.52] 0
BasicBlock-89 [-1.256.52.52] 0
Conv2d-90 [-1.512.26.26] 1.179.648
BatchNorm2d-91 [-1.512.26.26] 1,024
LeakyReLU-92 [-1.512.26.26] 0
Conv2d-93 [-1.256.26.26] 131,072
BatchNorm2d-94 [-1.256.26.26] 512
LeakyReLU-95 [-1.256.26.26] 0
Conv2d-96 [-1.512.26.26] 1.179.648
BatchNorm2d-97 [-1.512.26.26] 1,024
LeakyReLU-98 [-1.512.26.26] 0
BasicBlock-99 [-1.512.26.26] 0
Conv2d-100 [-1.256.26.26] 131,072
BatchNorm2d-101 [-1.256.26.26] 512
LeakyReLU-102 [-1.256.26.26] 0
Conv2d-103 [-1.512.26.26] 1.179.648
BatchNorm2d-104 [-1.512.26.26] 1,024
LeakyReLU-105 [-1.512.26.26] 0
BasicBlock-106 [-1.512.26.26] 0
Conv2d-107 [-1.256.26.26] 131,072
BatchNorm2d-108 [-1.256.26.26] 512
LeakyReLU-109 [-1.256.26.26] 0
Conv2d-110 [-1.512.26.26] 1.179.648
BatchNorm2d-111 [-1.512.26.26] 1,024
LeakyReLU-112 [-1.512.26.26] 0
BasicBlock-113 [-1.512.26.26] 0
Conv2d-114 [-1.256.26.26] 131,072
BatchNorm2d-115 [-1.256.26.26] 512
LeakyReLU-116 [-1.256.26.26] 0
Conv2d-117 [-1.512.26.26] 1.179.648
BatchNorm2d-118 [-1.512.26.26] 1,024
LeakyReLU-119 [-1.512.26.26] 0
BasicBlock-120 [-1.512.26.26] 0
Conv2d-121 [-1.256.26.26] 131,072
BatchNorm2d-122 [-1.256.26.26] 512
LeakyReLU-123 [-1.256.26.26] 0
Conv2d-124 [-1.512.26.26] 1.179.648
BatchNorm2d-125 [-1.512.26.26] 1,024
LeakyReLU-126 [-1.512.26.26] 0
BasicBlock-127 [-1.512.26.26] 0
Conv2d-128 [-1.256.26.26] 131,072
BatchNorm2d-129 [-1.256.26.26] 512
LeakyReLU-130 [-1.256.26.26] 0
Conv2d-131 [-1.512.26.26] 1.179.648
BatchNorm2d-132 [-1.512.26.26] 1,024
LeakyReLU-133 [-1.512.26.26] 0
BasicBlock-134 [-1.512.26.26] 0
Conv2d-135 [-1.256.26.26] 131,072
BatchNorm2d-136 [-1.256.26.26] 512
LeakyReLU-137 [-1.256.26.26] 0
Conv2d-138 [-1.512.26.26] 1.179.648
BatchNorm2d-139 [-1.512.26.26] 1,024
LeakyReLU-140 [-1.512.26.26] 0
BasicBlock-141 [-1.512.26.26] 0
Conv2d-142 [-1.256.26.26] 131,072
BatchNorm2d-143 [-1.256.26.26] 512
LeakyReLU-144 [-1.256.26.26] 0
Conv2d-145 [-1.512.26.26] 1.179.648
BatchNorm2d-146 [-1.512.26.26] 1,024
LeakyReLU-147 [-1.512.26.26] 0
BasicBlock-148 [-1.512.26.26] 0
Conv2d-149 [-1.1024.13.13] 4.718.592
BatchNorm2d-150 [-1.1024.13.13] 2,048
LeakyReLU-151 [-1.1024.13.13] 0
Conv2d-152 [-1.512.13.13] 524.288
BatchNorm2d-153 [-1.512.13.13] 1,024
LeakyReLU-154 [-1.512.13.13] 0
Conv2d-155 [-1.1024.13.13] 4.718.592
BatchNorm2d-156 [-1.1024.13.13] 2,048
LeakyReLU-157 [-1.1024.13.13] 0
BasicBlock-158 [-1.1024.13.13] 0
Conv2d-159 [-1.512.13.13] 524.288
BatchNorm2d-160 [-1.512.13.13] 1,024
LeakyReLU-161 [-1.512.13.13] 0
Conv2d-162 [-1.1024.13.13] 4.718.592
BatchNorm2d-163 [-1.1024.13.13] 2,048
LeakyReLU-164 [-1.1024.13.13] 0
BasicBlock-165 [-1.1024.13.13] 0
Conv2d-166 [-1.512.13.13] 524.288
BatchNorm2d-167 [-1.512.13.13] 1,024
LeakyReLU-168 [-1.512.13.13] 0
Conv2d-169 [-1.1024.13.13] 4.718.592
BatchNorm2d-170 [-1.1024.13.13] 2,048
LeakyReLU-171 [-1.1024.13.13] 0
BasicBlock-172 [-1.1024.13.13] 0
Conv2d-173 [-1.512.13.13] 524.288
BatchNorm2d-174 [-1.512.13.13] 1,024
LeakyReLU-175 [-1.512.13.13] 0
Conv2d-176 [-1.1024.13.13] 4.718.592
BatchNorm2d-177 [-1.1024.13.13] 2,048
LeakyReLU-178 [-1.1024.13.13] 0
BasicBlock-179 [-1.1024.13.13] 0
DarkNet-180 [[-1.256.52.52], [...1.512.26.26], [...1.1024.13.13]] 0
Conv2d-181 [-1.512.13.13] 524.288
BatchNorm2d-182 [-1.512.13.13] 1,024
LeakyReLU-183 [-1.512.13.13] 0
Conv2d-184 [-1.1024.13.13] 4.718.592
BatchNorm2d-185 [-1.1024.13.13] 2,048
LeakyReLU-186 [-1.1024.13.13] 0
Conv2d-187 [-1.512.13.13] 524.288
BatchNorm2d-188 [-1.512.13.13] 1,024
LeakyReLU-189 [-1.512.13.13] 0
Conv2d-190 [-1.1024.13.13] 4.718.592
BatchNorm2d-191 [-1.1024.13.13] 2,048
LeakyReLU-192 [-1.1024.13.13] 0
Conv2d-193 [-1.512.13.13] 524.288
BatchNorm2d-194 [-1.512.13.13] 1,024
LeakyReLU-195 [-1.512.13.13] 0
Conv2d-196 [-1.1024.13.13] 4.718.592
BatchNorm2d-197 [-1.1024.13.13] 2,048
LeakyReLU-198 [-1.1024.13.13] 0
Conv2d-199 [-1.255.13.13] 261.375
Conv2d-200 [-1.256.13.13] 131,072
BatchNorm2d-201 [-1.256.13.13] 512
LeakyReLU-202 [-1.256.13.13] 0
Upsample-203 [-1.256.26.26] 0
Conv2d-204 [-1.256.26.26] 196.608
BatchNorm2d-205 [-1.256.26.26] 512
LeakyReLU-206 [-1.256.26.26] 0
Conv2d-207 [-1.512.26.26] 1.179.648
BatchNorm2d-208 [-1.512.26.26] 1,024
LeakyReLU-209 [-1.512.26.26] 0
Conv2d-210 [-1.256.26.26] 131,072
BatchNorm2d-211 [-1.256.26.26] 512
LeakyReLU-212 [-1.256.26.26] 0
Conv2d-213 [-1.512.26.26] 1.179.648
BatchNorm2d-214 [-1.512.26.26] 1,024
LeakyReLU-215 [-1.512.26.26] 0
Conv2d-216 [-1.256.26.26] 131,072
BatchNorm2d-217 [-1.256.26.26] 512
LeakyReLU-218 [-1.256.26.26] 0
Conv2d-219 [-1.512.26.26] 1.179.648
BatchNorm2d-220 [-1.512.26.26] 1,024
LeakyReLU-221 [-1.512.26.26] 0
Conv2d-222 [-1.255.26.26] 130.815
Conv2d-223 [-1.128.26.26] 32.768
BatchNorm2d-224 [-1.128.26.26] 256
LeakyReLU-225 [-1.128.26.26] 0
Upsample-226 [-1.128.52.52] 0
Conv2d-227 [-1.128.52.52] 49.152
BatchNorm2d-228 [-1.128.52.52] 256
LeakyReLU-229 [-1.128.52.52] 0
Conv2d-230 [-1.256.52.52] 294.912
BatchNorm2d-231 [-1.256.52.52] 512
LeakyReLU-232 [-1.256.52.52] 0
Conv2d-233 [-1.128.52.52] 32.768
BatchNorm2d-234 [-1.128.52.52] 256
LeakyReLU-235 [-1.128.52.52] 0
Conv2d-236 [-1.256.52.52] 294.912
BatchNorm2d-237 [-1.256.52.52] 512
LeakyReLU-238 [-1.256.52.52] 0
Conv2d-239 [-1.128.52.52] 32.768
BatchNorm2d-240 [-1.128.52.52] 256
LeakyReLU-241 [-1.128.52.52] 0
Conv2d-242 [-1.256.52.52] 294.912
BatchNorm2d-243 [-1.256.52.52] 512
LeakyReLU-244 [-1.256.52.52] 0
Conv2d-245 [-1.255.52.52] 65.535
================================================================
Total params: 61.949.149
Trainable params: 61.949.149
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.98
Forward/backward pass size (MB): 998.13
Params size (MB): 236.32
Estimated Total Size (MB): 1236.43
----------------------------------------------------------------
Copy the code
We found that the total number of parameters is about 62 million, and the estimated total memory consumption is 1.23G, which is still a relatively complex network.
Cluster to get the size of the box
- One of the improvements in Yolov2 is that the way the boxes are derived is by using a clustering approach
- kmeans_for_anchors.py
def cas_iou(box, cluster) :
# Calculate iOU, which has been analyzed many times
x = np.minimum(cluster[:, 0], box[0])
y = np.minimum(cluster[:, 1], box[1])
intersection = x * y
area1 = box[0] * box[1]
area2 = cluster[:,0] * cluster[:,1]
iou = intersection / (area1 + area2 - intersection)
return iou
def avg_iou(box, cluster) :
return np.mean([np.max(cas_iou(box[i], cluster)) for i in range(box.shape[0]])def kmeans(box, k) :
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# How many boxes are there
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
row = box.shape[0]
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# Position of each point in each box
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
distance = np.empty((row, k))
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# Final clustering location
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
last_clu = np.zeros((row, ))
np.random.seed()
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# Choose 9 randomly as clustering centers
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
cluster = box[np.random.choice(row, k, replace = False)]
iter = 0
while True:
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
Calculate the ratio of width to height between the current box and the prior box
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
for i in range(row):
# In this part, we use 1-iou as the measure of distance
distance[i] = 1 - cas_iou(box[i], cluster)
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
Extract the smallest point
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
When convergence is stable, you can jump out of the loop ahead of time
near = np.argmin(distance, axis=1)
if (last_clu == near).all() :break
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
Find the middle point of each class
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# We want to get the center of the cluster for each class
for j in range(k):
cluster[j] = np.median(box[near == j],axis=0)
#
last_clu = near
if iter % 5= =0:
print('iter: {:d}. avg_iou:{:.2f}'.format(iter, avg_iou(box, cluster)))
iter+ =1
# return the center of the cluster, and the minimum distance
return cluster, near
def load_data(path) :
data = []
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
Look for box for each XML
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
Loop through the data to get the width and the coordinates of the upper left and lower right corners
for xml_file in tqdm(glob.glob('{}/*xml'.format(path))):
tree = ET.parse(xml_file)
height = int(tree.findtext('./size/height'))
width = int(tree.findtext('./size/width'))
if height<=0 or width<=0:
continue
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
Get the width and height for each target
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
for obj in tree.iter('object') :# Normalized analysis of coordinates
xmin = int(float(obj.findtext('bndbox/xmin'))) / width
ymin = int(float(obj.findtext('bndbox/ymin'))) / height
xmax = int(float(obj.findtext('bndbox/xmax'))) / width
ymax = int(float(obj.findtext('bndbox/ymax'))) / height
xmin = np.float64(xmin)
ymin = np.float64(ymin)
xmax = np.float64(xmax)
ymax = np.float64(ymax)
Get the normalized width and height of the box
data.append([xmax - xmin, ymax - ymin])
return np.array(data)
if __name__ == '__main__':
np.random.seed(0)
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# running the program will calculate '. / VOCdevkit VOC2007 / Annotations in the XML
# will generate yolo_anchors. TXT
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# We entered the size of the picture and the number of boxes information
input_shape = [416.416]
anchors_num = 9
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
Load the data set, you can use VOC XML
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# path to our box XML file
path = 'VOCdevkit/VOC2007/Annotations'
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
Load all XML
Save the format to scale width,height
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
print('Load xmls.')
# load data
data = load_data(path)
print('Load xmls done.')
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# Use k clustering algorithm
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
print('K-means boxes.')
# We use Kmeans to cluster the normalized ones above us
cluster, near = kmeans(data, anchors_num)
print('K-means boxes done.')
The normalized width and height information is restored, which is consistent with the specified size of the input image (416,416)
data = data * np.array([input_shape[1], input_shape[0]])
# Normalization of clustering centers into real coordinates
cluster = cluster * np.array([input_shape[1], input_shape[0]])
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# drawing
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
# Draw points of our box in horizontal and vertical coordinates of length and width.
for j in range(anchors_num):
# Draw all boxes grouped with the current box
plt.scatter(data[near == j][:,0], data[near == j][:,1])
# Center point drawing
plt.scatter(cluster[j][0], cluster[j][1], marker='x', c='black')
# Save and display pictures
plt.savefig("kmeans_for_anchors.jpg")
plt.show()
print('Save kmeans_for_anchors.jpg in root dir.')
# Sort the cluster centers by area
cluster = cluster[np.argsort(cluster[:, 0] * cluster[:, 1]]print('avg_ratio:{:.2f}'.format(avg_iou(data, cluster)))
print(cluster)
# Write the sorted cluster center above to the file
f = open("yolo_anchors.txt".'w')
row = np.shape(cluster)[0]
for i in range(row):
if i == 0:
x_y = "%d,%d" % (cluster[i][0], cluster[i][1])
else:
x_y = ", %d,%d" % (cluster[i][0], cluster[i][1])
f.write(x_y)
f.close()
Copy the code
After running, we get the following image