当前位置: 首页 > 图灵资讯 > 行业资讯> 如何实现python绘制混淆矩阵?

如何实现python绘制混淆矩阵?

来源:图灵python
时间: 2024-12-25 17:50:38

大家从python基础到如今的入门,想必都对python有一定基础,今天小编给大家带来一个关于python的高阶内容——绘制混淆矩阵,一起来看下吧~

介绍:

混淆矩阵通过表示正确/不正确标签的计数来表示模型在表格格式中的准确性。

计算/绘制混淆矩阵:

以下是计算混淆矩阵的过程。

您需要一个包含预期结果值的测试数据集或验证数据集。

  • 对测试数据集中的每一行进行预测。

  • 从预期的结果和预测计数:

  • 每个类的正确预测数量。

  • 每个类的错误预测数量,由预测的类组织。

然后将这些数字组织成表格或矩阵,如下所示:

  • Expected down the side:矩阵的每一行都对应一个预测的类。

  • Predicted across the top:矩阵的每一列对应于一个实际的类。

然后将正确和不正确分类的计数填入表格中。

Reading混淆矩阵:

一个类的正确预测的总数进入该类值的预期行,以及该类值的预测列。

以同样的方式,一个类别的不正确预测总数进入该类别值的预期行,以及该类别值的预测列。

对角元素表示预测标签等于真实标签的点的数量,而非对角线元素是分类器错误标记的元素。混淆矩阵的对角线值越高越好,表明许多正确的预测。

Python绘制混淆矩阵 :

importitertools

importnumpyasnp

importmatplotlib.pyplotasplt

fromsklearnimportsvm,datasets

fromsklearn.model_selectionimporttrain_test_split

fromsklearn.metricsimportconfusion_matrix

#importsomedatatoplaywith

iris=datasets.load_iris()

X=iris.data

y=iris.target

class_names=iris.target_names

#Splitthedataintoatrainingsetandatestset

X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)

#Runclassifier,usingamodelthatistooregularized(Ctoolow)tosee

#theimpactontheresults

classifier=svm.SVC(kernel='linear',C=0.01)

y_pred=classifier.fit(X_train,y_train).predict(X_test)

defplot_confusion_matrix(cm,classes,

normalize=False,

title='Confusionmatrix',

cmap=plt.cm.Blues):

"""

Thisfunctionprintsandplotstheconfusionmatrix.

Normalizationcanbeappliedbysetting`normalize=True`.

"""

ifnormalize:

cm=cm.astype('float')/cm.sum(axis=1)[:,np.newaxis]

print("Normalizedconfusionmatrix")

else:

print('Confusionmatrix,withoutnormalization')

print(cm)

plt.imshow(cm,interpolation='nearest',cmap=cmap)

plt.title(title)

plt.colorbar()

tick_marks=np.arange(len(classes))

plt.xticks(tick_marks,classes,rotation=45)

plt.yticks(tick_marks,classes)

fmt='.2f'ifnormalizeelse'd'

thresh=cm.max()/2.

fori,jinitertools.product(range(cm.shape[0]),range(cm.shape[1])):

plt.text(j,i,format(cm[i,j],fmt),

horizontalalignment="center",

color="white"ifcm[i,j]>threshelse"black")

color="white"ifcm[i,j]>threshelse"black")

plt.tight_layout()

plt.ylabel('Truelabel')

plt.xlabel('Predictedlabel')

#Computeconfusionmatrix

cnf_matrix=confusion_matrix(y_test,y_pred)

np.set_printoptions(precision=2)

#Plotnon-normalizedconfusionmatrix

plt.figure()

plot_confusion_matrix(cnf_matrix,classes=class_names,

title='Confusionmatrix,withoutnormalization')

#Plotnormalizedconfusionmatrix

plt.figure()

plot_confusion_matrix(cnf_matrix,classes=class_names,normalize=True,

title='Normalizedconfusionmatrix')

plt.show()

Confusionmatrix,withoutnormalization

[[1300]

[0106]

[009]]

Normalizedconfusionmatrix

[[1.0.0.]

[0.0.620.38]

[0.0.1.]]

好了,大家可以消化学习下哦~如需了解更多python实用知识,点击进入PyThon学习网教学中心