# 主函数
def main():
# 假设数据集已经手动下载并解压
train_dir = "chest_xray/train"
test_dir = "chest_xray/test"
val_dir = "chest_xray/val"
# 加载数据
img_size = (150, 150)
batch_size = 32
train_generator, val_generator, test_generator = load_data(train_dir, test_dir, val_dir, img_size, batch_size)
# 处理样本不均衡
X_train, y_train_resampled, y_train_original = handle_imbalance(train_generator)
# 计算类别权重(基于原始分布)
n_normal = np.sum(y_train_original == 0)
n_pneumonia = np.sum(y_train_original == 1)
total = n_normal + n_pneumonia
weight_for_normal = (1 / n_normal) * (total / 2.0)
weight_for_pneumonia = (1 / n_pneumonia) * (total / 2.0)
class_weights = {0: weight_for_normal, 1: weight_for_pneumonia}
print(f"类别权重: 正常={weight_for_normal:.2f}, 肺炎={weight_for_pneumonia:.2f}")
# 构建模型
model = build_model((*img_size, 3))
model.summary()
# 提前停止回调
early_stopping = EarlyStopping(
monitor='val_loss',
patience=5,
restore_best_weights=True,
verbose=1
)
# 训练模型
history = model.fit(
X_train, y_train_resampled,
epochs=30,
batch_size=32,
validation_data=val_generator,
class_weight=class_weights,
callbacks=[early_stopping],
verbose=1
)
# 评估模型 - 使用完整测试集
test_generator.reset()
test_steps = len(test_generator)
test_results = model.evaluate(test_generator, steps=test_steps, verbose=1)
print("\n测试集评估结果:")
print(f"准确率: {test_results[1]:.4f}")
print(f"精确率: {test_results[2]:.4f}")
print(f"召回率: {test_results[3]:.4f}")
print(f"AUC: {test_results[4]:.4f}")
# 获取测试集所有预测结果
test_generator.reset()
y_true = []
y_pred_prob = []
for i in range(test_steps):
batch_x, batch_y = test_generator.next()
y_true.extend(batch_y)
batch_pred = model.predict(batch_x, verbose=0).ravel()
y_pred_prob.extend(batch_pred)
y_true = np.array(y_true)
y_pred_prob = np.array(y_pred_prob)
y_pred = (y_pred_prob > 0.5).astype(int)
# 计算额外指标
f1 = f1_score(y_true, y_pred)
auc = roc_auc_score(y_true, y_pred_prob)
print(f"\nF1-score: {f1:.4f}")
print(f"AUC-ROC: {auc:.4f}")
# 分类报告
print("\n分类报告:")
print(classification_report(y_true, y_pred, target_names=['NORMAL', 'PNEUMONIA']))
# 混淆矩阵
cm = confusion_matrix(y_true, y_pred)
print("混淆矩阵:")
print(cm)
# 绘制ROC曲线
fpr, tpr, _ = roc_curve(y_true, y_pred_prob)
plt.figure(figsize=(10, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('接收者操作特征曲线(ROC)')
plt.legend(loc="lower right")
plt.savefig('roc_curve.png', dpi=300)
plt.show()
# 绘制训练历史
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('准确率')
plt.legend()
plt.subplot(2, 2, 2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('损失')
plt.legend()
plt.subplot(2, 2, 3)
plt.plot(history.history['precision'], label='训练精确率')
plt.plot(history.history['val_precision'], label='验证精确率')
plt.title('精确率')
plt.legend()
plt.subplot(2, 2, 4)
plt.plot(history.history['recall'], label='训练召回率')
plt.plot(history.history['val_recall'], label='验证召回率')
plt.title('召回率')
plt.legend()
plt.tight_layout()
plt.savefig('training_history.png', dpi=300)
plt.show()
if __name__ == "__main__":
main()
D:\ProgramData\anaconda3\envs\tf_env\python.exe D:\workspace_py\deeplean\medical_image_classification.py
Found 5216 images belonging to 2 classes.
Found 16 images belonging to 2 classes.
Found 624 images belonging to 2 classes.
原始样本分布: 正常=1341, 肺炎=3875
过采样后分布: 正常=3875, 肺炎=3875
类别权重: 正常=1.94, 肺炎=0.67
2025-07-24 18:24:20.002334: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE SSE2 SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 148, 148, 32) 896
batch_normalization (Batch (None, 148, 148, 32) 128
Normalization)
max_pooling2d (MaxPooling2 (None, 74, 74, 32) 0
D)
dropout (Dropout) (None, 74, 74, 32) 0
conv2d_1 (Conv2D) (None, 72, 72, 64) 18496
batch_normalization_1 (Bat (None, 72, 72, 64) 256
chNormalization)
max_pooling2d_1 (MaxPoolin (None, 36, 36, 64) 0
g2D)
dropout_1 (Dropout) (None, 36, 36, 64) 0
conv2d_2 (Conv2D) (None, 34, 34, 128) 73856
batch_normalization_2 (Bat (None, 34, 34, 128) 512
chNormalization)
max_pooling2d_2 (MaxPoolin (None, 17, 17, 128) 0
g2D)
dropout_2 (Dropout) (None, 17, 17, 128) 0
conv2d_3 (Conv2D) (None, 15, 15, 256) 295168
batch_normalization_3 (Bat (None, 15, 15, 256) 1024
chNormalization)
max_pooling2d_3 (MaxPoolin (None, 7, 7, 256) 0
g2D)
dropout_3 (Dropout) (None, 7, 7, 256) 0
flatten (Flatten) (None, 12544) 0
dense (Dense) (None, 512) 6423040
batch_normalization_4 (Bat (None, 512) 2048
chNormalization)
dropout_4 (Dropout) (None, 512) 0
dense_1 (Dense) (None, 1) 513
=================================================================
Total params: 6815937 (26.00 MB)
Trainable params: 6813953 (25.99 MB)
Non-trainable params: 1984 (7.75 KB)
_________________________________________________________________
Epoch 1/30
243/243 [==============================] - 128s 520ms/step - loss: 0.4417 - accuracy: 0.8459 - precision: 0.8729 - recall: 0.8098 - auc: 0.9264 - val_loss: 2.2874 - val_accuracy: 0.5000 - val_precision: 0.5000 - val_recall: 1.0000 - val_auc: 0.4844
Epoch 2/30
243/243 [==============================] - 127s 521ms/step - loss: 0.2824 - accuracy: 0.8946 - precision: 0.9460 - recall: 0.8369 - auc: 0.9595 - val_loss: 7.6319 - val_accuracy: 0.5000 - val_precision: 0.5000 - val_recall: 1.0000 - val_auc: 0.5000
Epoch 3/30
243/243 [==============================] - 126s 520ms/step - loss: 0.2472 - accuracy: 0.9053 - precision: 0.9610 - recall: 0.8449 - auc: 0.9678 - val_loss: 2.1178 - val_accuracy: 0.6875 - val_precision: 0.7143 - val_recall: 0.6250 - val_auc: 0.6094
Epoch 4/30
243/243 [==============================] - 126s 517ms/step - loss: 0.2297 - accuracy: 0.9106 - precision: 0.9660 - recall: 0.8511 - auc: 0.9720 - val_loss: 23.9240 - val_accuracy: 0.5000 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.5000
Epoch 5/30
243/243 [==============================] - 126s 518ms/step - loss: 0.2216 - accuracy: 0.9103 - precision: 0.9646 - recall: 0.8519 - auc: 0.9724 - val_loss: 8.6471 - val_accuracy: 0.5000 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.5625
Epoch 6/30
243/243 [==============================] - 124s 512ms/step - loss: 0.2031 - accuracy: 0.9192 - precision: 0.9707 - recall: 0.8645 - auc: 0.9762 - val_loss: 8.0884 - val_accuracy: 0.5625 - val_precision: 1.0000 - val_recall: 0.1250 - val_auc: 0.5156
Epoch 7/30
243/243 [==============================] - 125s 512ms/step - loss: 0.1913 - accuracy: 0.9259 - precision: 0.9731 - recall: 0.8761 - auc: 0.9787 - val_loss: 20.7119 - val_accuracy: 0.5000 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.5000
Epoch 8/30
243/243 [==============================] - ETA: 0s - loss: 0.1742 - accuracy: 0.9295 - precision: 0.9768 - recall: 0.8800 - auc: 0.9813Restoring model weights from the end of the best epoch.
243/243 [==============================] - 125s 513ms/step - loss: 0.1742 - accuracy: 0.9295 - precision: 0.9768 - recall: 0.8800 - auc: 0.9813 - val_loss: 9.5482 - val_accuracy: 0.5000 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.5078
Epoch 00008: early stopping
20/20 [==============================] - 5s 239ms/step - loss: 1.2141 - accuracy: 0.8077 - precision: 0.9193 - recall: 0.7590 - auc: 0.8866
测试集评估结果:
准确率: 0.8077
精确率: 0.9193
召回率: 0.7590
AUC: 0.8866
F1-score: 0.8315
AUC-ROC: 0.8954
分类报告:
precision recall f1-score support
NORMAL 0.69 0.89 0.78 234
PNEUMONIA 0.92 0.76 0.83 390
accuracy 0.81 624
macro avg 0.80 0.82 0.80 624
weighted avg 0.83 0.81 0.81 624
混淆矩阵:
[[208 26]
[ 94 296]]
Process finished with exit code 0
胸片识别模型工作流程
这段代码是整个程序的 “主流程”,就像做一道菜的 “步骤清单”,从准备食材到最后端出菜的全过程。
1. 先找好 “食材”(数据存放的地方)
train_dir = "chest_xray/train" # 训练用的X光片放在这里 test_dir = "chest_xray/test" # 测试用的X光片放在这里 val_dir = "chest_xray/val" # 验证用的X光片放在这里 |
就像做菜前先确认:“蔬菜在冰箱,肉在冷冻室”。
2. 把 “食材” 处理成能直接用的样子
# 设定图片大小为150x150,每次拿32张图片来训练 train_generator, val_generator, test_generator = load_data(...) |
相当于 “把蔬菜洗干净、切成块”—— 把图片统一大小,分成 “训练组”“验证组”“测试组”,方便后面使用。
3. 解决 “食材数量不平衡” 的问题
# 处理样本不均衡,计算类别权重 ... class_weights = {0: weight_for_normal, 1: weight_for_pneumonia} |
比如:如果训练数据里 “正常胸片” 只有 100 张,“肺炎胸片” 有 900 张(差 9 倍),网络会偏向多学 “肺炎” 的特征,导致判断不准。
这里的操作就像:“给少数的‘正常胸片’增加‘权重’”,告诉网络:“虽然它少,但你要认真学,别偏心”。
4. 搭好 “炒菜的锅”(创建模型)
model = build_model((*img_size, 3)) # 用之前定义的函数建模型 model.summary() # 打印模型的“说明书”(有多少层、多少神经元) |
相当于 “拿出锅碗瓢盆”,准备好做菜的工具。
5. 设定 “炒菜的规则”(防止炒糊)
early_stopping = EarlyStopping(...) |
这是一个 “智能关火” 功能:如果炒了 5 次(patience=5),菜的味道(模型性能)没变好甚至变差,就自动关火(停止训练),避免 “炒糊”(模型学废了)。
6. 开始 “炒菜”(训练模型)
history = model.fit( X_train, y_train_resampled, epochs=30, # 最多炒30次 ... ) |
让模型对着 “训练用的胸片” 反复学习:每次看 32 张(batch_size=32),最多学 30 轮(epochs=30),学的时候参考前面的 “防偏心权重”,并随时用 “验证组” 数据检查学得怎么样。
7. 尝尝 “菜的味道”(测试模型)
test_results = model.evaluate(test_generator, ...) print("测试集评估结果: 准确率xxx...") |
用之前没学过的 “测试组” 胸片来检验模型:看看它判断对了多少(准确率)、漏诊了多少(召回率)、误诊了多少(精确率)。
8. 详细分析 “味道哪里好哪里差”
# 计算F1分数、画ROC曲线、混淆矩阵等 ... |
相当于 “写品尝报告”:
- 混淆矩阵:具体统计 “把正常判成肺炎”“把肺炎判成正常” 的数量;
- ROC 曲线:直观展示模型的整体判断能力;
- 训练历史图:看训练过程中 “准确率”“损失值” 的变化,判断模型有没有学好。
这段代码就是一个完整的 “胸片识别模型” 工作流程:
- 准备数据 → 2. 处理数据不平衡 → 3. 搭建模型 → 4. 设定训练规则 → 5. 训练模型 → 6. 测试模型效果 → 7. 生成详细报告。
就像从 “买菜、洗菜、炒菜到最后品尝并记录味道” 的全过程,一步不落,最终得到一个能判断 “胸片是否有肺炎” 的 AI 模型。
ROC 曲线与 AUC
ROC 曲线(Receiver Operating Characteristic Curve,受试者工作特征曲线)是一种用于评估二分类模型性能的可视化工具,通过展示模型在不同阈值下的真阳性率(TPR) 和假阳性率(FPR) 之间的关系,帮助判断模型的区分能力。
ROC 曲线的核心概念
- 基本定义
ROC 曲线以假阳性率(FPR) 为横轴,真阳性率(TPR) 为纵轴,每个点对应模型在一个特定分类阈值下的性能。
-
- 真阳性率(TPR):也称为灵敏度(Sensitivity)或召回率(Recall),表示实际为阳性的样本中被模型正确预测为阳性的比例,公式为:
TPR = \frac{TP}{TP + FN}
(TP:真阳性;FN:假阴性)
-
- 假阳性率(FPR):表示实际为阴性的样本中被模型错误预测为阳性的比例,公式为:
FPR = \frac{FP}{FP + TN}
(FP:假阳性;TN:真阴性)
- 曲线的绘制逻辑
二分类模型通常会输出样本属于阳性类别的概率(如 0-1 之间的分数),通过调整分类阈值(如 “概率≥0.5 则为阳性”),可以得到不同的 TPR 和 FPR 组合,将这些组合连成线即为 ROC 曲线。
ROC 曲线的意义
- 模型区分能力:曲线越靠近左上角(TPR 高、FPR 低),说明模型在正确识别阳性样本的同时,能有效减少对阴性样本的误判,性能越好。
- 理想情况:曲线经过左上角(TPR=1,FPR=0),表示模型完美区分正负样本。
- 随机猜测:曲线为一条从原点到(1,1)的对角线(TPR=FPR),说明模型无区分能力。
- 阈值选择参考:曲线上的每个点对应一个阈值,可根据实际需求(如更重视减少漏诊还是误诊)选择合适的阈值。例如,医疗场景中可能需要高 TPR(避免漏诊),即使 FPR 稍高。
AUC:ROC 曲线下面积
为了量化 ROC 曲线的性能,引入AUC(Area Under the ROC Curve,曲线下面积):
- 取值范围:0~1 之间。
- AUC=1:模型完美分类,ROC 曲线经过左上角。
- AUC=0.5:模型性能与随机猜测一致(对角线)。
- AUC<0.5:模型性能差于随机猜测,通常可通过反转预测结果改善。
- 意义:AUC 越大,模型区分正负样本的能力越强。例如,AUC=0.8 表示模型有 80% 的概率将阳性样本的预测分数高于阴性样本。
ROC 曲线的适用场景
- 二分类问题(如疾病诊断、垃圾邮件识别)。
- 对不平衡数据集(正负样本比例悬殊)有一定稳健性,因为 FPR 和 TPR 均基于实际类别计算,不受样本比例影响。
示例:ROC 曲线与 AUC
假设两个模型的 ROC 曲线如下:
- 模型 A 的曲线更靠近左上角,AUC=0.9。
- 模型 B 的曲线接近对角线,AUC=0.6。
则模型 A 的分类性能显著优于模型 B。
总之,ROC 曲线和 AUC 是评估二分类模型性能的重要工具,尤其适用于需要权衡 “正确识别阳性” 和 “错误识别阴性” 的场景。
评价与改进
D:\ProgramData\anaconda3\envs\tf_env\python.exe D:\workspace_py\deeplean\medical_image_classification.py
Found 5216 images belonging to 2 classes.
Found 16 images belonging to 2 classes.
Found 624 images belonging to 2 classes.
原始样本分布: 正常=1341, 肺炎=3875
过采样后分布: 正常=3875, 肺炎=3875
类别权重: 正常=1.94, 肺炎=0.67
2025-07-24 18:24:20.002334: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE SSE2 SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 148, 148, 32) 896
batch_normalization (Batch (None, 148, 148, 32) 128
Normalization)
max_pooling2d (MaxPooling2 (None, 74, 74, 32) 0
D)
dropout (Dropout) (None, 74, 74, 32) 0
conv2d_1 (Conv2D) (None, 72, 72, 64) 18496
batch_normalization_1 (Bat (None, 72, 72, 64) 256
chNormalization)
max_pooling2d_1 (MaxPoolin (None, 36, 36, 64) 0
g2D)
dropout_1 (Dropout) (None, 36, 36, 64) 0
conv2d_2 (Conv2D) (None, 34, 34, 128) 73856
batch_normalization_2 (Bat (None, 34, 34, 128) 512
chNormalization)
max_pooling2d_2 (MaxPoolin (None, 17, 17, 128) 0
g2D)
dropout_2 (Dropout) (None, 17, 17, 128) 0
conv2d_3 (Conv2D) (None, 15, 15, 256) 295168
batch_normalization_3 (Bat (None, 15, 15, 256) 1024
chNormalization)
max_pooling2d_3 (MaxPoolin (None, 7, 7, 256) 0
g2D)
dropout_3 (Dropout) (None, 7, 7, 256) 0
flatten (Flatten) (None, 12544) 0
dense (Dense) (None, 512) 6423040
batch_normalization_4 (Bat (None, 512) 2048
chNormalization)
dropout_4 (Dropout) (None, 512) 0
dense_1 (Dense) (None, 1) 513
=================================================================
Total params: 6815937 (26.00 MB)
Trainable params: 6813953 (25.99 MB)
Non-trainable params: 1984 (7.75 KB)
_________________________________________________________________
Epoch 1/30
243/243 [==============================] - 128s 520ms/step - loss: 0.4417 - accuracy: 0.8459 - precision: 0.8729 - recall: 0.8098 - auc: 0.9264 - val_loss: 2.2874 - val_accuracy: 0.5000 - val_precision: 0.5000 - val_recall: 1.0000 - val_auc: 0.4844
Epoch 2/30
243/243 [==============================] - 127s 521ms/step - loss: 0.2824 - accuracy: 0.8946 - precision: 0.9460 - recall: 0.8369 - auc: 0.9595 - val_loss: 7.6319 - val_accuracy: 0.5000 - val_precision: 0.5000 - val_recall: 1.0000 - val_auc: 0.5000
Epoch 3/30
243/243 [==============================] - 126s 520ms/step - loss: 0.2472 - accuracy: 0.9053 - precision: 0.9610 - recall: 0.8449 - auc: 0.9678 - val_loss: 2.1178 - val_accuracy: 0.6875 - val_precision: 0.7143 - val_recall: 0.6250 - val_auc: 0.6094
Epoch 4/30
243/243 [==============================] - 126s 517ms/step - loss: 0.2297 - accuracy: 0.9106 - precision: 0.9660 - recall: 0.8511 - auc: 0.9720 - val_loss: 23.9240 - val_accuracy: 0.5000 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.5000
Epoch 5/30
243/243 [==============================] - 126s 518ms/step - loss: 0.2216 - accuracy: 0.9103 - precision: 0.9646 - recall: 0.8519 - auc: 0.9724 - val_loss: 8.6471 - val_accuracy: 0.5000 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.5625
Epoch 6/30
243/243 [==============================] - 124s 512ms/step - loss: 0.2031 - accuracy: 0.9192 - precision: 0.9707 - recall: 0.8645 - auc: 0.9762 - val_loss: 8.0884 - val_accuracy: 0.5625 - val_precision: 1.0000 - val_recall: 0.1250 - val_auc: 0.5156
Epoch 7/30
243/243 [==============================] - 125s 512ms/step - loss: 0.1913 - accuracy: 0.9259 - precision: 0.9731 - recall: 0.8761 - auc: 0.9787 - val_loss: 20.7119 - val_accuracy: 0.5000 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.5000
Epoch 8/30
243/243 [==============================] - ETA: 0s - loss: 0.1742 - accuracy: 0.9295 - precision: 0.9768 - recall: 0.8800 - auc: 0.9813Restoring model weights from the end of the best epoch.
243/243 [==============================] - 125s 513ms/step - loss: 0.1742 - accuracy: 0.9295 - precision: 0.9768 - recall: 0.8800 - auc: 0.9813 - val_loss: 9.5482 - val_accuracy: 0.5000 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.5078
Epoch 00008: early stopping
20/20 [==============================] - 5s 239ms/step - loss: 1.2141 - accuracy: 0.8077 - precision: 0.9193 - recall: 0.7590 - auc: 0.8866
测试集评估结果:
准确率: 0.8077
精确率: 0.9193
召回率: 0.7590
AUC: 0.8866
F1-score: 0.8315
AUC-ROC: 0.8954
分类报告:
precision recall f1-score support
NORMAL 0.69 0.89 0.78 234
PNEUMONIA 0.92 0.76 0.83 390
accuracy 0.81 624
macro avg 0.80 0.82 0.80 624
weighted avg 0.83 0.81 0.81 624
混淆矩阵:
[[208 26]
[ 94 296]]
Process finished with exit code 0
胸部X光肺炎分类模型优化
我会从模型性能、实际应用考量、当前存在问题、改进建议等方面进行分析,以判断该模型是否达到实用水平。
模型性能分析
- ROC 曲线与 AUC 值:
- 从 ROC 曲线来看,橙色的 ROC 曲线明显高于对角线(随机猜测线),表明模型具有一定的区分能力。
- AUC 值为 0.8954,说明模型在区分正常胸片和肺炎胸片时,有 89.54% 的概率将阳性样本(肺炎)的预测分数高于阴性样本(正常),这是一个相对较好的结果,表明模型在整体上有较好的分类性能。
- 准确率、精确率、召回率和 F1 分数:
- 测试集准确率为 0.8077,即模型在测试集上的分类正确率约为 80.77%。
- 精确率(precision)为 0.9193,意味着当模型预测为肺炎时,有 91.93% 的概率是正确的,这在一定程度上保证了预测结果的可靠性。
- 召回率(recall)为 0.7590,表示在所有实际患有肺炎的样本中,模型能够正确识别出 75.90% 的肺炎病例。
- F1 分数为 0.8315,综合考虑了精确率和召回率,是一个较为均衡的指标,表明模型在精确率和召回率之间取得了较好的平衡。
- 混淆矩阵:
- 混淆矩阵显示,模型将 208 个正常样本正确分类为正常,26 个正常样本错误分类为肺炎;将 94 个肺炎样本错误分类为正常,296 个肺炎样本正确分类为肺炎。
- 假阴性(FN=94)表示有 94 个肺炎患者被错误地诊断为正常,这在医疗场景中是比较严重的问题,可能会导致患者延误治疗。
- 假阳性(FP=26)表示有 26 个正常患者被错误地诊断为肺炎,可能会给患者带来不必要的担忧和进一步的检查。
实际应用考量
- 优势:
- 该模型在一定程度上能够辅助医生进行肺炎的诊断,特别是在处理大量胸部 X 光图像时,可以提高诊断效率。
- 较高的精确率可以减少不必要的进一步检查,减轻患者的负担。
- 问题:
- 召回率相对较低(75.90%),意味着仍有较多的肺炎病例被漏诊,这在医疗领域是不可接受的,因为漏诊可能会导致严重的后果。
- 假阴性和假阳性的存在都可能对患者的诊断和治疗产生不利影响。
当前存在的问题
- 数据不平衡:
- 原始样本中正常样本和肺炎样本的数量不平衡(正常 = 1341,肺炎 = 3875),尽管进行了过采样处理,但可能仍然对模型的性能产生一定的影响。
- 模型泛化能力:
- 验证集的性能波动较大(如验证准确率在不同 epoch 中变化较大),可能表明模型的泛化能力不够强,在面对新的、未见过的数据时可能会出现性能下降的情况。
- 损失函数不稳定:
- 验证损失在训练过程中波动较大,说明模型的训练过程可能不够稳定,需要进一步优化训练策略。
改进方向
- 数据增强:
- 进一步增加数据增强的方式,如旋转、翻转、缩放等,以增加训练数据的多样性,提高模型的泛化能力。
- 调整模型结构:
- 尝试调整模型的网络结构,如增加或减少卷积层、全连接层的数量,调整卷积核的大小等,以找到更适合该任务的模型结构。
- 优化训练策略:
- 调整学习率、批次大小等训练参数,使用学习率衰减、早停等策略,以提高模型的训练稳定性和性能。
- 处理类别不平衡:
- 除了过采样,还可以考虑使用类别权重、欠采样等方法来处理类别不平衡问题,以提高模型对少数类(正常样本)的识别能力。
- 集成学习:
- 考虑使用集成学习方法,如随机森林、梯度提升树等,将多个模型的预测结果进行组合,以提高模型的整体性能。
该模型在胸部 X 光图像肺炎分类任务上取得了一定的成果,AUC 值和精确率较高,但召回率相对较低,假阴性和假阳性的数量也需要进一步减少。虽然模型在某些指标上表现较好,但在医疗领域,漏诊和误诊的代价较高,因此该模型目前尚未达到实用水平。需要进一步优化模型,提高召回率和稳定性,减少假阴性和假阳性的数量,才能在实际临床应用中发挥更大的作用。
后面我抽时间搞个改进版本出来~~~~ 各位~~~~