之前写过一篇关于在scikit-learn工具包中,可视化estimator分类模型分类结果的confusion matrix混淆矩阵可视化的方法,具体可以参考看这里,看这里。今天这篇介绍一下如何使用scikit-learn工具中提供的相关方法,可视化其他任意框架(比如深度学习框架)的分类模型预测结果的混淆矩阵。

        下面先说一下几个关键步骤:

1、确定类别列表,类别列表和one-hot的编码顺序一致,这里使用cifar-10的类别列表作为演示的例子。

classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"

2、准备好样本的真实label,这里我手动构造一个1000个样本的label,每一类100个。

# 生成数据集的GT标签
gt_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
    gt_labels[i] = i
gt_labels = gt_labels.reshape(1, -1).squeeze()
print("gt_labels.shape : {}".format(gt_labels.shape))
print("gt_labels : {}".format(gt_labels[::5]))

3、准备好样本的预测label,这里我也手动构造这1000个样本的预测label,构造时才用了一点规则,构造出来的预测结果保证从第0类到第9类的预测准确率是逐渐降低的。

# 生成数据集的预测标签
pred_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
    # 标签生成规则:对于真值类别编号为i的数据,生成的预测类别编号为[0, i-1]之间的随机值
    # 这样生成的预测准确率从0到9逐渐递减
    pred_labels[i] = np.random.randint(0, i + 1, 100)
pred_labels = pred_labels.reshape(1, -1).squeeze()
print("pred_labels.shape : {}".format(pred_labels.shape))
print("pred_labels : {}".format(pred_labels[::5]))

4、计算真是label和预测label的混淆矩阵,直接调用scikit-learn中的confusion_matrix方法

# 使用sklearn工具中confusion_matrix方法计算混淆矩阵
confusion_mat = confusion_matrix(gt_labels, pred_labels)
print("confusion_mat.shape : {}".format(confusion_mat.shape))
print("confusion_mat : {}".format(confusion_mat))

5、混淆矩阵可视化,在scikit-learn工具中有一个plot_confusion_matrix方法可以可视化sklearn训练的模型estimator的混淆矩阵,具体参数如下:

        但是,现在的问题是我们使用的是别的框架训练的模型,也就没有这个estimator参数可以供sklearn使用,怎么办?

        我们看一下plot_confusion_matrix函数的代码可以发现,他其实内部调用了以下方法:

         那么,我们也仿照这个调用方式来写一下试试,代码如下:

# 使用sklearn工具包中的ConfusionMatrixDisplay可视化混淆矩阵,参考plot_confusion_matrix
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_mat, display_labels=classes)
disp.plot(
    include_values=True,            # 混淆矩阵每个单元格上显示具体数值
    cmap="viridis",                 # 不清楚啥意思,没研究,使用的sklearn中的默认值
    ax=None,                        # 同上
    xticks_rotation="horizontal",   # 同上
    values_format="d"               # 显示的数值格式
)

 6、将以上代码整合一下,输入数据的真实label和预测label,就可以可视化混淆矩阵了,并且不仅局限于评估scikit-learn的estimator,可以适用于所有框架的输出结果,完整代码如下:

import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from matplotlib import pyplot as plt

classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

# 生成数据集的GT标签
gt_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
    gt_labels[i] = i
gt_labels = gt_labels.reshape(1, -1).squeeze()
print("gt_labels.shape : {}".format(gt_labels.shape))
print("gt_labels : {}".format(gt_labels[::5]))

# 生成数据集的预测标签
pred_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
    # 标签生成规则:对于真值类别编号为i的数据,生成的预测类别编号为[0, i-1]之间的随机值
    # 这样生成的预测准确率从0到9逐渐递减
    pred_labels[i] = np.random.randint(0, i + 1, 100)
pred_labels = pred_labels.reshape(1, -1).squeeze()
print("pred_labels.shape : {}".format(pred_labels.shape))
print("pred_labels : {}".format(pred_labels[::5]))

# 使用sklearn工具中confusion_matrix方法计算混淆矩阵
confusion_mat = confusion_matrix(gt_labels, pred_labels)
print("confusion_mat.shape : {}".format(confusion_mat.shape))
print("confusion_mat : {}".format(confusion_mat))

# 使用sklearn工具包中的ConfusionMatrixDisplay可视化混淆矩阵,参考plot_confusion_matrix
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_mat, display_labels=classes)
disp.plot(
    include_values=True,            # 混淆矩阵每个单元格上显示具体数值
    cmap="viridis",                 # 不清楚啥意思,没研究,使用的sklearn中的默认值
    ax=None,                        # 同上
    xticks_rotation="horizontal",   # 同上
    values_format="d"               # 显示的数值格式
)
plt.show()

7、混淆矩阵的可视化结果

        上图中的可视化结果符合我们在生成预测label标签时使用的规则,就是对于每个类别 i 的预测结果是0-i之间的随机值,这样的话,每个类别的预测误差只会出现在类别编号比它小的部分,也就是上图中展示的下三角矩阵。

        在混淆矩阵中,横轴上的标签标示样本的预测label,纵轴上的标签标示样本的实际label。所以,对角线上的数字表示预测label和真是label一致的数量,也就是预测正确的数量。对于其他位置的数字就表示预测错误的,举个例子,比如第2行、第1列,也就是对应着(airplane, automobile)位置的数字51,表示有51个真实label为automobile的样本被预测为了airplane。

        通过可视化的混淆矩阵,模型的误差,以及效果分类不好的类别,以及为什么不好,以及容易和哪个类之间出现误识别就一目了然了。

参考:https://blog.csdn.net/cxx654/article/details/107296343

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐