Cross validation in deep neural networks
Asked Answered
D

3

8

How do you perform cross-validation in a deep neural network? I know that to perform cross validation to will train it on all folds except one and test it on the excluded fold. Then do this for k fold times and average the accuries for each fold. How do you do this for each iteration. Do you update the parameters at each fold? Or you perform k-fold cross validation for each iteration? Or is each training on all folds but one fold considered as one iteration?

Disk answered 10/6, 2017 at 16:39 Comment(6)
You do K-Fold Cross Validation the same way as any other ML model, you just train K models. This has nothing to do with iterations.Louise
What do you mean by this? Because we update the parameters every iteration right? So is doing cross validation considered as one iteration?Disk
No , updating parameters has nothing to do with cross validation!Louise
Hmm what is it good for then?Disk
To get a less biased view of model performance, you might train a model in a train/test split and get good performance just because of chance, so how does model performance change when you vary the training data? That's the question K-Fold CV is supposed to answer.Louise
So cross validation is done outside training?Disk
F
4

Stratified cross validation There are a couple of solution available to run deep neural network in fold-cross validation.

def create_baseline():
    # create model
    model = Sequential()
    model.add(Dense(60, input_dim=11, kernel_initializer='normal', activation='relu'))
    model.add(Dense(1, kernel_initializer='normal', activation='sigmoid'))
    # Compile model
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model
# evaluate model with standardized dataset
estimator = KerasClassifier(build_fn=create_baseline, epochs=100, batch_size=5,     verbose=0)
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
results = cross_val_score(estimator, X, encoded_Y, cv=kfold)
print("Baseline: %.2f%% (%.2f%%)" % (results.mean()*100, results.std()*100))
Flatways answered 9/12, 2018 at 7:7 Comment(0)
A
1

Cross-validation is a general technique in ML to prevent overfitting. There is no difference between doing it on a deep-learning model and doing it on a linear regression. The idea is the same for all ML models. The basic idea behind CV, you described in your question is correct.

But the question how do you do this for each iteration does not make sense. There is nothing in CV algorithm that relates to iterations while training. You trained your model and only then you evaluate it.

Do you update the parameters at each fold?. You train the same model k-times and most probably each time you will have different parameters.


The answer that CV is not needed in DL is wrong. The basic idea of CV is to have a better estimate of how your model is performing on a limited dataset. So if your dataset is small the ability to train k models will give you a better estimate (the downsize is that you spend ~k times more time). If you have 100mln examples, most probably having 5% testing/validation set will already give you a good estimate.

Angiosperm answered 10/6, 2017 at 22:27 Comment(3)
Is 44k samples sufficient? And what I meant was, an iteration means updating the parameter. So if at each fold you update the parameter then wouldn't that be considered as an iteration?Disk
@Disk Is 44k samples sufficient? no one except of you can answer this question.Angiosperm
What I meant was if it's sufficient for using holdout instead of cross validationDisk
C
0

Just wanted to say that k-fold x-validation is not very common in deep machine vision, which is what I know most about. This is due to a few factors:

  • Large computational costs of k-fold x-validation with deep vision. (This may be less of a concern if you have huge cloud compute resources.)

  • Availability of very large datasets (e.g., ImageNet). Splitting such data sets into training/validation/test data arguably lets you effectively evaluate model generalizability and guards against overfitting. If not, see final point.

  • Data augmentation methods in machine vision help prevent overfitting.

  • The ease of transfer learning in machine vision makes people less concerned nowadays. If your model ends up failing on some future corner case, it is really easy to train it on new data and get it up to speed. Models can be fixed. They are code.

So while the answer is that you can do kfold validation just like with any other class of model, you should think about whether it is useful in your use case.

Carrera answered 13/10 at 14:40 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.