preface

I haven’t done something interesting for a long time. When I went to Github, I thought that I saw the cute cartoon avatar recently, so I decided to look for related projects to play. After all, it is so boring to directly call the API of Baidu.

The body of the

Having written some blogs about GAN, you should have a basic understanding of GAN. The most basic content is that based on the idea of zero-sum game, the false image generated by the generator is realistic enough to fool the discriminator. In essence, the generator keeps learning from the confrontation and learns the data distribution in the real image. However, it is still difficult to train GAN. The difficulty lies in the difficulty of convergence and mode collapse.

Again, the purpose of this time, if you want to achieve the animation of the avatar, this concept should be the style transfer of the image, then the most basic should think of the loss of content and style, of course, the specific problem will put forward more different loss. A quick search on Github turns up these names

  • CartoonGAN
  • AnimeGAN
  • AnimeGAN2

Look at the relevant papers and Github links, you will find that the above order is the process of gradual optimization, I will put the specific paper reading section in the paper column, today’s attention is still focused on the implementation. I finally chose the SOTA model Animegan 2-Tensorflow

There is also a PyTorch version of Animegan2-PyTorch

I chose the PyTorch version because TensorFlow is a 1.x version implementation and my environment has all moved to TF2, and I increasingly dislike TF’s complex and clutter-filled APIS. More network implementations are based on Keras and sometimes lose the flexibility of TF itself. Anyway, I digress. Now let’s actually implement it.

Begin to implement

Go to Animegan2-PyTorch, and if you just want to implement it you don’t need to download the whole project, just model.py

import torch
from torch import nn
import torch.nn.functional as F


class ConvNormLReLU(nn.Sequential) :
    def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False) :
        pad_layer = {
            "zero": nn.ZeroPad2d,
            "same": nn.ReplicationPad2d,
            "reflect": nn.ReflectionPad2d,
        }
        if pad_mode not in pad_layer:
            raise NotImplementedError

        super(ConvNormLReLU, self).__init__(
            pad_layer[pad_mode](padding),
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias),
            nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True),
            nn.LeakyReLU(0.2, inplace=True))class InvertedResBlock(nn.Module) :
    def __init__(self, in_ch, out_ch, expansion_ratio=2) :
        super(InvertedResBlock, self).__init__()

        self.use_res_connect = in_ch == out_ch
        bottleneck = int(round(in_ch * expansion_ratio))
        layers = []
        ifexpansion_ratio ! =1:
            layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0))

        # dw
        layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True))
        # pw
        layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False))
        layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True))

        self.layers = nn.Sequential(*layers)

    def forward(self, input) :
        out = self.layers(input)
        if self.use_res_connect:
            out = input + out
        return out


class Generator(nn.Module) :
    def __init__(self, ) :
        super().__init__()

        self.block_a = nn.Sequential(
            ConvNormLReLU(3.32, kernel_size=7, padding=3),
            ConvNormLReLU(32.64, stride=2, padding=(0.1.0.1)),
            ConvNormLReLU(64.64)
        )

        self.block_b = nn.Sequential(
            ConvNormLReLU(64.128, stride=2, padding=(0.1.0.1)),
            ConvNormLReLU(128.128)
        )

        self.block_c = nn.Sequential(
            ConvNormLReLU(128.128),
            InvertedResBlock(128.256.2),
            InvertedResBlock(256.256.2),
            InvertedResBlock(256.256.2),
            InvertedResBlock(256.256.2),
            ConvNormLReLU(256.128),
        )

        self.block_d = nn.Sequential(
            ConvNormLReLU(128.128),
            ConvNormLReLU(128.128)
        )

        self.block_e = nn.Sequential(
            ConvNormLReLU(128.64),
            ConvNormLReLU(64.64),
            ConvNormLReLU(64.32, kernel_size=7, padding=3)
        )

        self.out_layer = nn.Sequential(
            nn.Conv2d(32.3, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Tanh()
        )

    def forward(self, input, align_corners=True) :
        out = self.block_a(input)
        half_size = out.size()[-2:]
        out = self.block_b(out)
        out = self.block_c(out)

        if align_corners:
            out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
        else:
            out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
        out = self.block_d(out)

        if align_corners:
            out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
        else:
            out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
        out = self.block_e(out)

        out = self.out_layer(out)
        return out
Copy the code

This includes generator code, which is a key part of generating anime style images. Then I need to go to the corresponding network disk to download the model trained under different styles, which may not be accessible, so I put the downloaded model on the network disk to share. Model link password: N78V

Then I will write my own code to generate images, which can of course be modified according to test_faces.ipynb in the original project. Here is my generation code, which needs to put the model in the same directory in advance, and create the samples folder to store the generated images

import os
import cv2
import matplotlib.pyplot as plt
import torch
import random
import numpy as np

from model import Generator


def load_image(path, size=None) :
    image = image2tensor(cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB))

    w, h = image.shape[-2:]
    ifw ! = h: crop_size =min(w, h)
        left = (w - crop_size) // 2
        right = left + crop_size
        top = (h - crop_size) // 2
        bottom = top + crop_size
        image = image[:, :, left:right, top:bottom]

    if size is not None and image.shape[-1] != size:
        image = torch.nn.functional.interpolate(image, (size, size), mode="bilinear", align_corners=True)

    return image


def image2tensor(image) :
    image = torch.FloatTensor(image).permute(2.0.1).unsqueeze(0) / 255.
    return (image - 0.5) / 0.5


def tensor2image(tensor) :
    tensor = tensor.clamp_(-1..1.).detach().squeeze().permute(1.2.0).cpu().numpy()
    return tensor * 0.5 + 0.5


def imshow(img, size=5, cmap='jet') :
    plt.figure(figsize=(size, size))
    plt.imshow(img, cmap=cmap)
    plt.axis('off')
    plt.show()


if __name__ == '__main__':
    device = 'cuda'
    torch.set_grad_enabled(False)
    image_size = 300
    img=input("")
    model = Generator().eval().to(device)

    ckpt = torch.load(f"./new.pth", map_location=device)
    model.load_state_dict(ckpt)

    result=[]
    image = load_image(f"./face/{img}", image_size)
    output = model(image.to(device))

    result.append(torch.cat([image, output.cpu()], 3))
    result = torch.cat(result, 2)

    imshow(tensor2image(result), 40)
    cv2.imwrite(f'./samples/new+{img}', cv2.cvtColor(255 * tensor2image(result), cv2.COLOR_BGR2RGB))

Copy the code

Select CPU or CUDA in Device (using GPU)

Possible problems

When pyTorch is lower than 1.6 on the computer, an error is reported when loading the model, but the graphics card does not seem to support installing the higher torch, so you have to find another solution.

Take a laptop, you can usually install a higher version of PyTorch (greater than 1.6 will do), and after setting up the environment, run the following code to load it with the newly generated model file, which is the new.pth code in my code to modify the model file

import torch
weight = torch.load("Higher version model address")
torch.save(weight, 'Customize new model address', _use_new_zipfile_serialization=False)
Copy the code

The effect

The effect is still quite good, the speed of generating pictures is relatively fast, but I feel there are still some problems

  • Because the code will intercept the image, it is possible to lose the information in the original image
  • It may be a problem with the training data set, the incoming images with poor lighting appear very bad and the contour curves are very blurred.