How to do multi GPU training with Keras?
Asked Answered
I

3

11

I want my model to run on multiple GPU-sharing parameters but with different batches of data.

Can I do something like that with model.fit()? Is there any other alternative?

It answered 18/7, 2017 at 12:2 Comment(0)
P
3

Keras now has (as of v2.0.9) in-built support for device parallelism, across multiple GPUs, using keras.utils.multi_gpu_model.

Currently, only supports the Tensorflow back-end.

Good example here (docs): https://keras.io/getting-started/faq/#how-can-i-run-a-keras-model-on-multiple-gpus Also covered here: https://datascience.stackexchange.com/a/25737

Purlieu answered 31/12, 2017 at 18:52 Comment(0)
M
2

try to use make_parallel function in: https://github.com/kuza55/keras-extras/blob/master/utils/multi_gpu.py (it will work only with the tensorflow backend).

Milfordmilhaud answered 26/9, 2017 at 15:29 Comment(0)
B
0

In multi-gpu model training is very convenient than ever. Check the following document regarding this: Multi-GPU and distributed training.


In essence, to do single-host, multi-device synchronous training with a model, you would use the tf.distribute.MirroredStrategy API. Here's how it works:

  • Instantiate a MirroredStrategy, optionally configuring which specific devices you want to use (by default the strategy will use all GPUs available).

  • Use the strategy object to open a scope, and within this scope, create all the Keras objects you need that contain variables. Typically, that means creating & compiling the model inside the distribution scope.

  • Train the model via fit() as usual.

Schematically, it looks like this:

# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

# Open a strategy scope.
with strategy.scope():
  # Everything that creates variables should be under the strategy scope.
  # In general this is only model construction & `compile()`.
  model = Model(...)
  model.compile(...)

# Train the model on all available devices.
model.fit(train_dataset, validation_data=val_dataset, ...)

# Test the model on all available devices.
model.evaluate(test_dataset)
Bromeosin answered 8/8, 2021 at 11:0 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.