On the Internet search along the day, want to use VGG to classify a picture, found that is a principle of east, west a principle. This article summarizes using the PyTorch framework to officially implement VGG to categorize images.

The method provided in this article requires two additional files: one is the VGG model, which can be downloaded from:

{
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
}
Copy the code

One is the label information of image-net. This article uses the JSON version. Click me to download. Be careful where the files are stored and where they are named

Pay attention to the location of the images you want to verify!

Can’t dataset/directly be your image.jpg/ PNG, it must be: dataset/anyway, there must be a folder here/your image.jpg/ PNG

Pytorch version >=0.4 Has no dependencies other than PyTorch.

example

['n02793495', 'barn'] ['n02793495', 'barn'] ['n03776460', 'mobile_home'] predicts species: ['n03956157', 'planetarium'] Predicts species: ['n02793495', 'barn'] Predicts species: ['n03776460', 'mobile_home']  ['n02793495', 'barn']Copy the code

Read the json code

// readJson.py
import json

def GetInfo():
    with open("./imageNet.json",'r') as load_f:
        load_dict = json.load(load_f)
        return load_dict
Copy the code

The real subject is here

import torch import readJson from torchvision import transforms import torchvision.models as models from torchvision Import datasets tran=transforms.Compose([transforms.resize ((224,224)), transforms.totensor (), Transforms.Normalize(mean=[0.485, 0.456, 0.406], STD =[0.229, 0.224, 0.225]]) # prepared vgg19 / VGG 16 is ready to print the corresponding documents (" loading VGG model..." ) modelVGG = models.vgg19() path ="C://Users/liu/.torch/models/vgg19-dcbb9e9d.pth" pre = torch.load(path) Modelvgg.load_state_dict (pre) print(" VGG model loaded successfully ") # Prepare cuda cudA = torch.cuda.is_available() device = torch. Device ("cuda") If torch.cuda.is_available() else "CPU ") modelVGG = modelVgg.to (device) # print(" prepare dataset............") ) dataset = datasets. ImageFolder (' D: / shujuji/your folders, tran) data_loader = torch. Utils. Data. The DataLoader (the dataset, Batch_size =8, shuffle=True,) jsonInfo = readjson.getinfo () print(" prepared successfully ") for I,imgs in enumerate(data_loader): # Even if there are 8 images in a batch, they will still be wrapped in a list. List [[Figure 1, Figure 2,..... Figure 8]] It's an array wrapped around an array. real = imgs[0] if (CUDA): Real = real-.cuda () out = modelVGG(real) Return the maximum value and index predict_value, predict_idx = torch. Max (out, 1) print(predict_value,predict_idx) Print (p.item(),predict_value[pos].item()) print(" the type of prediction is: ",jsonInfo[STR (p.item())]) pos= pos +1Copy the code