I want my model to run on multiple GPU-sharing parameters but with different batches of data.
Can I do something like that with model.fit()
? Is there any other alternative?
I want my model to run on multiple GPU-sharing parameters but with different batches of data.
Can I do something like that with model.fit()
? Is there any other alternative?
Keras now has (as of v2.0.9) in-built support for device parallelism, across multiple GPUs, using keras.utils.multi_gpu_model
.
Currently, only supports the Tensorflow back-end.
Good example here (docs): https://keras.io/getting-started/faq/#how-can-i-run-a-keras-model-on-multiple-gpus Also covered here: https://datascience.stackexchange.com/a/25737
try to use make_parallel function in: https://github.com/kuza55/keras-extras/blob/master/utils/multi_gpu.py (it will work only with the tensorflow backend).
In kera multi-gpu model training is very convenient than ever. Check the following document regarding this: Multi-GPU and distributed training.
In essence, to do single-host, multi-device synchronous training with a keras model, you would use the tf.distribute.MirroredStrategy
API. Here's how it works:
Instantiate a MirroredStrategy
, optionally configuring which specific devices you want to use (by default the strategy will use all GPUs available).
Use the strategy object to open a scope, and within this scope, create all the Keras objects you need that contain variables. Typically, that means creating & compiling the model inside the distribution scope.
Train the model via fit()
as usual.
Schematically, it looks like this:
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
# Open a strategy scope.
with strategy.scope():
# Everything that creates variables should be under the strategy scope.
# In general this is only model construction & `compile()`.
model = Model(...)
model.compile(...)
# Train the model on all available devices.
model.fit(train_dataset, validation_data=val_dataset, ...)
# Test the model on all available devices.
model.evaluate(test_dataset)
© 2022 - 2024 — McMap. All rights reserved.