How can I plot a confusion matrix? [duplicate]
Asked Answered
T

3

162

I am using scikit-learn for classification of text documents(22000) to 100 classes. I use scikit-learn's confusion matrix method for computing the confusion matrix.

model1 = LogisticRegression()
model1 = model1.fit(matrix, labels)
pred = model1.predict(test_matrix)
cm=metrics.confusion_matrix(test_labels,pred)
print(cm)
plt.imshow(cm, cmap='binary')

This is how my confusion matrix looks like:

[[3962  325    0 ...,    0    0    0]
 [ 250 2765    0 ...,    0    0    0]
 [   2    8   17 ...,    0    0    0]
 ..., 
 [   1    6    0 ...,    5    0    0]
 [   1    1    0 ...,    0    0    0]
 [   9    0    0 ...,    0    0    9]]

However, I do not receive a clear or legible plot. Is there a better way to do this?

Tina answered 23/2, 2016 at 8:6 Comment(1)
Check this answer for pure Matplotlib codeAdrien
C
237

enter image description here

you can use plt.matshow() instead of plt.imshow() or you can use seaborn module's heatmap (see documentation) to plot the confusion matrix

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[33,2,0,0,0,0,0,0,0,1,3], 
        [3,31,0,0,0,0,0,0,0,0,0], 
        [0,4,41,0,0,0,0,0,0,0,1], 
        [0,1,0,30,0,6,0,0,0,0,1], 
        [0,0,0,0,38,10,0,0,0,0,0], 
        [0,0,0,3,1,39,0,0,0,0,4], 
        [0,2,2,0,4,1,31,0,0,0,2],
        [0,1,0,0,0,0,0,36,0,2,0], 
        [0,0,0,0,0,0,1,5,37,5,1], 
        [3,0,0,0,0,0,0,0,0,39,0], 
        [0,0,0,0,0,0,0,0,0,0,38]]
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],
                  columns = [i for i in "ABCDEFGHIJK"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)
Cercaria answered 23/2, 2016 at 8:19 Comment(3)
mask_bad = X.mask if np.ma.is_masked(X) else np.isnan(X) # Mask nan's. TypeError: ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''Yonah
If there are numbers with 3 or more digits this prints them in normal form like 3.4e+02 for 340, due to the default fmt parameter. Setting it to like sn.heatmap(df_cm, annot=True, fmt='.10g') fixes this.Spermophile
The number are shown only for the first line of your confusion matrix....Decidua
E
128

@bninopaul 's answer is not completely for beginners

here is the code you can "copy and run"

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt

array = [[13,1,1,0,2,0],
         [3,9,6,0,1,0],
         [0,0,16,2,0,0],
         [0,0,0,13,0,0],
         [0,0,0,0,15,0],
         [0,0,1,0,0,15]]

df_cm = pd.DataFrame(array, range(6), range(6))
# plt.figure(figsize=(10,7))
sn.set(font_scale=1.4) # for label size
sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}) # font size

plt.show()

result

Elagabalus answered 16/2, 2017 at 5:41 Comment(3)
Just to add, for custom x and y labels, replace df_cm line with something like this: df_cm = pd.DataFrame(array, index=["stage 1", "stage 2", "stage 3", "stagte 4"], columns=["stage 1", "stage 2", "stage 3", "stagte 4"])Hoofbound
I'm not seeing why this answer is more "for beginners"?... It's basically the same as bninopaul's.Honor
The conf matrix is beginner-sized @DavidSkarbrevik ;)Buhrstone
D
81

IF you want more data in you confusion matrix, including "totals column" and "totals line", and percents (%) in each cell, like matlab default (see image below)

enter image description here

including the Heatmap and other options...

You should have fun with the module above, shared in the github ; )

https://github.com/wcipriano/pretty-print-confusion-matrix


This module can do your task easily and produces the output above with a lot of params to customize your CM: enter image description here

Doorbell answered 3/7, 2018 at 22:32 Comment(4)
Hi, thanks for this! Can you approve this PR, pip installing would be so much nicer: github.com/wcipriano/pretty-print-confusion-matrix/pull/11Allinclusive
hello Ian! Okay, I'll check and approve your PR, thanks for collaboration ; )Doorbell
Wasn't my PR, but thank you for approving! :)Allinclusive
Okay, It was PR 11 (Make the package available to be installed via PyPI). I saw this cause of your comment here in this thread, thanks!Doorbell

© 2022 - 2024 — McMap. All rights reserved.