How to output per-class accuracy in Keras?
Asked Answered
F

4

29

Caffe can not only print overall accuracy, but also per-class accuracy.

In Keras log, there's only overall accuracy. It's hard for me to calculate the separate class accuracy.

Epoch 168/200

0s - loss: 0.0495 - acc: 0.9818 - val_loss: 0.0519 - val_acc: 0.9796

Epoch 169/200

0s - loss: 0.0519 - acc: 0.9796 - val_loss: 0.0496 - val_acc: 0.9815

Epoch 170/200

0s - loss: 0.0496 - acc: 0.9815 - val_loss: 0.0514 - val_acc: 0.9801

Anybody who knows how to output per-class accuracy in keras?

Fortier answered 29/8, 2017 at 4:35 Comment(0)
B
37

Precision & recall are more useful measures for multi-class classification (see definitions). Following the Keras MNIST CNN example (10-class classification), you can get the per-class measures using classification_report from sklearn.metrics:

from sklearn.metrics import classification_report
import numpy as np

Y_test = np.argmax(y_test, axis=1) # Convert one-hot to index
y_pred = model.predict_classes(x_test)
print(classification_report(Y_test, y_pred))

Here is the result:

         precision    recall  f1-score   support

      0       0.99      1.00      1.00       980
      1       0.99      0.99      0.99      1135
      2       1.00      0.99      0.99      1032
      3       0.99      0.99      0.99      1010
      4       0.98      1.00      0.99       982
      5       0.99      0.99      0.99       892
      6       1.00      0.99      0.99       958
      7       0.97      1.00      0.99      1028
      8       0.99      0.99      0.99       974
      9       0.99      0.98      0.99      1009

avg / total   0.99      0.99      0.99     10000
Brigid answered 29/8, 2017 at 13:49 Comment(15)
@desertnaut.Thanks a lot, This is very usefull for me. This code can output the per-class accuracy in the test time, but how to print the per-class accuracy in the trainning process in Keras?Fortier
For how many classes? You would really like to print ~ 20 numbers per training epoch??Brigid
@desertnaut.Two classes per trainning epoch. I found some code here:[link]github.com/fchollet/keras/blob/…, but It also prints the over precision and recall ,not the per-class. And I'm not sure this code will work correctly when trainning on batch.Fortier
For 2 classes (binary classification) the accuracy is the same for both classes! Check the definition...Brigid
@Brigid I don't understand... my case is that number of objects in each classification are extremely unbalanced, even though it is a true/false classification, I don't think the accuracy of both classes are the same.Melson
@Melson It is - look up the definition of (binary case) accuracy; the only difference in imbalanced cases is that accuracy is no longer very useful, since it can be high by naively classifying everything to the majority class. For more details, pls open a new question.Brigid
The above function returns a whole string. Is there a way to get them as structured data that I can later convert to JSON?Accoucheur
@Accoucheur you are right, it's a string; haven't thought much about converting it into JSON...Brigid
@Brigid Not same for both classes; a batch (e.g. size 32) can have 4 samples = '1', 28 = '0'; one class's accuracy being 100% or 25% says nothing about the other's.Needful
@Needful I kindly suggest you revisit the definition of accuracy, because you sound confused. There is no such thing as "100% accuracy for one class" - 100% accuracy means correctly predicting the class for all samples; truth is, in class imbalance settings (such as the one you imply here) accuracy is no longer useful, but this does not mean that its definition changes.Brigid
@Brigid OP inquired about per-class accuracy - which has only one definition: (# of samples in class predicted correctly) / (total # of samples in class). -- This metric is useful; for binary classification, an accuracy of ~1 for '0' and ~0 for '1' is indicative of the model blindly guessing the majority class to be 'safe', and failing to learn any real features -- to name one example.Needful
@Needful you are still confusing things: this is the exact definition of recall, or true positive rate (TPR), and not of accuracy; again, the fact that accuracy is not useful in imbalanced settings does not mean that its definition changes in such cases, which is always (# of correctly classified samples) / (total # of samples). In fact, OP asked for something that does not exist (a fact that was already implied in my answer).Brigid
@Brigid Quickly realized this after posting; what OP and I call "per-class accuracy" happens to exactly be per-class TPR. In my application, I've fed samples from one class at a time, with Keras returning 'accuracy', so there was a need to separate results for logging/averaging purposes - hence 'per class'. -- So we're both right on the idea, but you are correct on the 'formal' wording.Needful
@Needful understood, but terminology (and not wording) has its own importance, and it is not just a formality. And again, the per-class precision & recall in a multi-class setting is exactly what I have provided in my answer, so it is unclear what more is offered by your own one...Brigid
@Brigid agreed @importance, but this case's rather involved in that both Keras .train and theory definition of "accuracy" is intended. My solution spares an import, and shows a way to store separate class metrics averaged over >1 train updates - but more importantly, how to get per-class metrics during training.Needful
P
7

You are probably looking to use a callback, which you can easily add to the model.fit() call.

For example, you can define your own class using the keras.callbacks.Callback interface. I recommend using the on_epoch_end() function since it will format nicely inside of your training summary if you decide to print with that verbosity setting. Please note that this particular code block is set to use 3 classes, but you can of course change it to your desired number.

# your class labels
classes = ["class_1","class_2", "class_3"]

class AccuracyCallback(tf.keras.callbacks.Callback):

    def __init__(self, test_data):
        self.test_data = test_data

    def on_epoch_end(self, epoch, logs=None):
        x_data, y_data = self.test_data

        correct = 0
        incorrect = 0

        x_result = self.model.predict(x_data, verbose=0)

        x_numpy = []

        for i in classes:
            self.class_history.append([])

        class_correct = [0] * len(classes)
        class_incorrect = [0] * len(classes)

        for i in range(len(x_data)):
            x = x_data[i]
            y = y_data[i]

            res = x_result[i]

            actual_label = np.argmax(y)
            pred_label = np.argmax(res)

            if(pred_label == actual_label):
                x_numpy.append(["cor:", str(y), str(res), str(pred_label)])     
                class_correct[actual_label] += 1   
                correct += 1
            else:
                x_numpy.append(["inc:", str(y), str(res), str(pred_label)])
                class_incorrect[actual_label] += 1
                incorrect += 1

        print("\tCorrect: %d" %(correct))
        print("\tIncorrect: %d" %(incorrect))

        for i in range(len(classes)):
            tot = float(class_correct[i] + class_incorrect[i])
            class_acc = -1
            if (tot > 0):
                class_acc = float(class_correct[i]) / tot

            print("\t%s: %.3f" %(classes[i],class_acc)) 

        acc = float(correct) / float(correct + incorrect)  

        print("\tCurrent Network Accuracy: %.3f" %(acc))

Then, you are going to want to configure your new callback to your model fit. Assuming your validation data (val_data) is some tuple pair, you can use the following:

accuracy_callback = AccuracyCallback(val_data)

# you can use the history if desired
history = model.fit( x=_, y=_, verbose=1, 
           epochs=_, shuffle=_, validation_data = val_data,
           callbacks=[accuracy_callback], batch_size=_
         )

Please note that the _ indicates values likely to change based on your configuration

Pattypatulous answered 7/5, 2020 at 23:47 Comment(2)
please, what is the purpose for the self.class_history.append([]) loop? It throws me AttributeError: 'AccuracyCallback' object has no attribute 'class_history', but I before I try to fix it I would like to know what does the piece of code do.Bookish
getting this: ValueError: too many values to unpack (expected 2)Refugia
N
1

For train per-class accuracy: implement below on training dataset - after (and/or before) training on the dataset.


For raw per-class validation accuracy:
def per_class_accuracy(y_preds,y_true,class_labels):
    return [np.mean([
        (y_true[pred_idx] == np.round(y_pred)) for pred_idx, y_pred in enumerate(y_preds) 
      if y_true[pred_idx] == int(class_label)
                    ]) for class_label in class_labels]

def update_val_history():
    [val_history[class_label].append(np.mean( np.asarray(temp_history).T[class_idx] )
                             ) for class_idx, class_label in enumerate(class_labels)]

Example:

class_labels = ['0','1','2','3']
val_history = {class_label:[] for class_label in class_labels}

y_true   = np.asarray([0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3])
y_preds1 = np.asarray([0,3,3,3, 1,1,0,0, 2,2,2,0, 3,3,3,3])
y_preds2 = np.asarray([0,0,3,3, 0,1,0,0, 2,2,2,2, 0,0,0,0])

y_preds1 = model.predict(x1)
temp_hist.append(per_class_accuracy(y_preds1,y_true,class_labels))
update_val_history()
y_preds2 = model.predict(x2)
temp_hist.append(per_class_accuracy(y_preds2,y_true,class_labels))
update_val_history()

print(val_history)

>>{
'0': [0.25, 0.50],
'1': [0.50, 0.25],
'2': [0.75, 1.00],
'3': [1.00, 0.00]
}

Needful answered 16/6, 2019 at 16:18 Comment(0)
W
1

Update to the solution provided by Solution by desertnaut:
Now in Keras, you will get an error

AttributeError: 'Sequential' object has no attribute 'predict_classes'"

To fix the error use the following code:

from sklearn.metrics import classification_report
import numpy as np

Y_test = np.argmax(y_test, axis=1) # Convert one-hot to index
y_pred = np.argmax(model.predict(x_test), axis=-1)
print(classification_report(Y_test, y_pred))
Wriest answered 24/1, 2022 at 17:0 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.