柚子快報(bào)邀請(qǐng)碼778899分享:機(jī)器學(xué)習(xí) 算法 混淆矩陣的生成
柚子快報(bào)邀請(qǐng)碼778899分享:機(jī)器學(xué)習(xí) 算法 混淆矩陣的生成
混淆矩陣簡(jiǎn)介
混淆矩陣(Confusion Matrix)是一個(gè)二維表格,常用于評(píng)價(jià)分類模型的性能。在混淆矩陣中,每一列代表了預(yù)測(cè)值,每一行代表了真實(shí)值。因此,混淆矩陣中的每一個(gè)元素表示了一個(gè)樣本被預(yù)測(cè)為某一類別的次數(shù)?;煜仃嚨臉?gòu)成如下:
預(yù)測(cè)值=正例預(yù)測(cè)值=反例真實(shí)值=正例TPFN真實(shí)值=反例FPTN
其中,TP表示真正例(True Positive),F(xiàn)N表示假反例(False Negative),F(xiàn)P表示假正例(False Positive),TN表示真反例(True Negative)。
解釋如下:
TP:真正例,指的是模型將正例預(yù)測(cè)為正例的次數(shù); FN:假反例,指的是模型將正例預(yù)測(cè)為反例的次數(shù); FP:假正例,指的是模型將反例預(yù)測(cè)為正例的次數(shù); TN:真反例,指的是模型將反例預(yù)測(cè)為反例的次數(shù)。 混淆矩陣的重要性在于,可以通過計(jì)算其中的四個(gè)元素,得到各種評(píng)價(jià)指標(biāo),如精確度(Accuracy)、召回率(Recall)、準(zhǔn)確率(Precision)和 F1 值等。
精確度(Accuracy):表示模型預(yù)測(cè)正確的樣本數(shù)與總樣本數(shù)之比,即
A
c
c
u
r
a
c
y
=
T
P
+
T
N
T
P
+
F
P
+
F
N
+
T
N
Accuracy = \frac{TP+TN}{TP+FP+FN+TN}
Accuracy=TP+FP+FN+TNTP+TN?; 召回率(Recall):表示模型正確預(yù)測(cè)正例樣本的比例,即
R
e
c
a
l
l
=
T
P
T
P
+
F
N
Recall = \frac{TP}{TP+FN}
Recall=TP+FNTP?; 準(zhǔn)確率(Precision):表示模型預(yù)測(cè)為正例的樣本中,真正例的比例,即
P
r
e
c
i
s
i
o
n
=
T
P
T
P
+
F
P
Precision = \frac{TP}{TP+FP}
Precision=TP+FPTP?; F1 值:綜合了準(zhǔn)確率和召回率,即
F
1
=
2
×
P
r
e
c
i
s
i
o
n
×
R
e
c
a
l
l
P
r
e
c
i
s
i
o
n
+
R
e
c
a
l
l
F1 = \frac{2\times Precision\times Recall}{Precision+Recall}
F1=Precision+Recall2×Precision×Recall?。 混淆矩陣也可以可視化,可以使用熱力圖等圖形來展示混淆矩陣中每個(gè)元素的數(shù)值大小,以便更加直觀地理解分類模型的性能。
混淆矩陣的主要作用和意義如下:
評(píng)估分類器的性能:混淆矩陣可以幫助我們計(jì)算分類器的準(zhǔn)確率、召回率、精確率、F1分?jǐn)?shù)等指標(biāo),從而評(píng)估分類器的性能。
比較不同分類器的性能:混淆矩陣可以幫助我們比較不同分類器的性能,找出最優(yōu)的分類器。
識(shí)別分類器的錯(cuò)誤類型:混淆矩陣可以幫助我們了解分類器在哪些情況下容易出錯(cuò),識(shí)別出分類器的錯(cuò)誤類型,從而針對(duì)性地改進(jìn)分類器。
優(yōu)化分類器的閾值:混淆矩陣可以幫助我們優(yōu)化分類器的閾值,從而提高分類器的性能。
可視化分類器的性能:混淆矩陣可以將分類器的性能可視化,從而更直觀地了解分類器的性能。
混淆矩陣可視化代碼:
import os
from matplotlib.font_manager import FontProperties
import itertools
import matplotlib.pyplot as plt
import numpy as np
# 繪制混淆矩陣
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
"""
- cm : 計(jì)算出的混淆矩陣的值
- classes : 混淆矩陣中每一行每一列對(duì)應(yīng)的列
- normalize : True:顯示百分比, False:顯示個(gè)數(shù)
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("顯示百分比:")
np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
print(cm)
else:
print('顯示具體數(shù)字:')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
# matplotlib版本問題,如果不加下面這行代碼,則繪制的混淆矩陣上下只能顯示一半,有的版本的matplotlib不需要下面的代碼,分別試一下即可
plt.ylim(len(classes) - 0.5, -0.5)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
cnf_matrix = np.array([[151, 64, 731, 164, 45],
[821, 653, 79, 0, 28],
[266, 167, 423, 4, 2],
[691, 0, 107, 776, 26],
[30, 0, 111, 17, 42]])
attack_types = ['Normal', 'DoS', 'Probe', 'R2L', 'U2R']
# 歸一化
# plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Confusion matrix')
# 不歸一化
plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Confusion matrix')
其中上述有兩種方式可以選擇,即一種是歸一化,一種是不歸一化 歸一化設(shè)置 normalize=True 結(jié)果為: 不歸一化設(shè)置 normalize=False 結(jié)果為:
如果想要配合模型生成混淆矩陣,則需要讓神經(jīng)生成一個(gè)混淆矩陣的矩陣序列代碼為:
import os
import json
import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from model import MobileNetV2
class ConfusionMatrix(object):
"""
注意,如果顯示的圖像不全,是matplotlib版本問題
本例程使用matplotlib-3.2.1(windows and ubuntu)繪制正常
需要額外安裝prettytable庫(kù)
"""
def __init__(self, num_classes: int, labels: list):
self.matrix = np.zeros((num_classes, num_classes))
self.num_classes = num_classes
self.labels = labels
def update(self, preds, labels):
for p, t in zip(preds, labels):
self.matrix[p, t] += 1
def plot(self, normalize=False):
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("顯示百分比:")
np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
print(cm)
else:
print('顯示具體數(shù)字:')
print(cm)
matrix = self.matrix
plt.imshow(matrix , interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
# matplotlib版本問題,如果不加下面這行代碼,則繪制的混淆矩陣上下只能顯示一半,有的版本的matplotlib不需要下面的代碼,分別試一下即可
plt.ylim(len(classes) - 0.5, -0.5)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
if __name__ == '__main__':
mylabel = {"4": "4", "5": "5", "6": "6"}
num_classes=3 #################################
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
ROOT_DATA = r'D:/other/ClassicalModel/data/flower_datas' #################################
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
validate_dataset = datasets.ImageFolder(root=os.path.join(ROOT_DATA, "val"),
transform=data_transform)
batch_size = 16
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=2)
net = MobileNetV2(num_classes=num_classes) ###########################
# load pretrain weights
model_weight_path = r"D:/other/ClassicalModel/MobileNet/runs1/mobilenet_v2.pth" #########################
assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)
net.load_state_dict(torch.load(model_weight_path, map_location=device))
net.to(device)
labels = [label for _, label in mylabel.items()]
confusion = ConfusionMatrix(num_classes=num_classes, labels=labels)
net.eval()
with torch.no_grad():
for val_data in tqdm(validate_loader):
val_images, val_labels = val_data
outputs = net(val_images.to(device))
outputs = torch.softmax(outputs, dim=1)
outputs = torch.argmax(outputs, dim=1)
# print('outputs++'+str(outputs.to("cpu").numpy())+'val_labels++'+str(val_labels.numpy()))
confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())
confusion.plot()
其中*多的地方需要自行修改,例如
ROOT_DATA = r'D:/other/ClassicalModel/data/flower_datas' #################################
在這里進(jìn)行數(shù)據(jù)集的修改
mylabel = {"4": "4", "5": "5", "6": "6"}
進(jìn)行標(biāo)簽的修改
net = MobileNetV2(num_classes=3) ###########################
在這里進(jìn)行網(wǎng)絡(luò)修改
model_weight_path = r"D:/other/ClassicalModel/MobileNet/runs1/mobilenet_v2.pth" #########################
在這里進(jìn)行本地模型權(quán)重的修改
柚子快報(bào)邀請(qǐng)碼778899分享:機(jī)器學(xué)習(xí) 算法 混淆矩陣的生成
文章鏈接
本文內(nèi)容根據(jù)網(wǎng)絡(luò)資料整理,出于傳遞更多信息之目的,不代表金鑰匙跨境贊同其觀點(diǎn)和立場(chǎng)。
轉(zhuǎn)載請(qǐng)注明,如有侵權(quán),聯(lián)系刪除。