Efficient Training with Multiple GPUs in Keras#
Since our lab computer is currently equipped with three high-performance GPUs, we need to modify the code to make Keras run efficiently and take advantage of them.
The current environment is Keras with TensorFlow backend, Python 3.
Three 1080ti graphics cards.
Originally, our network architecture was based on the Sequential structure (roughly introduced).
Therefore, each batch was executed by distributing tasks to the GPUs.
In other words, the tasks were evenly distributed among all the GPUs.
However, all the GPUs cannot be fully utilized in this way.
So, we need to slice the batch into the number of GPUs (dividing it into three equal parts, since we have three GPUs).
Then, we combine all the batch results into one output, which will speed up the computation (of course, there will also be output limitations for each GPU).
Use the following code to slice x into (n_gpus) equal parts. Note that at this point, x is not the actual numerical value, but a tensor (i.e., the value has not been run yet).
def slice_batch(x, n_gpus, part):
sh = K.shape(x)
L = sh[0] // n_gpus
if part == n_gpus - 1:
return x[part * L:]
return x[part * L:(part + 1) * L]
Afterwards, pass the model into the following function, and it will add (n_gpus) lambda layers to the current model, connected to the input.
Then, all these results are merged into one output (and merged along the batch axis).
def to_multi_gpu(model, n_gpus=3):
with tf.device('/cpu:0'):
print(model.input_shape)
x = Input(model.input_shape[1:], name="input")
towers = []
for g in range(n_gpus):
with tf.device('/gpu:' + str(g)):
slice_g = Lambda(slice_batch, lambda shape: shape, arguments={'n_gpus': n_gpus, 'part': g})(x)
towers.append(model(slice_g))
with tf.device('/cpu:0'):
merged = Concatenate(axis=0)(towers)
return Model(inputs=[x], outputs=[merged])
The resulting usage will be similar to the following diagram:
The training time without using this technique is:
160 epochs: 23 minutes
The average distribution time on three graphics cards using this technique is:
160 epochs: 9 minutes