Keras trains the model with model.fit_generator (to save memory)
preface
For example, if we have 20000 samples and the dimensions of the input images are 224x224x3 and float32 is used, then if we load all the data into memory at one time, The total memory required is 20000x224x224x3x32bit/8=11.2GB, so it would require a lot of memory to load the entire dataset at once.
If we directly use Keras fit function to train the model, we need to pass in all the training data, but fortunately fit_generator is provided, which can read the data in batches, saving our memory. The only thing we need to do is to implement a generator.
1. Introduction to fit_generator function
fit_generator(generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0)
Copy the code
Parameters:
Generator: A generator, or an instance of a Sequence (keras.utils.sequence) object. This is the focus of our implementation, and we’ll see how generators and sequence can be implemented later.
Steps_per_epoch: This is how many times we need to execute generators in each epoch to produce data. Fit_generator function does not have the parameter batCH_size, but is implemented by steps_per_EPOCH. Each batch of data is produced. Therefore, the value of STEps_per_EPOCH will be set to (sample size /batch_size). If our generator is of sequence type, this parameter is optional and len(Generator) is used by default.
Epochs: the number of iterations we trained.
Verbose: 0, 1, or 2. Log display mode. 0 = Quiet mode, 1 = progress bar, 2 = one line per round
Callbacks: A series of callback functions called during training.
Validation_data: Similar to our generator, only this is used for validation, not training.
Validation_steps: Similar to the previous Steps_per_EPOCH.
Class_weight: Optional dictionary that maps class indexes (integers) to weighted (floating point) values for weighted loss functions (only during training). This can be used to tell the model to “pay more attention” to samples from underrepresented classes. (Feel this parameter is used less)
Max_queue_size: an integer. The maximum size of the generator queue. The default is 10.
Workers: integer. Maximum number of processes to use if using process-based multithreading. If not specified, workers will default to 1. If 0, the generator is executed on the main thread.
Use_multiprocessing: Boolean value. If True, process-based multithreading is used. The default is False.
Shuffle: Indicates whether to shuffle the batch order before each iteration. Can only be used with Sequence(keras.utils.sequence) instances.
Initial_epoch: Rotation to start training (to help recover from previous training)
2. The generator
2.1 Implementation of generator
Sample code:
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from PIL import Image
def process_x(path) :
img = Image.open(path)
img = img.resize((96.96))
img = img.convert('RGB')
img = np.array(img)
img = np.asarray(img, np.float32) / 255.0
Some data enhancement processing can also be carried out
return img
def generate_arrays_from_file(x_y) :
# x_y is our training set containing tags, the first of each line is our image path, followed by the image tag
global count
batch_size = 8
while 1:
batch_x = x_y[(count - 1) * batch_size:count * batch_size, 0]
batch_y = x_y[(count - 1) * batch_size:count * batch_size, 1:]
batch_x = np.array([process_x(img_path) for img_path in batch_x])
batch_y = np.array(batch_y).astype(np.float32)
print("count:" + str(count))
count = count + 1
yield batch_x, batch_y
model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
count = 1
x_y = []
model.fit_generator(generate_arrays_from_file(x_y), steps_per_epoch=10, epochs=2, max_queue_size=1, workers=1)
Copy the code
Before we understand the code above, we need to understand the use of yield.
Yield keyword:
Let’s look at the use of yield in an example:
def foo(): print("starting..." ) while True: res = yield 4 print("res:", res) g = foo() print(next(g)) print("----------") print(next(g))Copy the code
Running results:
starting...
4
----------
res: None
4
Copy the code
A function with yield is a generator, not a function. Foo doesn’t actually execute because foo has the yield keyword. Instead, foo gets an instance of the generator. When we first call next, foo first executes the print method in foo and then enters the while loop. Yield is the same thing as return, the function returns 4, and the program stops. So our first call to next(g) results in the first two lines.
Then when we call next(g) again, we will pick up where we left off. We will assign res, since 4 was returned last time. We will assign res to None, and print(” res: “,res) to print res: None, looped again to yield returns 4, and the program stops.
So the yield keyword is used to pick up where the last program left off, so that when we use it as a generator, we don’t run out of memory by reading data all at once.
Now look at the sample code above:
The generate_arrays_from_file function is our generator, looping through one batch size of data at a time, processing the data, and returning it. X_y is our training set after combining paths and labels, similar to the following form:
[‘data/img_4092.jpg’ ‘0’ ‘1’ ‘0’ ‘0’ ‘0’ ]
The process_x function is used to read images and normalize them. You can also define your own operations in process_x. For example, real-time data enhancement of images.
2.2 Using Sequence to Implement the Generator
Sample code:
class BaseSequence(Sequence) :
""" Basic data stream generator, Each iteration returns a batch BaseSequence. The GENERATOR parameter FIT_generator, which can be used directly for FIT_Generator, rewraps BaseSequence as a multi-process data flow generator Moreover, it can ensure that the same samples will not be taken repeatedly in a multiprocess epoch.
def __init__(self, img_paths, labels, batch_size, img_size) :
# Np.hstack is tiled horizontally
self.x_y = np.hstack((np.array(img_paths).reshape(len(img_paths), 1), np.array(labels)))
self.batch_size = batch_size
self.img_size = img_size
def __len__(self) :
# math.ceil means round up
Len (BaseSequence); len(BaseSequence); len(BaseSequence
return math.ceil(len(self.x_y) / self.batch_size)
def preprocess_img(self, img_path) :
img = Image.open(img_path)
resize_scale = self.img_size[0] / max(img.size[:2])
img = img.resize((self.img_size[0], self.img_size[0]))
img = img.convert('RGB')
img = np.array(img)
# Data normalization
img = np.asarray(img, np.float32) / 255.0
return img
def __getitem__(self, idx) :
batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0]
batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:]
batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x])
batch_y = np.array(batch_y).astype(np.float32)
print(batch_x.shape)
return batch_x, batch_y
The on_epoch_end method of the overridden parent class Sequence is called after each iteration.
def on_epoch_end(self) :
Reshuffle the training set data after each iteration
np.random.shuffle(self.x_y)
Copy the code
In the above code, __len __ and _getitem _ are the magic methods we overwrote, __len __ is called when len(BaseSequence) is called, here we return (total sample /batch_size), For passing the steps_per_EPOCH parameter in FIT_generator; __getitem __ allows the object to iterate, so that after passing a BaseSequence object to the FIT_Generator, the generator reads the data in a loop.
Here’s an example of what Getitem can do:
class Animal:
def __init__(self, animal_list) :
self.animals_name = animal_list
def __getitem__(self, index) :
return self.animals_name[index]
animals = Animal(["dog"."cat"."fish"])
for animal in animals:
print(animal)
Copy the code
Output result:
dog
cat
fish
Copy the code
Moreover, using Sequence class can guarantee that samples in each epoch will be trained only once in the case of multi-process.
referenceYield method: