“This is the 7th day of my participation in the Gwen Challenge in November. Check out the details: The Last Gwen Challenge in 2021.”

Conditional Batch Normalization

Batch Normalization (BN) is a commonly used network training technique in deep learning. It not only accelerates model convergence but also, more importantly, alleviates the problem of “gradient dispersion” in the deep network, which makes training the deep network model easier and more stable. So now BN has become the standard technique for almost all convolutional neural networks. Let’s briefly review the equations of BN for the next batch:


B N ( x ) = gamma ( x mu ( x ) sigma ( x ) ) + Beta. BN(x) = \gamma (\frac {x-\mu(x)}{\sigma(x)}) + \beta

Where mean µµµ and standard deviation σσσ are calculated on the (N, H, W) dimension, each normalization layer has only one affine transformation parameter pair γγγ and βββ, which are learned by the network itself during training.

However, using BN in Generative Adversarial Networks (GAN) results in the homogeneity of the generated images to some extent. For example, in the CIFAR10 dataset, there are 10 types of images: 6 are animals (birds, cats, deer, dogs, frogs and horses), and 4 are vehicles (planes, cars, ships and trucks). Obviously, the different categories of images look very different in appearance — traffic tends to have hard and straight edges, while animals tend to have curved edges and softer textures.

As we learned in style migration, the activated statistics determine the image style. Thus, mixed batch statistics can create images that look a bit like an animal but also a bit like a vehicle (for example, a car-shaped cat). This is because batch normalization uses only one γ\gammaγ and one β\betaβ in the whole batch consisting of different categories of images. This problem is solved if each category has aγ \gammaγ and aβ \betaβ, which is what conditional batch normalization is all about. Each class has one gamma \gammaγ and one β\betaβ, so the 10 classes in CIFAR10 have 10 gamma \gammaγ and 10 β\betaβ in each layer.

TensorFlow implements conditional batch normalization

Now we can construct the variables needed for conditional batch normalization, as follows:

  1. β\betaβ and γ\gammaγ of shape (10, C), where C is the number of activated channels.
  2. (1, 1, 1, C) the moving mean and variance of the shape. In training, mean and variance are calculated from small batches. In the reasoning process, we use the moving mean accumulated during training. Their shape allows arithmetic operations to be broadcast to the N, H, and W dimensions.

To achieve conditional batch normalization with a custom layer, first create the required variables:

class ConditionBatchNorm(Layer) :
    def build(self, input_shape) :
        self.input_size = input_shape
        n, h, w, c = input_shape
        self.gamma = self.add_weight(shape=[self.n_class, c], 
            initializer='zeros', trainable=True, name='gamma')
        self.moving_mean = self.add_weight(shape=[1.1.1, c],
            initializer='zeros', trainable=False, name='moving_mean')
        self.moving_var = self.add_weight(shape=[1.1.1, c], 
            initializer='zeros', trainable=False, name='moving_var')
Copy the code

When the run condition is batch normalized, the correct β\betaβ and γ\gammaγ are retrieved for the label. This is done using tF.Gather (self.beta, labels), which is conceptually equivalent to beta = self.beta[labels] as follows:

    def call(self, x, labels, trainable=False) :
        beta = tf.gather(self.beta, labels)
        beta = tf.expand_dims(beta, 1)
        gamma = tf.gather(self.gamma, labels)
        gamma = tf.expand_dims(gamma, 1)
		if training:
			mean, var = tf.nn.moments(x, axes=(0.1.2), keepdims=True)
			self.moving_mean.assign(self.decay * self.moving_mean + (1-self.decay)*mean)
			self.moving_var.assign(self.decay * self.moving_var + (1-self.decay)*var)
			output = tf.nn.batch_normalization(x, mean, var, beta, gamma, self.eps)
		else:
			output = tf.nn.batch_normalization(x, self.moving_mean, self.moving_var, beta, gamma, self.eps)
		return output
Copy the code

Conditional batch normalization is applied to residual blocks

Conditional batch normalization is used in the same way as batch normalization. As an example, we now add conditional batch normalization to the residual block:

class ResBlock(Layer) :
    def build(self, input_shape) :
        input_filter = input_shape[-1]
        self.conv_1 = Conv2D(self.filters, 3, padding='same', name='conv2d_1')
        self.conv_2 = Conv2D(self.filters, 3, padding='same', name='conv2d_2')
        self.cbn_1 = ConditionBatchNorm(self.n_class)
        self.cbn_2 = ConditionBatchNorm(self.n_class)
        self.learned_skip = False
        ifself.filters ! = input_filter: self.learned_skip =True
            self.conv_3 = Conv2D(self.filters, 1, padding='same', name='conv2d_3')
            self.cbn_3 = ConditionBatchNorm(self.n_class)
Copy the code

The following is the forward calculation code using conditional batch normalized residual blocks:

    def call(self, input_tensor, labels) :
        x = self.conv_1(input_tensor)
        x = self.cbn_1(x, labels)
        x = tf.nn.leaky_relu(x, 0.2)
        x = self.conv_2(x)
        x = tf.cbn_2(x, labels)
        x = tf.nn.leaky_relu(x, 0.2)
        if self.learned_skip:
            skip = self.conv_3(input_tensor)
            skip = self.cbn_3(skip, labels)
            skip = tf.nn.leaky_relu(skip, 0.2)
        else:
            skip = input_tensor
        output = skip + x
        return output
Copy the code