On the Internet search along the day, want to use VGG to classify a picture, found that is a principle of east, west a principle. This article summarizes using the PyTorch framework to officially implement VGG to categorize images.
The method provided in this article requires two additional files: one is the VGG model, which can be downloaded from:
{
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
}
Copy the code
One is the label information of image-net. This article uses the JSON version. Click me to download. Be careful where the files are stored and where they are named
Pay attention to the location of the images you want to verify!
Can’t dataset/directly be your image.jpg/ PNG, it must be: dataset/anyway, there must be a folder here/your image.jpg/ PNG
Pytorch version >=0.4 Has no dependencies other than PyTorch.
example
['n02793495', 'barn'] ['n02793495', 'barn'] ['n03776460', 'mobile_home'] predicts species: ['n03956157', 'planetarium'] Predicts species: ['n02793495', 'barn'] Predicts species: ['n03776460', 'mobile_home'] ['n02793495', 'barn']Copy the code
Read the json code
// readJson.py
import json
def GetInfo():
with open("./imageNet.json",'r') as load_f:
load_dict = json.load(load_f)
return load_dict
Copy the code
The real subject is here
import torch import readJson from torchvision import transforms import torchvision.models as models from torchvision Import datasets tran=transforms.Compose([transforms.resize ((224,224)), transforms.totensor (), Transforms.Normalize(mean=[0.485, 0.456, 0.406], STD =[0.229, 0.224, 0.225]]) # prepared vgg19 / VGG 16 is ready to print the corresponding documents (" loading VGG model..." ) modelVGG = models.vgg19() path ="C://Users/liu/.torch/models/vgg19-dcbb9e9d.pth" pre = torch.load(path) Modelvgg.load_state_dict (pre) print(" VGG model loaded successfully ") # Prepare cuda cudA = torch.cuda.is_available() device = torch. Device ("cuda") If torch.cuda.is_available() else "CPU ") modelVGG = modelVgg.to (device) # print(" prepare dataset............") ) dataset = datasets. ImageFolder (' D: / shujuji/your folders, tran) data_loader = torch. Utils. Data. The DataLoader (the dataset, Batch_size =8, shuffle=True,) jsonInfo = readjson.getinfo () print(" prepared successfully ") for I,imgs in enumerate(data_loader): # Even if there are 8 images in a batch, they will still be wrapped in a list. List [[Figure 1, Figure 2,..... Figure 8]] It's an array wrapped around an array. real = imgs[0] if (CUDA): Real = real-.cuda () out = modelVGG(real) Return the maximum value and index predict_value, predict_idx = torch. Max (out, 1) print(predict_value,predict_idx) Print (p.item(),predict_value[pos].item()) print(" the type of prediction is: ",jsonInfo[STR (p.item())]) pos= pos +1Copy the code