sklearn 的混淆矩阵不存储以下信息how矩阵已创建(类排序和标准化):这意味着您必须在创建混淆矩阵后立即使用它否则信息将会丢失。
默认情况下,sklearn.metrics.confusion_matrix(y_true,y_pred) https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html按照类在 y_true 中出现的顺序创建矩阵。
如果您将此数据传递给sklearn.metrics.confusion_matrix:
+--------+--------+
| y_true | y_pred |
+--------+--------+
| A | B |
| C | C |
| D | B |
| B | A |
+--------+--------+
Scikit-learn 将创建这个混淆矩阵(省略零):
+-----------+---+---+---+---+
| true\pred | A | C | D | B |
+-----------+---+---+---+---+
| A | | | | 1 |
| C | | 1 | | |
| D | | | | 1 |
| B | 1 | | | |
+-----------+---+---+---+---+
它会将这个 numpy 矩阵返回给您:
+---+---+---+---+
| 0 | 0 | 0 | 1 |
| 0 | 0 | 1 | 0 |
| 0 | 0 | 0 | 1 |
| 1 | 0 | 0 | 0 |
+---+---+---+---+
如果您想选择类或对它们重新排序,您可以将 'labels' 参数传递给confusion_matrix()
.
对于重新排序:
labels = ['D','C','B','A']
mat = confusion_matrix(true_y,pred_y, labels=labels)
或者,如果您只想关注一些标签(如果您有很多标签,则很有用):
labels = ['A','D']
mat = confusion_matrix(true_y,pred_y, labels=labels)
另外,看看sklearn.metrics.plot_confusion_matrix https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html。它非常适合小班(
如果您有 >100 个类,则需要使用白色来绘制矩阵。