Generate training validation test TXT file

  • voc_annotation.py
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - -- -- -- -- -- -- -- -- - #
# annotation_mode specifies what is evaluated at run time for this file
# annotation_mode 0 represents the entire label process, including access to VOCdevkit VOC2007 / ImageSets TXT inside and training with 2007 _train. TXT, 2007 _val. TXT
# annotation_mode 1 represents VOCdevkit VOC2007 / ImageSets TXT inside
# annotation_mode = 2 to obtain 2007_train. TXT and 2007_val.txt for training
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - -- -- -- -- -- -- -- -- - #

# this is the type of file we generate. If we specify different values, we will generate different files
annotation_mode     = 0
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
TXT, 2007_val.txt, 2007_val.txt, 2007_val.txt
Use the same classes_path used for training and prediction
# if 2007_train.txt is generated without destination information
# Then classes are not set correctly
# only valid if annotation_mode is 0 and 2
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #

# specify the location of the class name file
classes_path        = 'model_data/voc_classes.txt'
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - -- -- -- -- -- -- -- -- - #
# trainval_percent Specifies the ratio of (training set + verification set) to test set. By default (training set + verification set): test set = 9:1
# train_percent Specifies the ratio of training sets to verification sets in (training set + verification set). By default, training set: verification set = 9:1
Annotation_mode is valid only if annotation_mode is 0 and 1
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - -- -- -- -- -- -- -- -- - #

# trainval is
trainval_percent    = 0.9
train_percent       = 0.9
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- #
# point to the folder where the VOC dataset is located
The default is to point to VOC datasets in the root directory
# -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- #

Specify the root folder for the data
VOCdevkit_path  = 'VOCdevkit'

# use it in the loop below
VOCdevkit_sets = [('2007'.'train'), ('2007'.'val')]

Read the file information from class_path above to get the list of class_name
classes, _ = get_classes(classes_path)

def convert_annotation(year, image_id, list_file) :
    Convert data in XML to x1,y1,x2,y2,class_num multi-group format
    in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
    
    # XML parses data
    tree=ET.parse(in_file)
    root = tree.getroot()

    for obj in root.iter('object'):
        difficult = 0 
        if obj.find('difficult')! =None:
            difficult = obj.find('difficult').text
        cls = obj.find('name').text
        
        # Just note that here we remove the ones that are not in our class_list and those marked difficult
        if cls not in classes or int(difficult)==1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
        list_file.write("" + ",".join([str(a) for a in b]) + ', ' + str(cls_id))
        
if __name__ == "__main__":
    random.seed(0)
    # select 0 or 1 in the above mode
    if annotation_mode == 0 or annotation_mode == 1:
        print("Generate txt in ImageSets.")
        
        Get the file path of the XML file
        xmlfilepath     = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
        
        # Select the path to the file we want to save
        saveBasePath    = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
        
        List all files in the XML folder to form a list
        temp_xml        = os.listdir(xmlfilepath)
        total_xml       = []
        for xml in temp_xml:
            if xml.endswith(".xml"):
                total_xml.append(xml)
		
        Divide the data according to the ratio we defined above
        num     = len(total_xml)  
        list    = range(num)
        # trainval accounts for 90% of the total data, train accounts for 90% of trainval
        tv      = int(num*trainval_percent)  
        tr      = int(tv*train_percent)
        # We take 90% of the randomly selected data set as trainval's data set, where subscripts are placed in a list
        trainval= random.sample(list,tv)
        # We randomly selected 90% of trainval data set as train data set
        train   = random.sample(trainval,tr)  
        
        print("train and val size",tv)
        print("train size",tr)
        Open the four files and place the proportioned data into the files
        ftrainval   = open(os.path.join(saveBasePath,'trainval.txt'), 'w')  
        ftest       = open(os.path.join(saveBasePath,'test.txt'), 'w')  
        ftrain      = open(os.path.join(saveBasePath,'train.txt'), 'w')  
        fval        = open(os.path.join(saveBasePath,'val.txt'), 'w')  
        
        # We loop through the subscript list
        for i in list:
            Write the.xml file to the file
            name=total_xml[i][:-4] +'\n'
            
            # in trainval file
            if i in trainval: 
                Write the trainval file
                ftrainval.write(name)
                Write the train file in train
                if i in train:  
                    ftrain.write(name) 
                Write to val if not present
                else:  
                    fval.write(name)  
           Write test to 10% of files not in trainval
            else:  
                ftest.write(name)  
        
        # close file
        ftrainval.close()  
        ftrain.close()  
        fval.close()  
        ftest.close()
        print("Generate txt in ImageSets done.")
	
    # when mode 0 or 2
    if annotation_mode == 0 or annotation_mode == 2:
        print("Generate 2007_train.txt and 2007_val.txt for train.")
        # create two files in a different directory (year: 2007 or 2012)
        for year, image_set in VOCdevkit_sets:
            Get the id list of the image from the train. TXT and val. TXT we built above
            image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
            # Open file
            list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
            
            # loop over the list of image_id
            for image_id in image_ids:
                Write file path information
                list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))
				Convert data in XML to x1,y1,x2,y2,class_num multi-group format
                convert_annotation(year, image_id, list_file)
                Record a single line of data plus a newline character
                list_file.write('\n')
            # close file
            list_file.close()
        print("Generate 2007_train.txt and 2007_val.txt for train done.")
Copy the code

Split file will not be shown, only show 2007_train.txt effect

View the network hierarchy and the number of parameters

  • summary.py
import torch
from torchsummary import summary

from nets.yolo import YoloBody

if __name__ == "__main__":
    Device is required to specify whether the network is running on the GPU or CPU
    device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    m       = YoloBody([[6.7.8], [3.4.5], [0.1.2]], 80).to(device)
    summary(m, input_size=(3.416.416))
Copy the code

The structure of the network is as follows

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1.32.416.416]             864
       BatchNorm2d-2         [-1.32.416.416]              64
         LeakyReLU-3         [-1.32.416.416]               0
            Conv2d-4         [-1.64.208.208]          18.432
       BatchNorm2d-5         [-1.64.208.208]             128
         LeakyReLU-6         [-1.64.208.208]               0
            Conv2d-7         [-1.32.208.208]           2,048
       BatchNorm2d-8         [-1.32.208.208]              64
         LeakyReLU-9         [-1.32.208.208]               0
           Conv2d-10         [-1.64.208.208]          18.432
      BatchNorm2d-11         [-1.64.208.208]             128
        LeakyReLU-12         [-1.64.208.208]               0
       BasicBlock-13         [-1.64.208.208]               0
           Conv2d-14        [-1.128.104.104]          73.728
      BatchNorm2d-15        [-1.128.104.104]             256
        LeakyReLU-16        [-1.128.104.104]               0
           Conv2d-17         [-1.64.104.104]           8.192
      BatchNorm2d-18         [-1.64.104.104]             128
        LeakyReLU-19         [-1.64.104.104]               0
           Conv2d-20        [-1.128.104.104]          73.728
      BatchNorm2d-21        [-1.128.104.104]             256
        LeakyReLU-22        [-1.128.104.104]               0
       BasicBlock-23        [-1.128.104.104]               0
           Conv2d-24         [-1.64.104.104]           8.192
      BatchNorm2d-25         [-1.64.104.104]             128
        LeakyReLU-26         [-1.64.104.104]               0
           Conv2d-27        [-1.128.104.104]          73.728
      BatchNorm2d-28        [-1.128.104.104]             256
        LeakyReLU-29        [-1.128.104.104]               0
       BasicBlock-30        [-1.128.104.104]               0
           Conv2d-31          [-1.256.52.52]         294.912
      BatchNorm2d-32          [-1.256.52.52]             512
        LeakyReLU-33          [-1.256.52.52]               0
           Conv2d-34          [-1.128.52.52]          32.768
      BatchNorm2d-35          [-1.128.52.52]             256
        LeakyReLU-36          [-1.128.52.52]               0
           Conv2d-37          [-1.256.52.52]         294.912
      BatchNorm2d-38          [-1.256.52.52]             512
        LeakyReLU-39          [-1.256.52.52]               0
       BasicBlock-40          [-1.256.52.52]               0
           Conv2d-41          [-1.128.52.52]          32.768
      BatchNorm2d-42          [-1.128.52.52]             256
        LeakyReLU-43          [-1.128.52.52]               0
           Conv2d-44          [-1.256.52.52]         294.912
      BatchNorm2d-45          [-1.256.52.52]             512
        LeakyReLU-46          [-1.256.52.52]               0
       BasicBlock-47          [-1.256.52.52]               0
           Conv2d-48          [-1.128.52.52]          32.768
      BatchNorm2d-49          [-1.128.52.52]             256
        LeakyReLU-50          [-1.128.52.52]               0
           Conv2d-51          [-1.256.52.52]         294.912
      BatchNorm2d-52          [-1.256.52.52]             512
        LeakyReLU-53          [-1.256.52.52]               0
       BasicBlock-54          [-1.256.52.52]               0
           Conv2d-55          [-1.128.52.52]          32.768
      BatchNorm2d-56          [-1.128.52.52]             256
        LeakyReLU-57          [-1.128.52.52]               0
           Conv2d-58          [-1.256.52.52]         294.912
      BatchNorm2d-59          [-1.256.52.52]             512
        LeakyReLU-60          [-1.256.52.52]               0
       BasicBlock-61          [-1.256.52.52]               0
           Conv2d-62          [-1.128.52.52]          32.768
      BatchNorm2d-63          [-1.128.52.52]             256
        LeakyReLU-64          [-1.128.52.52]               0
           Conv2d-65          [-1.256.52.52]         294.912
      BatchNorm2d-66          [-1.256.52.52]             512
        LeakyReLU-67          [-1.256.52.52]               0
       BasicBlock-68          [-1.256.52.52]               0
           Conv2d-69          [-1.128.52.52]          32.768
      BatchNorm2d-70          [-1.128.52.52]             256
        LeakyReLU-71          [-1.128.52.52]               0
           Conv2d-72          [-1.256.52.52]         294.912
      BatchNorm2d-73          [-1.256.52.52]             512
        LeakyReLU-74          [-1.256.52.52]               0
       BasicBlock-75          [-1.256.52.52]               0
           Conv2d-76          [-1.128.52.52]          32.768
      BatchNorm2d-77          [-1.128.52.52]             256
        LeakyReLU-78          [-1.128.52.52]               0
           Conv2d-79          [-1.256.52.52]         294.912
      BatchNorm2d-80          [-1.256.52.52]             512
        LeakyReLU-81          [-1.256.52.52]               0
       BasicBlock-82          [-1.256.52.52]               0
           Conv2d-83          [-1.128.52.52]          32.768
      BatchNorm2d-84          [-1.128.52.52]             256
        LeakyReLU-85          [-1.128.52.52]               0
           Conv2d-86          [-1.256.52.52]         294.912
      BatchNorm2d-87          [-1.256.52.52]             512
        LeakyReLU-88          [-1.256.52.52]               0
       BasicBlock-89          [-1.256.52.52]               0
           Conv2d-90          [-1.512.26.26]       1.179.648
      BatchNorm2d-91          [-1.512.26.26]           1,024
        LeakyReLU-92          [-1.512.26.26]               0
           Conv2d-93          [-1.256.26.26]         131,072
      BatchNorm2d-94          [-1.256.26.26]             512
        LeakyReLU-95          [-1.256.26.26]               0
           Conv2d-96          [-1.512.26.26]       1.179.648
      BatchNorm2d-97          [-1.512.26.26]           1,024
        LeakyReLU-98          [-1.512.26.26]               0
       BasicBlock-99          [-1.512.26.26]               0
          Conv2d-100          [-1.256.26.26]         131,072
     BatchNorm2d-101          [-1.256.26.26]             512
       LeakyReLU-102          [-1.256.26.26]               0
          Conv2d-103          [-1.512.26.26]       1.179.648
     BatchNorm2d-104          [-1.512.26.26]           1,024
       LeakyReLU-105          [-1.512.26.26]               0
      BasicBlock-106          [-1.512.26.26]               0
          Conv2d-107          [-1.256.26.26]         131,072
     BatchNorm2d-108          [-1.256.26.26]             512
       LeakyReLU-109          [-1.256.26.26]               0
          Conv2d-110          [-1.512.26.26]       1.179.648
     BatchNorm2d-111          [-1.512.26.26]           1,024
       LeakyReLU-112          [-1.512.26.26]               0
      BasicBlock-113          [-1.512.26.26]               0
          Conv2d-114          [-1.256.26.26]         131,072
     BatchNorm2d-115          [-1.256.26.26]             512
       LeakyReLU-116          [-1.256.26.26]               0
          Conv2d-117          [-1.512.26.26]       1.179.648
     BatchNorm2d-118          [-1.512.26.26]           1,024
       LeakyReLU-119          [-1.512.26.26]               0
      BasicBlock-120          [-1.512.26.26]               0
          Conv2d-121          [-1.256.26.26]         131,072
     BatchNorm2d-122          [-1.256.26.26]             512
       LeakyReLU-123          [-1.256.26.26]               0
          Conv2d-124          [-1.512.26.26]       1.179.648
     BatchNorm2d-125          [-1.512.26.26]           1,024
       LeakyReLU-126          [-1.512.26.26]               0
      BasicBlock-127          [-1.512.26.26]               0
          Conv2d-128          [-1.256.26.26]         131,072
     BatchNorm2d-129          [-1.256.26.26]             512
       LeakyReLU-130          [-1.256.26.26]               0
          Conv2d-131          [-1.512.26.26]       1.179.648
     BatchNorm2d-132          [-1.512.26.26]           1,024
       LeakyReLU-133          [-1.512.26.26]               0
      BasicBlock-134          [-1.512.26.26]               0
          Conv2d-135          [-1.256.26.26]         131,072
     BatchNorm2d-136          [-1.256.26.26]             512
       LeakyReLU-137          [-1.256.26.26]               0
          Conv2d-138          [-1.512.26.26]       1.179.648
     BatchNorm2d-139          [-1.512.26.26]           1,024
       LeakyReLU-140          [-1.512.26.26]               0
      BasicBlock-141          [-1.512.26.26]               0
          Conv2d-142          [-1.256.26.26]         131,072
     BatchNorm2d-143          [-1.256.26.26]             512
       LeakyReLU-144          [-1.256.26.26]               0
          Conv2d-145          [-1.512.26.26]       1.179.648
     BatchNorm2d-146          [-1.512.26.26]           1,024
       LeakyReLU-147          [-1.512.26.26]               0
      BasicBlock-148          [-1.512.26.26]               0
          Conv2d-149         [-1.1024.13.13]       4.718.592
     BatchNorm2d-150         [-1.1024.13.13]           2,048
       LeakyReLU-151         [-1.1024.13.13]               0
          Conv2d-152          [-1.512.13.13]         524.288
     BatchNorm2d-153          [-1.512.13.13]           1,024
       LeakyReLU-154          [-1.512.13.13]               0
          Conv2d-155         [-1.1024.13.13]       4.718.592
     BatchNorm2d-156         [-1.1024.13.13]           2,048
       LeakyReLU-157         [-1.1024.13.13]               0
      BasicBlock-158         [-1.1024.13.13]               0
          Conv2d-159          [-1.512.13.13]         524.288
     BatchNorm2d-160          [-1.512.13.13]           1,024
       LeakyReLU-161          [-1.512.13.13]               0
          Conv2d-162         [-1.1024.13.13]       4.718.592
     BatchNorm2d-163         [-1.1024.13.13]           2,048
       LeakyReLU-164         [-1.1024.13.13]               0
      BasicBlock-165         [-1.1024.13.13]               0
          Conv2d-166          [-1.512.13.13]         524.288
     BatchNorm2d-167          [-1.512.13.13]           1,024
       LeakyReLU-168          [-1.512.13.13]               0
          Conv2d-169         [-1.1024.13.13]       4.718.592
     BatchNorm2d-170         [-1.1024.13.13]           2,048
       LeakyReLU-171         [-1.1024.13.13]               0
      BasicBlock-172         [-1.1024.13.13]               0
          Conv2d-173          [-1.512.13.13]         524.288
     BatchNorm2d-174          [-1.512.13.13]           1,024
       LeakyReLU-175          [-1.512.13.13]               0
          Conv2d-176         [-1.1024.13.13]       4.718.592
     BatchNorm2d-177         [-1.1024.13.13]           2,048
       LeakyReLU-178         [-1.1024.13.13]               0
      BasicBlock-179         [-1.1024.13.13]               0
         DarkNet-180  [[-1.256.52.52], [...1.512.26.26], [...1.1024.13.13]]               0
          Conv2d-181          [-1.512.13.13]         524.288
     BatchNorm2d-182          [-1.512.13.13]           1,024
       LeakyReLU-183          [-1.512.13.13]               0
          Conv2d-184         [-1.1024.13.13]       4.718.592
     BatchNorm2d-185         [-1.1024.13.13]           2,048
       LeakyReLU-186         [-1.1024.13.13]               0
          Conv2d-187          [-1.512.13.13]         524.288
     BatchNorm2d-188          [-1.512.13.13]           1,024
       LeakyReLU-189          [-1.512.13.13]               0
          Conv2d-190         [-1.1024.13.13]       4.718.592
     BatchNorm2d-191         [-1.1024.13.13]           2,048
       LeakyReLU-192         [-1.1024.13.13]               0
          Conv2d-193          [-1.512.13.13]         524.288
     BatchNorm2d-194          [-1.512.13.13]           1,024
       LeakyReLU-195          [-1.512.13.13]               0
          Conv2d-196         [-1.1024.13.13]       4.718.592
     BatchNorm2d-197         [-1.1024.13.13]           2,048
       LeakyReLU-198         [-1.1024.13.13]               0
          Conv2d-199          [-1.255.13.13]         261.375
          Conv2d-200          [-1.256.13.13]         131,072
     BatchNorm2d-201          [-1.256.13.13]             512
       LeakyReLU-202          [-1.256.13.13]               0
        Upsample-203          [-1.256.26.26]               0
          Conv2d-204          [-1.256.26.26]         196.608
     BatchNorm2d-205          [-1.256.26.26]             512
       LeakyReLU-206          [-1.256.26.26]               0
          Conv2d-207          [-1.512.26.26]       1.179.648
     BatchNorm2d-208          [-1.512.26.26]           1,024
       LeakyReLU-209          [-1.512.26.26]               0
          Conv2d-210          [-1.256.26.26]         131,072
     BatchNorm2d-211          [-1.256.26.26]             512
       LeakyReLU-212          [-1.256.26.26]               0
          Conv2d-213          [-1.512.26.26]       1.179.648
     BatchNorm2d-214          [-1.512.26.26]           1,024
       LeakyReLU-215          [-1.512.26.26]               0
          Conv2d-216          [-1.256.26.26]         131,072
     BatchNorm2d-217          [-1.256.26.26]             512
       LeakyReLU-218          [-1.256.26.26]               0
          Conv2d-219          [-1.512.26.26]       1.179.648
     BatchNorm2d-220          [-1.512.26.26]           1,024
       LeakyReLU-221          [-1.512.26.26]               0
          Conv2d-222          [-1.255.26.26]         130.815
          Conv2d-223          [-1.128.26.26]          32.768
     BatchNorm2d-224          [-1.128.26.26]             256
       LeakyReLU-225          [-1.128.26.26]               0
        Upsample-226          [-1.128.52.52]               0
          Conv2d-227          [-1.128.52.52]          49.152
     BatchNorm2d-228          [-1.128.52.52]             256
       LeakyReLU-229          [-1.128.52.52]               0
          Conv2d-230          [-1.256.52.52]         294.912
     BatchNorm2d-231          [-1.256.52.52]             512
       LeakyReLU-232          [-1.256.52.52]               0
          Conv2d-233          [-1.128.52.52]          32.768
     BatchNorm2d-234          [-1.128.52.52]             256
       LeakyReLU-235          [-1.128.52.52]               0
          Conv2d-236          [-1.256.52.52]         294.912
     BatchNorm2d-237          [-1.256.52.52]             512
       LeakyReLU-238          [-1.256.52.52]               0
          Conv2d-239          [-1.128.52.52]          32.768
     BatchNorm2d-240          [-1.128.52.52]             256
       LeakyReLU-241          [-1.128.52.52]               0
          Conv2d-242          [-1.256.52.52]         294.912
     BatchNorm2d-243          [-1.256.52.52]             512
       LeakyReLU-244          [-1.256.52.52]               0
          Conv2d-245          [-1.255.52.52]          65.535
================================================================
Total params: 61.949.149
Trainable params: 61.949.149
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.98
Forward/backward pass size (MB): 998.13
Params size (MB): 236.32
Estimated Total Size (MB): 1236.43
----------------------------------------------------------------
Copy the code

We found that the total number of parameters is about 62 million, and the estimated total memory consumption is 1.23G, which is still a relatively complex network.

Cluster to get the size of the box

  • One of the improvements in Yolov2 is that the way the boxes are derived is by using a clustering approach
  • kmeans_for_anchors.py
def cas_iou(box, cluster) :
    # Calculate iOU, which has been analyzed many times
    x = np.minimum(cluster[:, 0], box[0])
    y = np.minimum(cluster[:, 1], box[1])

    intersection = x * y
    area1 = box[0] * box[1]

    area2 = cluster[:,0] * cluster[:,1]
    iou = intersection / (area1 + area2 - intersection)

    return iou

def avg_iou(box, cluster) :
    return np.mean([np.max(cas_iou(box[i], cluster)) for i in range(box.shape[0]])def kmeans(box, k) :
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    # How many boxes are there
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    row = box.shape[0]
    
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    # Position of each point in each box
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    distance = np.empty((row, k))
    
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    # Final clustering location
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    last_clu = np.zeros((row, ))

    np.random.seed()

    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    # Choose 9 randomly as clustering centers
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    cluster = box[np.random.choice(row, k, replace = False)]

    iter = 0
    while True:
        # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
        Calculate the ratio of width to height between the current box and the prior box
        # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
        for i in range(row):
            # In this part, we use 1-iou as the measure of distance
            distance[i] = 1 - cas_iou(box[i], cluster)
        
        # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
        Extract the smallest point
        # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
        
        When convergence is stable, you can jump out of the loop ahead of time
        near = np.argmin(distance, axis=1)
		
        if (last_clu == near).all() :break
        
        # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
        Find the middle point of each class
        # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
        
        # We want to get the center of the cluster for each class
        for j in range(k):
            cluster[j] = np.median(box[near == j],axis=0)
		
        # 
        last_clu = near
        if iter % 5= =0:
            print('iter: {:d}. avg_iou:{:.2f}'.format(iter, avg_iou(box, cluster)))
        iter+ =1
	# return the center of the cluster, and the minimum distance
    return cluster, near

def load_data(path) :
    data = []
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    Look for box for each XML
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    
    Loop through the data to get the width and the coordinates of the upper left and lower right corners
    for xml_file in tqdm(glob.glob('{}/*xml'.format(path))):
        tree    = ET.parse(xml_file)
        height  = int(tree.findtext('./size/height'))
        width   = int(tree.findtext('./size/width'))
        if height<=0 or width<=0:
            continue
        
        # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
        Get the width and height for each target
        # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
        for obj in tree.iter('object') :# Normalized analysis of coordinates
            xmin = int(float(obj.findtext('bndbox/xmin'))) / width
            ymin = int(float(obj.findtext('bndbox/ymin'))) / height
            xmax = int(float(obj.findtext('bndbox/xmax'))) / width
            ymax = int(float(obj.findtext('bndbox/ymax'))) / height

            
            xmin = np.float64(xmin)
            ymin = np.float64(ymin)
            xmax = np.float64(xmax)
            ymax = np.float64(ymax)
            Get the normalized width and height of the box
            data.append([xmax - xmin, ymax - ymin])
    return np.array(data)

if __name__ == '__main__':
    np.random.seed(0)
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    # running the program will calculate '. / VOCdevkit VOC2007 / Annotations in the XML
    # will generate yolo_anchors. TXT
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    
    # We entered the size of the picture and the number of boxes information
    input_shape = [416.416]
    anchors_num = 9
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    Load the data set, you can use VOC XML
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    
    # path to our box XML file
    path        = 'VOCdevkit/VOC2007/Annotations'
    
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    Load all XML
    Save the format to scale width,height
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    print('Load xmls.')
    # load data
    data = load_data(path)
    print('Load xmls done.')
    
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    # Use k clustering algorithm
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    print('K-means boxes.')
    # We use Kmeans to cluster the normalized ones above us
    cluster, near   = kmeans(data, anchors_num)
    print('K-means boxes done.')
    
    The normalized width and height information is restored, which is consistent with the specified size of the input image (416,416)
    data            = data * np.array([input_shape[1], input_shape[0]])
    # Normalization of clustering centers into real coordinates
    cluster         = cluster * np.array([input_shape[1], input_shape[0]])

    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    # drawing
    # -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - #
    
    # Draw points of our box in horizontal and vertical coordinates of length and width.
    for j in range(anchors_num):
        # Draw all boxes grouped with the current box
        plt.scatter(data[near == j][:,0], data[near == j][:,1])
        # Center point drawing
        plt.scatter(cluster[j][0], cluster[j][1], marker='x', c='black')
    # Save and display pictures
    plt.savefig("kmeans_for_anchors.jpg")
    plt.show()
    print('Save kmeans_for_anchors.jpg in root dir.')
	
    # Sort the cluster centers by area
    cluster = cluster[np.argsort(cluster[:, 0] * cluster[:, 1]]print('avg_ratio:{:.2f}'.format(avg_iou(data, cluster)))
    print(cluster)

    # Write the sorted cluster center above to the file
    f = open("yolo_anchors.txt".'w')
    row = np.shape(cluster)[0]
    for i in range(row):
        if i == 0:
            x_y = "%d,%d" % (cluster[i][0], cluster[i][1])
        else:
            x_y = ", %d,%d" % (cluster[i][0], cluster[i][1])
        f.write(x_y)
    f.close()
Copy the code

After running, we get the following image