Which PyTorch modules are affected by model.eval() and model.train()?
Asked Answered
M

2

9

The model.eval() method modifies certain modules (layers) which are required to behave differently during training and inference. Some examples are listed in the docs:

This has [an] effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Is there an exhaustive list of which modules are affected?

Mcatee answered 8/3, 2021 at 17:53 Comment(1)
I think... that's about that? I dont recall any other standard layer that changes its behaviour, but maybe I'm wrong and I would be promptly corrected if the list existed :)) I take into account all layers that inherits from BatchNorm of courseHerbart
R
11

In addition to info provided by @iacob:

Base class Module Criteria
RNNBase RNN
LSTM
GRU
dropout > 0 (default: 0)
Transformer layers Transformer
TransformerEncoder
TransformerDecoder
dropout > 0 (Transformer default: 0.1)
Lazy variants LazyBatchNorm
currently nightly
merged PR
track_running_stats=True
Regrate answered 13/3, 2021 at 14:38 Comment(1)
GroupNorm and LayerNorm do not track running stats and are not affected by model.eval().Arnaldo
M
7

Searching site:https://pytorch.org/docs/stable/generated/torch.nn. "during evaluation" on google, it would appear the following modules are affected:

Base class Modules Criteria
_InstanceNorm InstanceNorm1d
InstanceNorm2d
InstanceNorm3d
track_running_stats=True
_BatchNorm BatchNorm1d
BatchNorm2d
BatchNorm3d
SyncBatchNorm
_DropoutNd Dropout
Dropout2d
Dropout3d
AlphaDropout
FeatureAlphaDropout
Mcatee answered 9/3, 2021 at 9:45 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.