What is going on
Please correct me if I'm wrong in any place
The changes you are referring to were introduced in 2018
via this commit and described as:
in multiprocessing mode, only one process will write the checkpoint
Previously, those were saved without any if
block so each node on each GPU would save a model which is indeed wasteful and would most probably overwrite saved model multiple times on each node.
Now, we are talking about multiprocessing distributed (possibly many workers each with possibly multiple GPUs).
args.rank
for each process is thus modified inside the script by this line:
args.rank = args.rank * ngpus_per_node + gpu
which has the following comment:
For multiprocessing distributed training, rank needs to be the
global rank among all the processes
Hence args.rank
is unique ID amongst all GPUs amongst all nodes (or so it seems).
If so, and each node has ngpus_per_node
(in this training code it is assumed each has the same amount of GPUs from what I've gathered), then the model is saved only for one (last) GPU on each node. In your example with 3
machines and 4
GPUs you would get 3
saved models (hopefully I understand this code correctly as it's pretty convoluted tbh).
If you used rank==0
only one model per world (where world would be defined as n_gpus * n_nodes
) would be saved.
Questions
First question
So why don't we just save model on rank == 0, but rank %
ngpus_per_node == 0 ?
I will start with your assumption, namely:
To the best of my knowledge, DistributedDataParallel() will automatic
do all reduce to the loss on the backend, without doing any further
job, every process can sync the loss automatically base on that.
Precisely, it has nothing to do with loss but rather gradient
accumulation and applied corrections to weights, as per documentation (emphasis mine):
This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the batch dimension. The module is replicated on each machine and
each device, and each such replica handles a portion of the input.
During the backwards pass, gradients from each node are averaged.
So, when the model is created with some weights it is replicated on all devices (each GPU for each node). Now each GPU gets a part of input (say, for total batch size equal to 1024
, 4
nodes each with 4
GPUs, each GPU would get 64
elements), calculates forward pass, loss, performs backprop via .backward()
tensor method. Now all gradients are averaged by all-gather, parameters are optimized on root
machine and parameters are distributed to all nodes so module's state is always the same across all machines.
Note: I'm not sure how this averaging exactly takes place (and I don't see it explicitly said in docs), though I assume those are first averaged across GPUs and later across all nodes as it would be the most efficient I think.
Now, why would you save model for each node
in such case? In principle you could only save one (as all modules will be exactly the same), but it has some downsides:
- Say your node where your model was saved crashes and the file is lost. You have to redo all the stuff. Saving each model is not too costly operation (done once per epoch or less) so it can be easily done for each node/worker
- You have to restart training. This means model would have to be copied to each worker (and some necessary metadata, though I don't think it's the case here)
- Nodes will have to wait for every forward pass to finish anyway (so the gradients can be averaged), if the model saving takes a lot of time it would waste GPU/CPU being idle (or some other synchronization scheme would have to be applied, I don't think there is one in PyTorch). This makes it somewhat "no-cost" if you look at the overall picture.
Question 2 (and 3)
And which model should I used for if I get multiple model?
It doesn't matter as all of them will be exactly the same as the same corrections via optimizer are applied to the model with the same initial weights.
You could use something along those lines to load your saved .pth
model:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parallel_model = torch.nn.DataParallel(MyModelGoesHere())
parallel_model.load_state_dict(
torch.load("my_saved_model_state_dict.pth", map_location=str(device))
)
# DataParallel has model as an attribute
usable_model = parallel_model.model