- GAN with Keras: Application to Image Deblurring
- By Raphael Meudec
- Translation from: The Gold Project
- This article is permalink: github.com/xitu/gold-m…
- Translator: luochen
- Proofread by: SergeyChang mingxing47
In 2014, Ian Goodfellow proposed Generative Adversarial Networks (GAN) This article will focus on implementing an adversity-generated network based image de-blurring model using Keras. All the Keras code is here.
Scientific publication and Pytorch version implementation.
Quick review to generate adversarial networks
In a generative adversarial network, two networks train each other. The generative model misleads the discriminant model by creating false inputs that are not true. The discriminant model distinguishes between real and fake inputs.
GAN training process – Source
There are three main steps to training:
- Create noise-based fake inputs using generative models.
- Both real and false inputs are used to train the discriminant model.
- Training the whole model: the model is composed of generating model followed by discriminant model.
Note that in the third step, the weights of the discriminant model are no longer updated.
The reason for concatenating the two model networks is that it is not possible to feed back directly to the generated model output. Our only measure is whether the model accepts the generated samples.
The structure of GAN is briefly reviewed here. If you find it difficult to understand, you can refer to the Excellent introduction.
The data set
Ian Goodfellow first applied the GAN model to generate MNIST data. In this tutorial, we use generative adversarial networks for image de-blurring. Therefore, the input to generate the model is not noise but fuzzy images.
The data set was GOPRO. You can download the compact version (9GB) or the full version (35GB). It contains human-blurred images from multiple street views. Data sets are in subfolders by scene.
Let’s first place the images in folders A (blur) and B (clear). The structure of A and B is consistent with the original pix2PIx article. I wrote a custom script to perform this task, using it according to the README.
model
The training process remains the same. First, let’s look at the neural network structure!
Generate models
Generation models are designed to reproduce clear images. The network model is based on residual network (ResNet) blocks. It continuously tracks the evolution of the original blurry image. This article is based on the UNet version, which I haven’t implemented yet. Both of these structures are suitable for image de-blurring.
DeblurGAN generates the network structure of the model – Source
At the core are nine ResNet blocks applied to the upsampling of the original image. Let’s look at the implementation of Keras!
from keras.layers import Input, Conv2D, Activation, BatchNormalization
from keras.layers.merge import Add
from keras.layers.core import Dropout
def res_block(input, filters, kernel_size=(3.3), strides=(1.1), use_dropout=False):
Instantiate a Keras Resnet block using a sequential API. :param input: input tensor :param filters: Number of convolution cores :param kernel_SIZE: Convolution kernel size: param strides: Param USe_dropout: Boolean to determine whether to use the Dropout: Return: Keras model ""
x = ReflectionPadding2D((1.1))(input)
x = Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
if use_dropout:
x = Dropout(0.5)(x)
x = ReflectionPadding2D((1.1))(x)
x = Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,)(x)
x = BatchNormalization()(x)
Connect two layers of convolution between input and output
merged = Add()([input, x])
return merged
Copy the code
The ResNet layer is basically a convolution layer, adding inputs and outputs to form the final output.
from keras.layers import Input, Activation, Add
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.core import Lambda
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from layer_utils import ReflectionPadding2D, res_block
ngf = 64
input_nc = 3
output_nc = 3
input_shape_generator = (256.256, input_nc)
n_blocks_gen = 9
def generator_model(a):
""" build a generative model """
# Current version : ResNet block
inputs = Input(shape=image_shape)
x = ReflectionPadding2D((3.3))(inputs)
x = Conv2D(filters=ngf, kernel_size=(7.7), padding='valid')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Increase filter number
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
x = Conv2D(filters=ngf*mult*2, kernel_size=(3.3), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Apply 9 ResNet Blocks
mult = 2**n_downsampling
for i in range(n_blocks_gen):
x = res_block(x, ngf*mult, use_dropout=True)
# Reduce convolution kernel to 3 (RGB)
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3.3), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = ReflectionPadding2D((3.3))(x)
x = Conv2D(filters=output_nc, kernel_size=(7.7), padding='valid')(x)
x = Activation('tanh')(x)
# Add direct connection from input to output and recenter to [-1, 1]
outputs = Add()([x, inputs])
outputs = Lambda(lambda z: z/2)(outputs)
model = Model(inputs=inputs, outputs=outputs, name='Generator')
return model
Copy the code
Keras implements the generation model
As planned, nine ResNet blocks are applied to the upsampled version of the input. We add a connection from the input to the output and divide by 2 to maintain a normalized output.
So that’s the generation model, let’s look at the discriminant model.
Discriminant model
The goal of the discriminant model is to determine whether the input image is artificial. Therefore, the structure of the discriminant model is convolution and the output is a single value.
from keras.layers import Input
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.layers.core import Dense, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model
ndf = 64
output_nc = 3
input_shape_discriminator = (256.256, output_nc)
def discriminator_model(a):
""" Construct a discriminant model."""
n_layers, use_sigmoid = 3.False
inputs = Input(shape=input_shape_discriminator)
x = Conv2D(filters=ndf, kernel_size=(4.4), strides=2, padding='same')(inputs)
x = LeakyReLU(0.2)(x)
nf_mult, nf_mult_prev = 1.1
for n in range(n_layers):
nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
x = Conv2D(filters=ndf*nf_mult, kernel_size=(4.4), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
x = Conv2D(filters=ndf*nf_mult, kernel_size=(4.4), strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(filters=1, kernel_size=(4.4), strides=1, padding='same')(x)
if use_sigmoid:
x = Activation('sigmoid')(x)
x = Flatten()(x)
x = Dense(1024, activation='tanh')(x)
x = Dense(1, activation='sigmoid')(x)
model = Model(inputs=inputs, outputs=x, name='Discriminator')
return model
Copy the code
Keras implements the discriminant model
The final step is to build the full model. The special thing about this GAN is that the input is real image and not noise. As a result, we can get direct feedback on the output of the generated model.
from keras.layers import Input
from keras.models import Model
def generator_containing_discriminator_multiple_outputs(generator, discriminator):inputs = Input(shape=image_shape) generated_images = generator(inputs) outputs = discriminator(generated_images) model = Model(inputs=inputs, outputs=[generated_images, outputs])return model
Copy the code
Let’s see how we can take full advantage of this particularity by using two loss functions.
training
Loss function
We extract loss values at two levels, one at the end of the generation model and one at the end of the entire model.
The first is to calculate perceptual loss directly from the output of the generated model. This loss value ensures that the GAN model is defuzzy-oriented. It compares the first convolution output of VGG.
import keras.backend as K
from keras.applications.vgg16 import VGG16
from keras.models import Model
image_shape = (256.256.3)
def perceptual_loss(y_true, y_pred):
vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
loss_model.trainable = False
return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))
Copy the code
The second loss value is to calculate the output Wasserstein loss of the whole model. It’s the average difference between the two images. It is known for improving convergence against generated networks.
import keras.backend as K
def wasserstein_loss(y_true, y_pred):
return K.mean(y_true*y_pred)
Copy the code
The training process
The first step is to load the data and initialize the model. We used custom functions to load the data set and add the Adam optimizer to the model. We prevent the discriminant model from being trained by setting the Keras trainable option.
# Load dataset
data = load_images('./images/train', n_images)
y_train, x_train = data['B'], data['A']
Initialize the model
g = generator_model()
d = discriminator_model()
d_on_g = generator_containing_discriminator_multiple_outputs(g, d)
# Initialize optimizer
g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
# Compile model
d.trainable = True
d.compile(optimizer=d_opt, loss=wasserstein_loss)
d.trainable = False
loss = [perceptual_loss, wasserstein_loss]
loss_weights = [100.1]
d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
d.trainable = True
Copy the code
Then, we start the iteration while dividing the data set into batches.
for epoch in range(epoch_num):
print('epoch: {}/{}'.format(epoch, epoch_num))
print('batches: {}'.format(x_train.shape[0] / batch_size))
# Randomly divide images into different batches
permutated_indexes = np.random.permutation(x_train.shape[0])
for index in range(int(x_train.shape[0] / batch_size)):
batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size]
image_blur_batch = x_train[batch_indexes]
image_full_batch = y_train[batch_indexes]
Copy the code
Finally, we train the model and the discriminant model according to the two kinds of losses. We use generative models to generate false inputs. We train the discriminant model to distinguish between false and true inputs, and then we train the whole model.
for epoch in range(epoch_num):
for index in range(batches):
# [Batch Preparation]
# Generate false input
generated_images = g.predict(x=image_blur_batch, batch_size=batch_size)
# Train multiple discriminant models on true and false inputs
for _ in range(critic_updates):
d_loss_real = d.train_on_batch(image_full_batch, output_true_batch)
d_loss_fake = d.train_on_batch(generated_images, output_false_batch)
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
d.trainable = False
# Train generator only on discriminator's decision and generated images
d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [image_full_batch, output_true_batch])
d.trainable = True
Copy the code
You can refer to Github to see the entire loop!
Some material
I used AWS Instance (P2.xlarge) in Deep Learning AMI (Version 3.0). Under GOPRO data set compact version, the training time was about 5 hours (50 iterations).
Image deblurring results
Left to right: raw image, blurred image, GAN output
The output above is the result of our Keras Deblur GAN. Even in the case of severe blurring, the network was able to reduce and form a more convincing image. The headlights are clearer, the branches are clearer.
Left: GOPRO test image, right: GAN output.
One limitation is the induction mode on the image, which may be caused by the use of VGG as a loss.
Left: GOPRO test image, right: GAN output.
I hope you enjoyed this article on image de-blurring using generative adversarial models. Feel free to comment, follow us or contact me.
If you’re interested in computer vision, check out our previous article on Keras implementing content-based image retrieval. The following is a list of resources that generate adversarial networks.
Left: GOPRO test image, right: GAN output.
Generates a list of resources against the network.
-
NIPS 2016: Generative Adversarial Networks by Ian Goodfellow
-
ICCV 2017: A tutorial against the Generative Web
-
Keras implementation against generative networks by Eric Linder-Noren
-
Counter generated list of network resources by Deeplearning4J
-
Awesome against the generative network by Holger Caesar
Diggings translation project is a community for translating quality Internet technical articles from diggings English sharing articles. The content covers the fields of Android, iOS, front end, back end, blockchain, products, design, artificial intelligence and so on. For more high-quality translations, please keep paying attention to The Translation Project, official weibo and zhihu column.