calculate accuracy for each class using CNN and pytorch
Asked Answered
S

4

5

I Can calculate accuracy after each epoch using this code . But, I want to calculate the accuracy for each class at the end . how can i do that? I have two folders train and val . each folder has 7 folders of 7 different classes. the train folder is used for training .otherwise val folder is used for testing

  def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
    since = time.time()

    best_model = model
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                mode='train'
                optimizer = lr_scheduler(optimizer, epoch)
                model.train()  # Set model to training mode
            else:
                model.eval()
                mode='val'

            running_loss = 0.0
            running_corrects = 0

            counter=0
            # Iterate over data.
            for data in dset_loaders[phase]:
                inputs, labels = data
                print(inputs.size())
                # wrap them in Variable
                if use_gpu:
                    try:
                        inputs, labels = Variable(inputs.float().cuda()),                             
                        Variable(labels.long().cuda())
                    except:
                        print(inputs,labels)
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                # Set gradient to zero to delete history of computations in previous epoch. Track operations so that differentiation can be done automatically.
                optimizer.zero_grad()
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)
                
                loss = criterion(outputs, labels)
                # print('loss done')                
                # Just so that you can keep track that something's happening and don't feel like the program isn't running.
                # if counter%10==0:
                #     print("Reached iteration ",counter)
                counter+=1

                # backward + optimize only if in training phase
                if phase == 'train':
                    # print('loss backward')
                    loss.backward()
                    # print('done loss backward')
                    optimizer.step()
                    # print('done optim')
                # print evaluation statistics
                try:
                    # running_loss += loss.data[0]
                    running_loss += loss.item()
                    # print(labels.data)
                    # print(preds)
                    running_corrects += torch.sum(preds == labels.data)
                    # print('running correct =',running_corrects)
                except:
                    print('unexpected error, could not calculate loss or do a sum.')
            print('trying epoch loss')
            epoch_loss = running_loss / dset_sizes[phase]
            epoch_acc = running_corrects.item() / float(dset_sizes[phase])
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))


            # deep copy the model
            if phase == 'val':
                if USE_TENSORBOARD:
                    foo.add_scalar_value('epoch_loss',epoch_loss,step=epoch)
                    foo.add_scalar_value('epoch_acc',epoch_acc,step=epoch)
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model = copy.deepcopy(model)
                    print('new best accuracy = ',best_acc)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    print('returning and looping back')
    return best_model


def exp_lr_scheduler(optimizer, epoch, init_lr=BASE_LR, lr_decay_epoch=EPOCH_DECAY):
    """Decay learning rate by a factor of DECAY_WEIGHT every lr_decay_epoch epochs."""
    lr = init_lr * (DECAY_WEIGHT**(epoch // lr_decay_epoch))

    if epoch % lr_decay_epoch == 0:
        print('LR is set to {}'.format(lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer

 
Sauger answered 17/7, 2020 at 16:35 Comment(0)
T
7

Calculating overall accuracy is rather straight forward:

outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)

acc_all = (preds == labels).float().mean()

To calculate it per class requires a few more lines of code:

acc = [0 for c in list_of_classes]
for c in list_of_classes:
    acc[c] = ((preds == labels) * (labels == c)).float().sum() / (max(labels == c).sum(), 1))
Temperature answered 17/7, 2020 at 16:55 Comment(2)
short sweet and effective.South
shouldn't it be ((preds == labels) * (labels == c)).float().sum() / (max(labels == c).sum(), 1))?Adjudicate
L
2

You can also consider using sklearn classification_report for a detailed report on multi-class classification model performance. It gives you parameters like precision, recall and f1-score for all the classes and then macro and weighted average overall.

You can use this code snippet to do that.

from sklearn.metrics import classification_report
output = model(test_input.float())
_, predictions = torch.max(output, dim = 1)

print(classification_report(true_labels, predictions))
Loisloise answered 29/3, 2022 at 19:23 Comment(0)
H
0

Hit upon this while trying to understand the problems with my CNN. Have used the solution of Victor. Also wanted to check which classes were not getting trained properly and also which classes were wrongly getting classified with other

code here https://github.com/alexcpn/cnn_lenet_pytorch/blob/main/cnn/model_accuracy.py

Snippet below

with torch.no_grad():
    model.eval() #IMPORTANT set model to eval mode before inference
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        #print("Outputs=",outputs.shape) #Outputs= torch.Size([64, 10])
        _, predicted = torch.max(outputs.data, 1) # get the class with the most probability out
        #print("predicted=",predicted.shape,predicted[10]) # predicted= torch.Size([64])
        #print("labels=",labels.shape,labels[10]) #labels= torch.Size([64]) 
        total += labels.size(0)
        correct += (predicted == labels).float().sum().item()  #this is Torch Tensor semantics
        #print("correct",correct) # say 56 out of 64
        #print("classification_report",classification_report(labels.cpu(), predicted.cpu()))
        #-------- Lets check also which classes are wrongly predicted with other classes (we need to clip at max prob > .5 to do)
        mask=(predicted != labels)
        wrong_predicted =torch.masked_select(predicted,mask)
        wrong_labels =torch.masked_select(labels,mask)
        zipped = zip(wrong_labels,wrong_predicted)

        for _,j in enumerate(zipped):
            wrong_per_class[j[0].item()].append(j[1].item())
            #print(f"wrong_per_class{j[0].item()}={j[1].item()}",)

        for index, element in enumerate(categories):
            cal = ((predicted == labels)*(labels ==index)).sum().item()/ ((labels == index).sum()) #this is Torch Tensor semantics
            wrong_class = (predicted != labels)*(labels == index)
            # >>> import torch
            # >>> some_integers = torch.tensor((2, 3, 5, 7, 11, 13, 17, 19))
            # >>> some_integers3 = torch.tensor((12, 3, 5, 7, 11, 13, 17, 19))
            # >>> (some_integers ==some_integers3)*(some_integers == 3)
            # tensor([False,  True, False, False, False, False, False, False])
            # >>> ((some_integers ==some_integers3)*(some_integers >12)).sum().item()
            # 3
            if not math.isnan(cal):
                precision_per_class[element].append(cal.item())
            #print(f"{element}={cal}")
        
    avg_accuracy =[]    
    for key,val in precision_per_class.items():
        avg = np.mean(val)
        precision_per_class[key] = avg
        avg_accuracy.append(avg)
        print(f"Accuracy of Class {key}={avg}")

    # Just to cross check with the average accuracy results bleow    
    print(f"Average accuracy={np.mean(avg_accuracy)}")

    for key,val in wrong_per_class.items():
        print(f"wrong_per_class {categories[key]}={Counter(val)}")

    print(
        "Accuracy of the network on the {} test/validation images: {} %".format(
            total, 100 * correct / total
        )
    )
    

Ouput

Accuracy of Class tench=0.8504464285714286
Accuracy of Class English springer=0.6907253691128322
Accuracy of Class cassette player=0.7420465648174286
Accuracy of Class chain saw=0.5169889160564968
Accuracy of Class church=0.6264965534210205
Accuracy of Class French horn=0.5337499976158142
Accuracy of Class garbage truck=0.7543565290314811
Accuracy of Class gas pump=0.5343750034059797
Accuracy of Class golf ball=0.5873511944498334
Accuracy of Class parachute=0.5481353274413517
Average accuracy=0.6384671883923666
wrong_per_class tench=Counter({3: 25, 8: 16, 1: 10, 2: 7, 6: 3, 5: 3, 7: 2, 9: 1})
wrong_per_class English springer=Counter({3: 39, 0: 23, 8: 21, 6: 7, 5: 7, 7: 3, 9: 3, 4: 3, 2: 3})
wrong_per_class cassette player=Counter({7: 36, 6: 14, 3: 13, 8: 11, 0: 8, 5: 4, 1: 4, 4: 2})
wrong_per_class chain saw=Counter({0: 49, 1: 30, 6: 27, 7: 22, 5: 21, 4: 19, 2: 12, 8: 8, 9: 4})
wrong_per_class church=Counter({6: 23, 5: 21, 3: 20, 7: 19, 8: 16, 0: 14, 2: 10, 9: 7, 1: 5})
wrong_per_class French horn=Counter({3: 64, 4: 26, 2: 22, 1: 21, 7: 19, 0: 13, 8: 12, 6: 11})
wrong_per_class garbage truck=Counter({3: 28, 4: 23, 2: 14, 7: 14, 0: 8, 5: 5, 1: 4, 8: 2})
wrong_per_class gas pump=Counter({2: 50, 6: 46, 3: 41, 4: 23, 1: 11, 5: 9, 8: 8, 0: 7, 9: 2})
wrong_per_class golf ball=Counter({1: 38, 0: 37, 3: 27, 4: 17, 9: 11, 5: 10, 2: 9, 6: 7, 7: 6})
wrong_per_class parachute=Counter({8: 56, 3: 46, 4: 19, 6: 13, 7: 12, 0: 10, 2: 6, 1: 6, 5: 2})
Accuracy of the network on the 3925 test/validation images: 64.07643312101911 %

Will update the answer with some proper plot later with this data

Hoisch answered 19/10, 2022 at 14:16 Comment(0)
H
0

A more accurate way than my previous answer, by creating a confusion matrix first and then inferring from that; Will help in other analysis of training as well

# In test phase, we don't need to compute gradients (for memory efficiency)
with torch.no_grad():
    model.eval() #IMPORTANT set model to eval mode before inference
    correct = 0
    total = 0


    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        # ------------------------------------------------------------------------------------------
        # Predict for the batch of images
        # ------------------------------------------------------------------------------------------
        outputs = model(images)  #Outputs= torch.Size([64, 10]) Probability of each of the 10 classes
        _, predicted = torch.max(outputs.data, 1) # get the class with the highest Probability out Given 1 per image # predicted= torch.Size([64])
        total += labels.size(0) #labels= torch.Size([64])  This is the truth value per image - the right class
        correct += (predicted == labels).float().sum().item()  # Find which are correctly classified
        
        # ------------------------------------------------------------------------------------------
        #  Lets check also which classes are wrongly predicted with other classes  to create a MultiClass confusion matrix
        # ------------------------------------------------------------------------------------------

        mask=(predicted != labels) # Wrongly predicted
        wrong_predicted =torch.masked_select(predicted,mask)
        wrong_labels =torch.masked_select(labels,mask)
        wrongly_zipped = zip(wrong_labels,wrong_predicted)

        mask=(predicted == labels) # Rightly predicted
        rightly_predicted =torch.masked_select(predicted,mask)
        right_labels =rightly_predicted #same torch.masked_select(labels,mask)
        rightly_zipped = zip(right_labels,rightly_predicted)
        
        # Note that this is for a single batch - add to the list associated with class
        for _,j in enumerate(wrongly_zipped):
            k = j[0].item() # label
            l = j[1].item() # predicted
            wrong_per_class[k].append(l)
            confusion_matrix[k][l] +=1
       
        # Note that this is for a single batch - add to the list associated with class
        for _,j in enumerate(rightly_zipped):
            k = j[0].item() # label
            l = j[1].item() # predicted
            right_per_class[k].append(l)
            confusion_matrix[k][l] +=1
    
    #print("Confusion Matrix1=\n",confusion_matrix)
    # ------------------------------------------------------------------------------------------
    # Print Confusion matrix in Pretty print format
    # ------------------------------------------------------------------------------------------
    print(categories)
    for i in range(len(categories)):
        for j in range(len(categories)):
            print(f"\t{confusion_matrix[i][j]}",end='')
        print(f"\t{categories[i]}\n",end='')
    # ------------------------------------------------------------------------------------------
    # Calculate Accuracy per class
    # ------------------------------------------------------------------------------------------
    print("---------------------------------------")
    total_correct =0
    for i in range(len(categories)):
        print(f"Average accuracy per class {categories[i]} from confusion matrix {confusion_matrix[i][i]/confusion_matrix[i].sum()}")
        total_correct +=confusion_matrix[i][i]

    print(f"Average Accuracy/precision from the confusion matrix is {total_correct/confusion_matrix.sum()}")

    # Overall accuracy as below
    print(
        "Accuracy of the network on the {} test/validation images: {} %".format(
            total, 100 * correct / total
        )
    )

Note that this uses the Torch Tensor Semantics below

 # Below illustrates the above Torch Tensor semantics
        # >>> import torch
        # >>> some_integers = torch.tensor((2, 3, 5, 7, 11, 13, 17, 19))
        # >>> some_integers3 = torch.tensor((12, 3, 5, 7, 11, 13, 17, 19))
        # >>> (some_integers ==some_integers3)*(some_integers == 3)
        # tensor([False,  True, False, False, False, False, False, False])
        # >>> ((some_integers ==some_integers3)*(some_integers >12)).sum().item()
        # 3

Output

2022-10-20 13:38:01,112 Gpu device NVIDIA GeForce RTX 3060 Laptop GPU
['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']
        320.0   10.0    7.0     25.0    0.0     3.0     3.0     2.0     16.0    1.0     tench
        23.0    286.0   3.0     39.0    3.0     7.0     7.0     3.0     21.0    3.0     English springer
        8.0     4.0     265.0   13.0    2.0     4.0     14.0    36.0    11.0    0.0     cassette player
        49.0    30.0    12.0    194.0   19.0    21.0    27.0    22.0    8.0     4.0     chain saw
        14.0    5.0     10.0    20.0    274.0   21.0    23.0    19.0    16.0    7.0     church
        13.0    21.0    22.0    64.0    26.0    206.0   11.0    19.0    12.0    0.0     French horn
        8.0     4.0     14.0    28.0    23.0    5.0     291.0   14.0    2.0     0.0     garbage truck
        7.0     11.0    50.0    41.0    23.0    9.0     46.0    222.0   8.0     2.0     gas pump
        37.0    38.0    9.0     27.0    17.0    10.0    7.0     6.0     237.0   11.0    golf ball
        10.0    6.0     6.0     46.0    19.0    2.0     13.0    12.0    56.0    220.0   parachute
---------------------------------------
Average accuracy per class tench from confusion matrix 0.8268733850129198
Average accuracy per class English springer from confusion matrix 0.7240506329113924
Average accuracy per class cassette player from confusion matrix 0.742296918767507
Average accuracy per class chain saw from confusion matrix 0.5025906735751295
Average accuracy per class church from confusion matrix 0.6699266503667481
Average accuracy per class French horn from confusion matrix 0.5228426395939086
Average accuracy per class garbage truck from confusion matrix 0.7480719794344473
Average accuracy per class gas pump from confusion matrix 0.5298329355608592
Average accuracy per class golf ball from confusion matrix 0.5939849624060151
Average accuracy per class parachute from confusion matrix 0.5641025641025641
Average Accuracy/precision from the confusion matrix is 0.640764331210191
Accuracy of the network on the 3925 test/validation images: 64.07643312101911
Hoisch answered 20/10, 2022 at 8:24 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.