大家从python基础到如今的入门,想必都对python有一定基础,今天小编给大家带来一个关于python的高阶内容——绘制混淆矩阵,一起来看下吧~
介绍:
混淆矩阵通过表示正确/不正确标签的计数来表示模型在表格格式中的准确性。
计算/绘制混淆矩阵:
以下是计算混淆矩阵的过程。
您需要一个包含预期结果值的测试数据集或验证数据集。
对测试数据集中的每一行进行预测。
从预期的结果和预测计数:
每个类的正确预测数量。
每个类的错误预测数量,由预测的类组织。
然后将这些数字组织成表格或矩阵,如下所示:
Expected down the side:矩阵的每一行都对应一个预测的类。
Predicted across the top:矩阵的每一列对应于一个实际的类。
然后将正确和不正确分类的计数填入表格中。
Reading混淆矩阵:
一个类的正确预测的总数进入该类值的预期行,以及该类值的预测列。
以同样的方式,一个类别的不正确预测总数进入该类别值的预期行,以及该类别值的预测列。
对角元素表示预测标签等于真实标签的点的数量,而非对角线元素是分类器错误标记的元素。混淆矩阵的对角线值越高越好,表明许多正确的预测。
用Python绘制混淆矩阵 :
importitertools importnumpyasnp importmatplotlib.pyplotasplt fromsklearnimportsvm,datasets fromsklearn.model_selectionimporttrain_test_split fromsklearn.metricsimportconfusion_matrix #importsomedatatoplaywith iris=datasets.load_iris() X=iris.data y=iris.target class_names=iris.target_names #Splitthedataintoatrainingsetandatestset X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0) #Runclassifier,usingamodelthatistooregularized(Ctoolow)tosee #theimpactontheresults classifier=svm.SVC(kernel='linear',C=0.01) y_pred=classifier.fit(X_train,y_train).predict(X_test) defplot_confusion_matrix(cm,classes, normalize=False, title='Confusionmatrix', cmap=plt.cm.Blues): """ Thisfunctionprintsandplotstheconfusionmatrix. Normalizationcanbeappliedbysetting`normalize=True`. """ ifnormalize: cm=cm.astype('float')/cm.sum(axis=1)[:,np.newaxis] print("Normalizedconfusionmatrix") else: print('Confusionmatrix,withoutnormalization') print(cm) plt.imshow(cm,interpolation='nearest',cmap=cmap) plt.title(title) plt.colorbar() tick_marks=np.arange(len(classes)) plt.xticks(tick_marks,classes,rotation=45) plt.yticks(tick_marks,classes) fmt='.2f'ifnormalizeelse'd' thresh=cm.max()/2. fori,jinitertools.product(range(cm.shape[0]),range(cm.shape[1])): plt.text(j,i,format(cm[i,j],fmt), horizontalalignment="center", color="white"ifcm[i,j]>threshelse"black") color="white"ifcm[i,j]>threshelse"black") plt.tight_layout() plt.ylabel('Truelabel') plt.xlabel('Predictedlabel') #Computeconfusionmatrix cnf_matrix=confusion_matrix(y_test,y_pred) np.set_printoptions(precision=2) #Plotnon-normalizedconfusionmatrix plt.figure() plot_confusion_matrix(cnf_matrix,classes=class_names, title='Confusionmatrix,withoutnormalization') #Plotnormalizedconfusionmatrix plt.figure() plot_confusion_matrix(cnf_matrix,classes=class_names,normalize=True, title='Normalizedconfusionmatrix') plt.show()
Confusionmatrix,withoutnormalization [[1300] [0106] [009]] Normalizedconfusionmatrix [[1.0.0.] [0.0.620.38] [0.0.1.]]
好了,大家可以消化学习下哦~如需了解更多python实用知识,点击进入PyThon学习网教学中心。