目录
摘 要:图像识别作为深度学习领域内的一项重要应用,水果蔬菜图像的分类识别在智慧农业以及采摘机器人等方面具有重要应用。针对以往传统图像分类算法存在泛化能力差、准确率不高等问题,提出一种在TensorFlow框架下基于深度学习和迁移学习的水果蔬菜图像分类算法。该算法采用mobileNetV2的部分模型结构对水果蔬菜图像数据集进行特征提取,采用Softmax分类器对图像特征进行分类,并通过迁移学习方式进行训练得到迁移训练模型。测试结果表明,该算法与传统水果蔬菜分类算法对比,具有较高识别准确率。
关键词:迁移学习;深度学习;图像分类;mobileNetV2;
1 引言
随着计算机视觉技术的发展,图像识别在农业、食品工业等领域具有广泛的应用前景。准确识别不同种类的水果蔬菜对于自动化分拣、库存管理等具有重要意义。本文提出了一种基于深度学习的水果蔬菜图像分类方法,通过加载预训练的mobileNetV2卷积神经网络模型,并使用迁移学习,将其实现对多种水果蔬菜图像的自动分类。
2 研究原理
2.1 MobileNetV2模型
MobileNetV2 是一种卷积神经网络架构,它是MobileNet系列中的第二代模型,于 2018 年由 Google 的研究团队提出。专门设计用于在移动设备和嵌入式设备上进行实时图像分类和其他视觉任务。
MoblieNetV2网络主要有两个亮点,使用逆残差结构和在逆残差结构的最后一层使用线性激活函数。当然,它也有许多特点,如:
1、将卷积神经网络中的普通卷积层改为使用深度分离卷积,其基本思想是将标准卷积拆分为两个分卷积:第一层称为深度卷积(depthwise convolution),对每个输入通道应用单通道的轻量级滤波器;第二层称为逐点卷积(pointwise convolution),负责计算输入通道的线性组合构建新的特征。
2、采用倒残差结构:经典的残差块(residual block)的过程是:1x1(降维)-->3x3(卷积)-->1x1(升维), 但深度卷积层(Depthwise convolution layer)提取特征限制于输入特征维度,若采用残差块,先经过1x1的逐点卷积(Pointwise convolution)操作先将输入特征图压缩(一般压缩率为0.25),再经过深度卷积后,提取的特征会更少。所以mobileNetV2是先经过1x1的逐点卷积操作将特征图的通道进行扩张,丰富特征数量,进而提高精度。这一过程刚好和残差块的顺序颠倒,这也就是倒残差的由来:1x1(升维)-->3x3(dw conv+relu)-->1x1(降维+线性变换)。倒残差结构如图1所示:
图1 MoblieNetv2的倒残差结构示意图 |
3、线性瓶颈层:对于深度可分离卷积而言, 宽度乘数压缩后的M维空间后会通过一个非线性变换ReLU,根据ReLU的性质,输入特征若为负数,该通道的特征会被清零,本来特征已经经过压缩,这会进一步损失特征信息;若输入特征是正数,经过激活层输出特征是原始的输入值,则相当于线性变换。
MobileNetV2是一种强大的架构,能够在资源有限的设备上高效运行深度学习模型MoblieNetV2网络的整体网络框架如下表所示:
表1 MobileNet网络模型框架表 |
|||||
Input |
Operator |
t |
c |
n |
s |
2242×3 |
Conv2d |
- |
32 |
1 |
2 |
1122×32 |
bottleneck |
1 |
16 |
1 |
1 |
1122×16 |
bottleneck |
6 |
24 |
2 |
2 |
562×24 |
bottleneck |
6 |
32 |
3 |
2 |
282×32 |
bottleneck |
6 |
64 |
4 |
2 |
142×64 |
bottleneck |
6 |
96 |
3 |
1 |
142×96 |
bottleneck |
6 |
160 |
3 |
2 |
72×160 |
bottleneck |
6 |
320 |
1 |
1 |
72×320 |
Con2d 1×1 |
- |
1280 |
1 |
1 |
72×1280 |
Avgpool 7×7 |
- |
- |
1 |
- |
1×1×1280 |
Conv2d 1×1 |
- |
k |
- |
2.2 迁移学习
迁移学习(Transfer Learning)通俗来讲就是学会举一反三的能力,通过运用已有的知识来学习新的知识,其核心是找到已有知识和新知识之间的相似性,通过这种相似性的迁移达到迁移学习的目的。世间万事万物皆有共性,如何合理地找寻它们之间的相似性,进而利用这个桥梁来帮助学习新知识,是迁移学习的核心问题。
在本研究中,我将预处理模型MoblieNetv2的部分层进行冻结,并在此基础上添加自定义层进行研究。
3 网络架构
3.1网络模型
1、消融实验模型
消融实验本研究采用Moblie Netv2来进行,本研究使用预处理模型不做任何处理来进行数据集训练。
2、最终确定模型
本研究采用了一个简单的迁移学习预训练卷积神经网络模型,模型的输入层接收224x224像素的彩色图像,通过全连接层和softmax激活函数输出分类结果。模型使用了ReLU激活函数和早停技术以防止过拟合。
1、预训练模型层:
模型的基础,使用的是MobileNetV2,它包含了多个卷积层和池化层,并且在此基础上解冻layer1或者冻结layer2和layer1层用于从输入图像中提取特征。
2、扁平化层:
这一层将预训练模型的输出(多维特征图)展平成一维数组。这是因为全连接层需要一维的输入。
3、第一个全连接层:
这个层有256个神经元,使用ReLU激活函数。它接收Flatten层的输出,并进行进一步的特征处理。
4、第二个全连接层:
这个层有128个神经元,同样使用ReLU激活函数。它继续对前一层的输出进行处理,减少特征维度,同时保持重要的信息。
5、输出层:
这个层有36个神经元,对应于数据集中的36个类别。使用softmax激活函数,将输出转换为概率分布,每个神经元的输出值表示模型预测属于对应类别的概率。
我的模型使用Adam优化器进行训练,损失函数为分类交叉熵(categorical cross-entropy)。
在训练过程中,我使用ReduceLROnPlateau回调函数来监控验证集上的损失。如果验证集损失在连续多个epoch(本研究中设置为15)没有改善,学习率将减小到原来的1/2。这可以帮助模型更好地收敛并找到更好的最小值。
3.2 数据预处理
1、数据增强:数据集的图像通过ImageDataGenerator类进行预处理和增强,包括缩放、旋转、平移、缩放、水平翻转、剪切和填充模式等操作。
2、标签可视化:为了更好的展示训练集不同种类的标签个数,我绘制了条形图可以更好的并且更加直观的可视化训练集标签情况,如图2所示:
![]() |
图2 训练集标签分布情况可视化 |
3、批次化:确保数据以32个样本为一批进行处理,并在每个epoch开始时随机打乱数据顺序。
4、图像大小调整:所有图像调整为224x224的统一大小,以适应模型mobileNetv2的输入要求。
5、创建DataFrame:通过遍历每个类别的文件夹,创建了一个包含所有图像路径和对应标签的DataFrame,便于后续的数据加载和处理。
4 测试与分析
4.1 数据集介绍及划分
本研究使用"Fruits and Vegetables Image"的图像识别数据集,专门用于水果和蔬菜的图像分类任务。数据集包含了36种不同的水果和蔬菜类别的图像。数据集中包含的样本种类非常丰富,涵盖了多种水果、蔬菜以及一些草本植物。具体来说,这些样本种类包括苹果、香蕉、甜菜根、甜椒、卷心菜、彩椒、胡萝卜、花椰菜、辣椒、玉米、黄瓜、茄子、大蒜、姜、葡萄、墨西哥辣椒、猕猴桃、柠檬、生菜、芒果、洋葱、橙子、辣椒粉、梨、豌豆、菠萝、石榴、土豆、萝卜、大豆、菠菜、甜玉米、甘薯、番茄、芜菁和西瓜,共计36种不同的植物。
图像被组织成三个主要的文件夹:训练集(Train)、测试集(Test)和验证集(Validation)
数据集划分:训练集:包含3115张图像,每个类别大约有100张图像。
测试集:包含359张图像,每个类别大约有10张图像。
验证集:包含351张图像,每个类别大约有10张图像。
为了更好的展示数据集,将每一类选择一张来展示。如图3所示:
图3 数据集基本展示 |
4.2训练参数设置
模型训练使用了Adam优化器,学习率为0.001,训练了30个epoch,每个epoch包含32个批次。此外,使用了ImageDataGenerator进行数据增强,以提高模型的泛化能力。
表2 训练参数
网络模型 |
参数名称 |
数值 |
Mblie Netv2网络模型 |
batch_size |
30 |
初始学习率 |
0.001 |
|
损失函数 |
分类交叉熵 |
|
输入图片大小 |
224*224 |
|
迭代次数 |
30 |
|
优化器 |
Adm |
4.3 实验结果分析
4.3.1 预训练模型结果分析
使用预训练模型,不做任何处理,模型准确率中等,但过拟合现象严重,准确率曲线如图4所示:
![]() |
图4 预训练模型acc曲线图 |
通过图4曲线图可以看出,在初始阶段(0-5 epochs),模型的训练准确率和验证准确率都相对较低,但验证准确率略高,这意味着模型在开始时对训练数据的学习还不够有效。进入中期阶段(5-10 epochs),训练准确率显著提升,而验证准确率的增长则相对平缓,这表明模型开始逐渐捕捉到数据的特征。在快速提升阶段(10-15 epochs),训练和验证准确率都迅速提高,两者之间的差距开始缩小,显示出模型在训练数据和验证数据上都取得了较好的学习效果。随着训练的进行,模型进入稳定阶段(15-25 epochs),此时训练准确率继续稳步上升,而验证准确率则在达到一定水平后趋于平稳,并在某些点上出现波动,这意味着模型在训练集上继续学习,但在验证集上的性能提升开始放缓。最后,在后期阶段(25-30 epochs),训练准确率继续保持上升趋势,而验证准确率在经历了一些波动之后也显示出上升的趋势,这表明模型在后期阶段仍在继续学习和优化,尽管这种提升可能不如之前阶段那么显著。
其中的每个单元格表示实际类别与预测类别之间的关系。通过这个热图,可以直观地看到模型在不同类别上的预测性能,例如,对角线上的值表示模型正确分类的样本数量,而非对角线的值表示模型错误分类的样本数量。这有助于识别模型在哪些类别上表现较好,哪些类别上表现较差,从而可以针对性地进行模型优化。
混淆矩阵的对角线上的数字表示正确分类的样本数量。数字越大,说明模型在该类别上的性能越好。如图5所示,"grapes"、"banana"、"beetroot" 等类别在对角线上的数字较大,表明模型在这些类别上的预测较为准确。非对角元素较多,某些类别之间的混淆更为显著。例如,"corn" 被错误分类为 "peas" 的数量为 4,而 "peas" 被错误分类为 "corn" 的数量为 2,这表明这两个类别在特征上有一定的重叠。通过观察对角线和非对角线的数字,可以识别模型在哪些类别上表现良好,在哪些类别上需要改进。以此为依据来修改和优化模型。
图5 混淆矩阵1 |
4.3.2冻结layer1结果分析
冻结预训
练模型的layer1,从loss-acc曲线图6中可以看到模型训练结果随epoch增加而改善。模型在15个epoch之前波动较大,但是模型的损失函数(loss)逐渐减小,准确率(accuracy)显著上升,到第30个epoch时,验证集的损失减少到0.0655,准确率增加到0.9791,可视化测试准确
图6 冻结layer1 loss-acc曲线图 |
从矩阵图7中可以看出,高准确率类别:大多数类别的对角线元素(正确预测的数量)都非常高,这表明模型在这些类别上的性能很好。例如,“apple”、“beetroot”、“bell pepper”、“cabbage”、“corn”、“cucumber”、“eggplant”、“garlic”、“ginger”、“kiwi”、“lemon”、“lettuce”、“mango”、“onion”、“orange”、“paprika”、“pear”、“peas”、“pineapple”、“pomegranate”、“radish”、“soy beans”、“spinach”、“sweetpotato”、“tomato”和“turnip”类别的预测准确率都是100%。 低准确率类别:一些类别的准确率较低,例如“banana”有2个样本被错误分类为“grape”,“peas”有2个样本被错误分类为“potato”,“potato”有2个样本被错误分类为“peas”,“sweetcorn”有2个样本被错误分类为“peas”。这些错误分类可能表明这些类别之间存在某些相似性,导致模型难以区分。
图7 混淆矩阵2 |
总的来说,该混淆矩阵表明分类器在大多数水果类别上的表现是令人满意的,但在某些类别之间存在混淆,需要进一步分析和改进。
4.3.3冻结layer1和layer2结果分析
如图8所示,训练损失曲线相对平稳,而验证损失曲线在前20个epoch之前都波动较大,这表明模型在训练集上表现稳定,但在验证集上过拟合或数据不稳定。训练准确率曲线稳步上升,而验证准确率曲线波动较大,这与验证损失的波动相呼应,表明模型在验证集上的表现不稳定。总体来看,验证损失和验证准确率的波动表明了模型存在过拟合的现象,即在训练集上表现良好,但在验证集上表现不稳定。
![]() |
图8 冻结layer1和layer2 loss-acc曲线图 |
混淆矩阵(Confusion Matrix)是评估分类模型性能的一个重要工具,它通过一个矩阵的形式展示了模型预测结果与实际标签之间的关系。混淆矩阵中的每一项代表实际类别与预测类别之间的匹配情况。如图9可以看出,模型可以在大多数样本上正确识别,但在"bell pepper"
"carrot" 、"corn" 、"eggplant"、"ginger" 、"grapes"、"jalepeno" 、"mango" 、"onion" 、"soy beans"、"sweetpotato" 这几个样本中存在识别混淆,且"sweetpotato"识别混淆现象尤为严重,高达5次。
图9 混淆矩阵3 |
4.4实验结果对比分析
通过对比分析冻结不同层和仅使用预处理模型的训练结果,可以清晰地看到迁移学习在提升模型性能方面的显著效果。在训练过程中,通过冻结模型的layer1或者冻结layer1和layer2,利用了预训练模型MoblieNetv2在ImageNet上学习到的通用特征,这为模型提供了一个强大的起点。
在本次研究中,我发现只冻结layer1的模型在拟合度和识别准确率上表现最佳。这是因为layer1捕捉了图像的基础特征,如边缘和纹理,这些特征对于大多数视觉任务都是至关重要的。通过冻结其他层,模型能够进一步学习到更复杂的特征,但是模型无法调整其权重以适应新数据的分布。导致模型拟合程度没有只冻结layer1好。
混淆矩阵进一步证实了这一点,它显示了模型在各个类别上的预测准确性。通过观察混淆矩阵,可以看到模型在大多数类别上都取得了很高的准确率,这说明具有良好的泛化能力。同时,曲线图也展示了模型在训练和验证集上的性能,其中只冻结layer1的模型在验证集上的表现尤为突出,这进一步证明了其优越的泛化能力。
迁移学习不仅提高了模型的准确率,还增强了其泛化能力,而只冻结layer1的策略在本研究中展现出了最佳的性能。
5 总结
本课程设计项目聚焦于利用迁移学习进行水果蔬菜分类,采用MobileNetV2预训练模型以提高分类准确率和泛化能力。通过对"Fruits and Vegetables Image"数据集的细致预处理,包括数据增强和标签可视化,我们为模型训练打下了坚实基础。研究中,通过比较不同冻结层数对模型性能的影响,发现仅冻结layer1的策略在验证集上表现最佳,准确率达到了98.14,实现了高准确率模型。
# A function to avoid tensorflow warnings
from silence_tensorflow import silence_tensorflow
silence_tensorflow()
import os
import cv2
import random
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
from termcolor import colored
import matplotlib.pyplot as plt
from keras.utils import plot_model
from tensorflow.keras import optimizers
from tensorflow.keras import models, layers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from sklearn.metrics import classification_report, confusion_matrix
sns.set_style('darkgrid')
warnings.filterwarnings('ignore')
base_dir = "E:\\桌面\\fruit-and-vegetable-image-recognition"
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
test_dir = os.path.join(base_dir, 'test')
# 检查图片文件是否能正常读取的函数,并返回损坏图片文件路径列表
def check_image_files(folder_path):
"""检查文件夹下图片文件是否能正常读取"""
bad_image_paths = []
for class_name in os.listdir(folder_path):
class_full_path = os.path.join(folder_path, class_name)
for file_name in os.listdir(class_full_path):
file_path = os.path.join(class_full_path, file_name)
try:
img = cv2.imread(file_path)
if img is None:
bad_image_paths.append(file_path)
print(f"无法读取图片 {file_path},可能已损坏,请检查")
except:
bad_image_paths.append(file_path)
print(f"读取图片 {file_path} 出现异常,请检查")
return bad_image_paths
# 检查训练、验证、测试集图片文件,获取损坏图片文件路径列表
bad_train_image_paths = check_image_files(train_dir)
bad_validation_image_paths = check_image_files(validation_dir)
bad_test_image_paths = check_image_files(test_dir)
# 可以在这里添加处理损坏图片文件的逻辑,比如删除它们等,示例如下(谨慎使用删除操作)
# for path in bad_train_image_paths:
# os.remove(path)
# 对验证集和测试集损坏图片路径同理可进行删除等操作
def num_of_classes(folder_dir, folder_name):
classes = [class_name for class_name in os.listdir(train_dir)]
print(colored(f'number of classes in {folder_name} folder : {len(classes)}', 'blue', attrs=['bold']))
num_of_classes(train_dir, 'train')
num_of_classes(validation_dir, 'validation')
num_of_classes(test_dir, 'test')
classes = [class_name for class_name in os.listdir(train_dir)]
count = []
for class_name in classes:
count.append(len(os.listdir(os.path.join(train_dir, class_name))))
plt.figure(figsize=(15, 4))
ax = sns.barplot(x=classes, y=count, color='navy')
plt.xticks(rotation=285)
for i in ax.containers:
ax.bar_label(i, )
plt.title('Number of samples per label', fontsize=25, fontweight='bold')
plt.xlabel('Labels', fontsize=15)
plt.ylabel('Counts', fontsize=15)
plt.yticks(np.arange(0, 105, 10))
plt.show()
def create_df(folder_path):
all_images = []
for class_name in classes:
class_path = os.path.join(folder_path, class_name)
all_images.extend([(os.path.join(class_path, file_name), class_name) for file_name in os.listdir(class_path)])
df = pd.DataFrame(all_images, columns=['file_path', 'label'])
return df
train_df = create_df(train_dir)
validation_df = create_df(validation_dir)
test_df = create_df(test_dir)
print(colored(f'Number of samples in train : {len(train_df)}', 'blue', attrs=['bold']))
print(colored(f'Number of samples in validation : {len(validation_df)}', 'blue', attrs=['bold']))
print(colored(f'Number of samples test : {len(test_df)}', 'blue', attrs=['bold']))
# Create a DataFrame with one Label of each category
df_unique = train_df.copy().drop_duplicates(subset=["label"]).reset_index()
# Display some pictures of the dataset
fig, axes = plt.subplots(nrows=6, ncols=6, figsize=(8, 7),
subplot_kw={'xticks': [], 'yticks': []})
# 先缓存图片数据,跳过损坏的图片文件
cached_images = []
for file_path in df_unique.file_path:
if file_path not in bad_train_image_paths:
img = cv2.imread(file_path)
if img is not None:
cached_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) # 转换为RGB格式方便显示
# 根据实际缓存到的有效图片数量调整显示逻辑
for i in range(len(cached_images)):
ax = axes.flat[i]
ax.imshow(cached_images[i])
ax.set_title(df_unique.label[i], fontsize=12)
plt.tight_layout(pad=0.5)
plt.show()
# 曲线图
# Train generator
random_seed = 42
train_datagen = ImageDataGenerator(
rescale=1. / 255, # Scaled images in range 0 to 1
rotation_range=20, # Rorate images by factor 20 degree
width_shift_range=0.2, # Shift images horizontally by up to 20% of their width
height_shift_range=0.2, # Shift images vertically by up to 20% of their width
zoom_range=0.1, # Zoom in and out images by 10%
horizontal_flip=True, # Allow horizontal flipping
shear_range=0.1, # shear images by 10% their size
fill_mode='nearest', # fill unlocated pixels by nearest pixel
)
train_generator = train_datagen.flow_from_dataframe(
dataframe=train_df, # Target data
x_col='file_path', # X column
y_col='label', # y column
target_size=(224, 224), # Resize images to
color_mode='rgb', # Color mode
class_mode='categorical', # type of model
batch_size=32,
shuffle=True,
seed=random_seed,
drop_remainder=True # 添加这个参数
)
# validation generator
validation_datagen = ImageDataGenerator(rescale=1. / 255, )
validation_generator = validation_datagen.flow_from_dataframe(
dataframe=validation_df,
x_col='file_path',
y_col='label',
target_size=(224, 224),
class_mode='categorical',
batch_size=32,
seed=random_seed,
shuffle=False
)
# Test generator
test_datagen = ImageDataGenerator(rescale=1. / 255, )
test_generator = test_datagen.flow_from_dataframe(
dataframe=test_df,
x_col='file_path',
y_col='label',
target_size=(224, 224),
class_mode='categorical',
batch_size=32,
seed=random_seed,
shuffle=False
)
pre_trained_model = MobileNetV2(
input_shape=(224, 224, 3), # Input image size
include_top=False, # model not include top layer
weights='imagenet', # weights type
pooling='avg' # type of pooling layer
)
# Name of layers in MobileNetV2
for layer in pre_trained_model.layers:
print(layer.name)
# Freeze all layers, except last layer
# The goal is to train just last layer of pre trained model
pre_trained_model.trainable = True
set_trainable = False
for layer in pre_trained_model.layers:
if layer.name == 'block_16_expand':
set_trainable = True
if set_trainable:
layer.trainable = True
else:
layer.trainable = False
# Add custom layers on top of the base model
model = models.Sequential()
model.add(pre_trained_model)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(36, activation='softmax'))
# Compile
model.compile(optimizer=optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Model CheckPoint
checkpoint_cb = ModelCheckpoint('MyModel.keras', save_best_only=True)
# Early Stoping
earlystop_cb = EarlyStopping(patience=15, restore_best_weights=True)
# ReduceLROnPlateau
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6)
history = model.fit(
train_generator,
steps_per_epoch=len(train_generator),
epochs=10,
validation_data=validation_generator,
validation_steps=len(validation_generator),
callbacks=[checkpoint_cb, earlystop_cb, reduce_lr]
)
# Convert resutl of training to a DataFrame
result_df = pd.DataFrame(history.history)
result_df.tail()
x = np.arange(len(result_df))
fig, ax = plt.subplots(3, 1, figsize=(15, 12), sharex=True)
# AX0 : Loss
ax[0].plot(x, result_df.loss, label='loss', linewidth=3)
ax[0].plot(x, result_df.val_loss, label='val_loss', linewidth=2, ls='-.', c='r')
ax[0].set_title('Loss', fontsize=20)
ax[0].legend()
# AX1 : Loss
ax[1].plot(x, result_df.accuracy, label='accuracy', linewidth=2)
ax[1].plot(x, result_df.val_accuracy, label='val_accuracy', linewidth=2, ls='-.', c='r')
ax[1].set_title('Accuracy', fontsize=20)
ax[1].legend()
# AX2 : Loss
ax[2].plot(x, result_df.lr, label='learning_rate', linewidth=2)
ax[2].set_title('learning_rate', fontsize=20)
ax[2].set_xlabel('epochs')
ax[2].legend()
plt.show()
# 混淆矩阵
# checkpoint callback, save base model weights in "MyModel.keras".
# So, we should load it
best_model = models.load_model('MyModel.keras')
# 获取测试集真实标签
test_true_labels = []
for _, label in test_df[['label']].itertuples():
test_true_labels.append(label)
# 使用加载的最佳模型进行预测a
test_pred_probs = best_model.predict(test_generator)
test_pred_labels = np.argmax(test_pred_probs, axis=1)
# 输出分类报告
print(classification_report(test_true_labels, test_pred_labels))
# 计算并输出混淆矩阵
conf_matrix = confusion_matrix(test_true_labels, test_pred_labels)
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()