卷积神经网络项目实现文档
1、项目简介
1.1 项目名称
基于CNN实现心律失常(ECG)的小颗粒度分类系统
1.2 项目简介
心律失常是临床上常见且潜在致命的心血管疾病之一,包括房性早搏(PAC)、室性早搏(PVC)、心动过速等多种类型。传统的心电图(ECG)分析依赖医生人工判读,耗时长、主观性强,尤其在面对长时间动态心电监测(如 24 小时 Holter)数据时,极易出现漏诊或误诊。
本项目旨在利用卷积神经网络(CNN)对MIT-BIH心律失常数据库中的ECG信号进行细粒度分类,识别五种常见的心律失常类型:正常心跳(N)、室上性早搏(S)、室性早搏(V)、融合波(F)和未知心跳(Q)。由于不同类别的ECG信号在形态上差异细微,且存在严重的类别不平衡问题,传统的机器学习方法难以取得理想效果。
本项目采用深度学习中的CNN模型,充分利用卷积层对局部时序特征的提取能力。项目涵盖数据预处理、模型构建、训练优化、性能评估及模型部署全流程,并探索数据重采样、标准化、迁移学习等关键技术手段,最终实现高精度、可部署的心律失常自动识别系统。该系统可应用于:
- 院内心电监护报警系统
- 远程健康监测平台
- 可穿戴设备(如智能手表)的异常心律预警
- 医学教学与训练辅助工具
1.3 技术选择
为什么选择1D-CNN?
本项目采用 一维卷积神经网络(1D-CNN) 作为核心模型,主要基于以下几点考虑:
优势 | 说明 |
---|---|
保留时序结构 | ECG 信号本质上是一维时间序列,1D-CNN 能直接在原始信号上进行卷积操作,保留完整的时序信息,避免特征工程带来的信息损失。 |
自动特征提取 | CNN 能自动学习 QRS 波群、P 波、T 波等关键形态特征,无需手动设计特征(如 RR 间期、波幅等),提升模型泛化能力。 |
局部感知能力 | 卷积核具有局部感受野,能有效捕捉 ECG 中局部波形变化(如 R 波突起、ST 段抬高),对异常心跳(如 PVC 的宽大畸形 QRS)敏感。 |
参数效率高 | 相比 RNN/LSTM,1D-CNN 训练更快、更稳定,适合部署在边缘设备或实时系统中。 |
成功先例 | 在 MIT-BIH 心律失常数据库上的多项研究(如 Kiranyaz et al., 2016; Acharya et al., 2017)已验证 1D-CNN 在 ECG 分类任务中的优越性能。 |
为什么不选 RNN 或 Transformer?
虽然 RNN 能建模长期依赖,但 ECG 心跳分类主要依赖局部波形特征而非长序列依赖。RNN 训练慢、易梯度消失;Transformer 在短序列上无明显优势且计算开销大。因此,1D-CNN 是精度与效率的最优平衡。
2、数据
2.1 公开数据集
本项目使用国际公认的标准心律失常数据库:MIT-BIH Arrhythmia Database,该数据集由美国麻省理工学院(MIT)与贝斯以色列医院(Beth Israel Hospital)联合发布,是心电图自动分析领域最广泛使用的基准数据集之一。
名称:MIT-BIH Arrhythmia Database
来源:Kaggle - MIT-BIH Arrhythmia Database
内容:
mitbih_train.csv:训练集,共 109,446 条样本
mitbih_test.csv:测试集,共 21,892 条样本
格式说明:
每行表示一个心跳周期的 ECG 信号,共 187 个时间点
最后一列为类别标签(0~4),对应五种心律类型
标签 | 类别 | 描述 |
---|---|---|
0 | N | 正常心跳(Normal Beat) |
1 | S | 室上性早搏(Supraventricular Premature) |
2 | V | 室性早搏(Ventricular Premature) |
3 | F | 融合波(Fusion Beat) |
4 | Q | 未知心跳(Unclassifiable Beat) |
可视ECG信号:
2.2 数据分析与清洗
类别分布分析:训练集中类别严重不平衡,N类占比超过80%,V类仅占约5%。
处理方式:
- 对训练集使用 SMOTE 过采样,平衡各类别样本数量
- 测试集保持原始分布,用于真实性能评估
- 移除异常值(如全零信号)
2.3 数据预处理
标准化:使用 StandardScaler 对每个信号进行标准化
维度重塑:将 (N, 187) 转为 (N, 187, 1),适配1D-CNN输入
2.4 数据分割
训练集:mitbih_train.csv → 用于模型训练
测试集:mitbih_test.csv → 用于最终性能评估
验证集:从训练集中划分20%用于调参
数据处理(清洗+预处理+分割):
'''
数据预处理
'''
# data_preprocess.py
import joblib
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE
import pandas as pd
import os
import numpy as np
# 加载 CSV 文件 df.shape = (109446, 188) 109446个样本,187个特征 + 1个标签
df = pd.read_csv('./data/archive/mitbih_train.csv', header=None)
# 数据预处理
# 分离data和label
X = df.iloc[:, :-1].values
y = df.iloc[:, -1].values
# 标准化数据
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 重采样 解决样本不均衡问题
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_scaled, y)
# 划分训练集和测试集
X_train, X_val, y_train, y_val = train_test_split(
X_resampled,
y_resampled,
test_size=0.2,
random_state=42,
# 确保训练集和测试集的标签分布一致
stratify=y_resampled
)
# 修改数据维度 (样本数, 时间步, 特征数) X_train.shape = (87371, 187, 1)
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_val = X_val.reshape(X_val.shape[0], X_val.shape[1], 1)
print(X_train.shape, X_val.shape, y_train.shape, y_val.shape) # (87371, 187, 1) (21927, 187, 1) (87371,) (21927,)
# 保存预处理结果
# 创建目录 exist_ok=True 表示如果目录已经存在,则不报错
os.makedirs('./data/processed_data', exist_ok=True)
# 保存处理后的 numpy 数组
np.save('./data/processed_data/X_train.npy', X_train)
np.save('./data/processed_data/X_val.npy', X_val)
np.save('./data/processed_data/y_train.npy', y_train)
np.save('./data/processed_data/y_val.npy', y_val)
# 保存 StandardScaler (重要!推理时要用)
joblib.dump(scaler, './data/processed_data/scaler.pkl')
# 保存 SMOTE 对象 (用于分析)
joblib.dump(smote, './data/processed_data/smote.pkl')
print('数据处理完成!')
3. 神经网络
为实现对 ECG 心跳信号的自动分类,本项目设计并实现了一个轻量级的一维卷积神经网络(1D-CNN),命名为 ECGCNN
。该模型专为处理长度为 187 的单导联心电信号设计,能够在保持较高精度的同时满足实时性要求。
3.1 模型架构设计
模型整体结构由 3 个卷积块 + 2 个全连接层组成,采用“卷积提取特征 → 展平 → 分类”的经典流程。每一层的设计均针对 ECG 信号特点进行优化:
- 输入格式适配:支持
(N, 187)
或(N, 187, 1)
格式的输入,自动转换为 PyTorch 所需的(N, C, L)
格式(即(batch_size, channels, sequence_length)
)。 - 多层卷积提取局部特征:使用 1D 卷积核捕捉 QRS 波群、ST 段等关键形态特征。
- BatchNorm + ReLU + Dropout:提升训练稳定性、加速收敛,并防止过拟合。
- 动态全连接层尺寸计算:通过虚拟输入自动推导展平后的维度,增强模型灵活性。
3.2 自定义网络实现(model_self.py)
以下是基于 PyTorch 实现的完整模型定义:
'''
模型构建:基于 PyTorch 的 1D-CNN 模型
'''
# model_self.py
import torch
import torch.nn as nn
class ECGCNN(nn.Module):
"""
基于 PyTorch 的 1D-CNN 模型,用于 ECG 心跳分类
"""
def __init__(self, input_shape = (187,1),num_classes=5, dropout_rate=0.5):
super(ECGCNN, self).__init__()
# 注意:PyTorch 的 Conv1D 输入是 (N, C, L)
input_channels = input_shape[1] if len(input_shape) == 2 else 1 # 默认 1 通道
# 卷积层
self.conv1 =nn.Sequential(
nn.Conv1d(input_channels, 32, kernel_size=5, stride=1),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2)
)
self.conv2 = nn.Sequential(
nn.Conv1d(32, 64, kernel_size=3, stride=1),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2)
)
self.conv3 = nn.Sequential(
nn.Conv1d(64, 128, kernel_size=3, stride=1),
nn.ReLU(),
nn.Dropout(0.3),
nn.MaxPool1d(kernel_size=2)
)
# nn.Flatten() 的作用:将输入的维度进行展平,方便全连接层处理
self.flatten = nn.Flatten()
# 计算展平 flatten 之后的维度
dummy_input = torch.randn(1, input_channels,187) # 创建一个虚拟输入张量
with torch.no_grad(): # 禁用梯度计算,以节省内存和计算资源
features = self._forward_features(dummy_input) # 通过前向传播计算展平后的维度
flatten_dim = features.view(1, -1).shape[1] # 计算展平后的维度
self.fc1 = nn.Sequential(
nn.Linear(flatten_dim, 128),
nn.ReLU(),
nn.Dropout(dropout_rate)
)
self.fc2 = nn.Sequential(
nn.Linear(128, num_classes),
# nn.Softmax(dim=1) # dim=1 表示对列进行归一化
)
def _forward_features(self, x):
'''
用于计算全连接层输入维度的辅助函数
'''
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x
def forward(self, x):
"""
前向传播(PyTorch 中叫 forward,不是 call)
:param x: 输入张量 (N, C, L)
:return: 输出张量 (N, num_classes)
"""
# PyTorch 输入格式: (batch_size, channels, sequence_length)
# 所以需要把 (N, 187, 1) 转成 (N, 1, 187)
if x.dim() == 3 and x.shape[2] == 1: # 如果是(N, 187, 1) 则转成(N, 1, 187)
x = x.permute(0, 2, 1) # (N, 187, 1) -> (N, 1, 187)
elif x.dim() == 2: # 如果是(N, 187) 则转成(N, 1, 187)
x = x.unsqueeze(1) # (N, 187) -> (N, 1, 187)
# 卷积块
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
# 全连接层
x = self.flatten(x) #展平
x = self.fc1(x)
x = self.fc2(x)
return x
if __name__ == '__main__':
model = ECGCNN(num_classes=5)
print(model)
# 测试前向传播
x = torch.randn(32, 187, 1) # 32个样本,187个特征,1个通道
with torch.no_grad():
output = model(x)
print("输出x的形状:",x.shape)
print("输出形状:", output.shape) # (32, 5)
4. 模型训练
4.1 训练参数
为确保模型有效学习且不过拟合,采用以下训练配置:
参数 | 值 | 说明 |
---|---|---|
轮次(Epochs) | 10 | 在验证集性能趋于稳定后停止,避免过拟合 |
批次大小(Batch Size) | 32 | 平衡梯度稳定性与内存占用 |
学习率(Learning Rate) | 0.0001 | 使用 Adam 优化器时的常用小学习率,保证收敛稳定 |
设备 | CPU | 当前运行环境为 CPU,未来可扩展支持 GPU 加速 |
4.2 损失函数
采用 交叉熵损失函数(Cross-Entropy Loss):
criterion = nn.CrossEntropyLoss()
该函数结合了 LogSoftmax
和 NLLLoss
,适用于多分类任务。它衡量模型输出概率分布与真实标签之间的差异,是分类任务中最常用的损失函数之一。
4.3 优化器
使用 Adam 优化器:
optimizer = optim.Adam(model.parameters(), lr=0.0001)
Adam 结合了动量(Momentum)和自适应学习率(RMSProp)的优点,具有收敛快、鲁棒性强的特点,特别适合深度神经网络的训练。
4.4 训练过程可视化
为监控训练动态,使用 TensorBoard 进行可视化,记录每轮的训练/验证损失与准确率。
训练过程指标走势图
训练准确率和损失走势图 (如下)
验证准确率和损失走势图(如下)
从图中可以看出:
- 训练损失持续下降,训练准确率稳步上升,表明模型正在有效学习;
- 验证损失先降后趋于平稳,未出现明显回升,说明模型未严重过拟合;
- 最终验证准确率可达 90% 以上(具体数值依运行结果而定),表现出良好的分类能力。
网络结构图
该图为 TensorBoard 自动生成的计算图,清晰展示了数据流动路径和各层连接关系。
训练脚本(train.py)
以下是完整的训练流程实现,包含数据加载、模型定义、训练循环、验证评估与结果保存:
'''
模型训练
'''
# train.py
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader,TensorDataset
from model_self import ECGCNN
from torch.utils.tensorboard import SummaryWriter
# tensorboard 可视化操作
writer = SummaryWriter()
# 加载预处理数据
X_train = np.load('./data/processed_data/X_train.npy')
X_val = np.load('./data/processed_data/X_val.npy')
y_train = np.load('./data/processed_data/y_train.npy')
y_val = np.load('./data/processed_data/y_val.npy')
# 转为PyTorch张量
X_train = torch.tensor(X_train, dtype=torch.float32)
X_val = torch.tensor(X_val, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_val = torch.tensor(y_val, dtype=torch.long)
# print("Loaded data",X_train.shape,y_train.shape)
# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 设备选择val
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 创建模型 、损失函数、优化器
model = ECGCNN(num_classes=5)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
model.to(device)
criterion.to(device)
# 可视化网络结构
with torch.no_grad():
# X_train 是 (N, 187, 1)
sample_input = torch.zeros(1, X_train.shape[1], X_train.shape[2]).to(device) # (1, 187, 1)
writer.add_graph(model, sample_input)
# 记录训练过程中准确率最高的准确率
best_acc = 0.0
# 定义history字典,用于保存训练过程中损失和准确率
history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
for epoch in range(10):
# 训练模式
model.train()
train_loss = 0
train_acc = 0
train_total = 0
for i, (inputs, labels) in enumerate(train_loader):
# 输入数据
inputs = inputs.to(device) # 将数据加载到设备
# 标签
labels = labels.to(device) # 将标签加载到设备
# 梯度清零
optimizer.zero_grad()
# 获取预测结果
outputs = model(inputs)
# 预测准确率
_, predicted = outputs.max(1)
train_total += labels.size(0)
train_acc += predicted.eq(labels).sum().item()
# 计算损失
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
# 累计损失
train_loss += loss.item()
# if (i+1) % 10 == 0:
# print(f" Step [{i+1:>3}/{len(train_loader)}] | Batch Loss: {loss.item():.4f}")
# print(f"Epoch [{epoch+1}/{10}], Loss: {train_loss/len(train_loader):.4f}")
# 计算平均准确率
avg_train_acc = 100.0 * train_acc / train_total
# 平均训练损失
avg_train_loss = train_loss / len(train_loader)
# 记录训练损失
history['train_loss'].append(avg_train_loss)
# 验证模式 (使用 X_test, y_test)
model.eval()
val_loss = 0.0
correct = 0 # 验证集准确率
total = 0 # 验证集样本总数
with torch.no_grad():
for i, (inputs, labels) in enumerate(val_loader):
# 输入数据
inputs = inputs.to(device) # 将数据移动到设备
# 标签
labels = labels.to(device) # 将标签移动到设备
# 预测
outputs = model(inputs)
# 损失
loss = criterion(outputs, labels)
# 验证集损失 累加
val_loss += loss.item()
# 得到预测概率最高的类别
_, predicted = outputs.max(1)
# labels.size(0) 获取当前batch标签的样本数量
total += labels.size(0)
# 预测正确的数量 .eq() 判断两个张量是否相等,返回一个布尔张量 .sum() 把 True/False 转成 1/0,然后求和
# .item() 把 PyTorch 的标量张量(scalar tensor)转成 Python 数字
correct += predicted.eq(labels).sum().item()
# 平均验证损失
avg_val_loss = val_loss / len(val_loader)
# 验证集准确率
val_acc = 100.0 * correct / total
# 记录验证损失和准确率
history['val_loss'].append(avg_val_loss)
history['val_acc'].append(val_acc)
# 打印结果
print(f"\n Epoch [{epoch+1:>2}/10] ")
print(f" 🟢 Train Loss: {avg_train_loss:.4f} | Train Acc: {avg_train_acc:.2f}%")
print(f" 🔴 Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}%")
# tensorboard --logdir=runs runs为保存路径(替换为绝对路径) 集成终端打开
writer.add_scalar("Train/Loss", avg_train_loss, epoch+1)
writer.add_scalar("Train/Acc", avg_train_acc, epoch+1)
writer.add_scalar("Val/Loss", avg_val_loss, epoch+1)
writer.add_scalar("Val/Acc", val_acc, epoch+1)
# 保存模型
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "./weight/model_self.pth")
print("Saved best model!")
# 训练结束后,也可以打印最终的最佳准确率
print(f" Training finished. Best validation accuracy: {best_acc:.2f}%")
# 训练结束后关闭TensorBoard writer
writer.close()
# 保存history
np.save('./data/history_self.npy', history)
print("History saved!")
训练结果总结
- 最佳验证准确率:通常可达 90%~95%(受数据划分影响略有波动)
- 模型保存路径:
./weight/model_self.pth
- 训练历史保存:
./data/history_self.npy
,可用于绘制学习曲线
5. 模型验证
为了全面评估所构建模型的性能,本节从量化指标、分类报表、混淆矩阵三个维度对模型在验证集和测试集上的表现进行系统分析。重点考察模型的准确率、召回率、F1 分数以及各类别之间的混淆情况,从而判断其鲁棒性与泛化能力。
5.1 验证过程数据化
在模型训练完成后,需对预测结果进行结构化保存,以便后续分析与部署。本项目将验证集上的原始预测结果(包括输入信号、真实标签、预测概率、预测类别等)导出为 CSV 文件,形成结构化的验证数据集。
该文件可用于:
- 追踪错误样本(如误判的 V 类心跳)
- 分析模型置信度分布
- 支持临床医生复核
- 构建自动化评估流水线
最终生成的 Excel/CSV 文件示例如下:
5.2 指标报表
使用 sklearn.metrics.classification_report
生成详细的分类性能报表,包含每个类别的精确率(Precision)、召回率(Recall)、F1 分数(F1-Score)和支持样本数(Support)。
- Precision(精确率):预测为某类的样本中,真正属于该类的比例 → 关注“预测是否可靠”
- Recall(召回率):真实为某类的样本中,被正确识别的比例 → 关注“是否漏检”
- F1-Score:Precision 与 Recall 的调和平均,综合反映类别识别能力
- Support:该类在数据集中出现的次数
验证数据得到的报表:
从图中可见:
- N 类(正常心跳):由于样本占比高,Precision 和 Recall 均接近 1.0,模型对其识别非常稳定。
- V 类(室性早搏):Precision 高达 100%,Recall 达到 99%,表明模型在该类上表现出色——不仅极少将其他类型误判为 V 类(高精确率),也成功捕获了绝大多数真实的室性早搏(高召回率)。这在临床应用中尤为重要,因为漏检室性早搏可能带来严重风险。
测试数据得到的报表:
测试集报表反映了模型在“未见过”数据上的表现。整体指标略低于验证集,但仍保持较高水平(平均 F1 > 0.85),表明模型具备良好的泛化能力。
5.3 混淆矩阵
混淆矩阵是评估分类模型性能的重要工具,能够直观展示各类别之间的误判模式。通过分析混淆矩阵,可以识别模型最容易混淆的类别对,进而指导后续优化方向。
混淆矩阵可视化(验证数据集),如下:
混淆矩阵可视化(测试数据集),如下:
主要观察结论:
- 主对角线值高:说明大多数样本被正确分类,模型整体有效。
- N ↔ S 类之间存在少量混淆:可能由于部分 S 类心跳形态接近正常,导致边界模糊。
- F 类(融合波)识别效果最差:常被误判为 N 或 V 类,因其形态介于两者之间且样本稀少。
6. 模型优化
尽管基础模型已取得较好性能,但仍存在改进空间,尤其是在稀有类别(如 F 类)识别和泛化能力方面。为此,本节提出三种优化策略:增加网络深度、继续训练(微调)、引入预训练与迁移学习,以进一步提升模型表现。
6.1 增加网络深度
原始模型采用三层卷积结构,在特征提取能力上存在一定限制。为此,设计了一个更深的变体 ECGCNN_Deep
,包含四个卷积块,并引入 AdaptiveAvgPool1d
保证全连接层输入维度固定。
更深的网络能够:
- 提取更复杂的高层语义特征
- 扩大感受野,捕捉更长程的心律上下文
- 增强非线性表达能力
实验表明,适当加深网络可在不显著增加过拟合风险的前提下,提升对复杂心跳模式(如 S/F 类)的识别能力。
'''
模型构建(增加网络层版)
'''
# model_deep.py
import torch.nn as nn
import torch
class ECGCNN_Deep(nn.Module):
def __init__(self, num_classes=5):
super(ECGCNN_Deep, self).__init__()
self.features = nn.Sequential(
# Block 1
# 输入通道数1,输出通道数32,卷积核大小5,步长1,填充2
nn.Conv1d(1, 32, kernel_size=5, stride=1,padding=2),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2),
# Block 2
nn.Conv1d(32, 64, kernel_size=3, stride=1,padding=1),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2),
# Block 3
nn.Conv1d(64, 128, kernel_size=3, stride=1,padding=1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2),
# Block 4
nn.Conv1d(128, 256, kernel_size=3, stride=1,padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.AdaptiveAvgPool1d(8), # 固定输出长度
)
self.classifier = nn.Sequential(
nn.Linear(256 * 8, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
if __name__ == '__main__':
# 创建数据
X_train = torch.randn(1, 187, 1)
# 创建模型
model = ECGCNN_Deep(num_classes=5)
# 测试模型
print(model(X_train).shape) # [1, 5]
6.2 继续训练
初始训练仅进行 10 个 epoch,模型可能尚未完全收敛。为进一步挖掘模型潜力,采用**小学习率继续训练(fine-tuning)**策略,在已有权重基础上再训练 50 个 epoch。
该方法的优势包括:
- 避免从头训练的高成本
- 利用已学习的特征基础进行精细化调整
- 在损失平台期后实现进一步下降
'''
继续训练模型
'''
# continue_train.py
import torch
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader,TensorDataset
from model_self import ECGCNN
import os
# 加载预处理数据
X_train = np.load('./data/processed_data/X_train.npy')
y_train = np.load('./data/processed_data/y_train.npy')
# 转为PyTorch张量
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
# print("Loaded data",X_train.shape,y_train.shape)
# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 设备选择
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载已有模型
model = ECGCNN(num_classes=5)
model.load_state_dict(torch.load('./weight/model_self.pth'))
model.to(device)
# 继续训练(例如再训练 50 个 epoch)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # 更小的学习率
criterion = torch.nn.CrossEntropyLoss()
new_history = {
'train_loss': [],
'train_acc': [],
}
for epoch in range(50): # 继续训练 50 个 epoch
print(f"\n=== Epoch [{epoch+1}/50] ===")
model.train()
train_loss = 0.0
train_acc = 0
train_total = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
train_total += labels.size(0)
train_acc += predicted.eq(labels).sum().item()
avg_train_loss = train_loss / len(train_loader)
avg_train_acc = 100. * train_acc / train_total
# 记录历史
new_history['train_loss'].append(avg_train_loss)
new_history['train_acc'].append(avg_train_acc)
# 打印
print(f"Train Loss: {avg_train_loss:.4f} | Acc: {avg_train_acc:.2f}%")
# 7. 保存最终模型(训练完直接保存)
os.makedirs('./weight', exist_ok=True)
torch.save(model.state_dict(), './weight/model_self_finetuned_final.pth')
print(f" 继续训练完成!最终模型已保存。")
# 8. 保存训练历史
np.save('./data/new_history_finetune.npy', new_history)
print(" 训练历史已保存。")
6.3 预训练和迁移学习
采用 ResNet18 作为基础架构,并对其进行改造以适应 1D 心电信号输入:
将原始的 2D 卷积层 conv1 修改为 kernel_size=(7,1),使其能处理单通道时间序列;
使用在大规模 ECG 数据集上预训练的权重(ecg_resnet18.pth)初始化模型;
在目标任务上进行迁移学习:
可选择 冻结主干网络(backbone),仅微调分类头(适合小数据集);
或 全模型微调(fine-tune all layers),适合数据量较大时。
该方法有效利用了预训练模型提取通用心电特征的能力,提升了小样本下的分类性能。
'''
预训练模型 + 迁移学习
'''
# pretrain_translearn.py
import torch
import torch.nn as nn
import torchvision.models as models
class PretrainedResNet1D(nn.Module):
def __init__(self, num_classes=5, freeze_backbone=False):
super(PretrainedResNet1D, self).__init__()
# 加载 ImageNet 上预训练的 ResNet18 模型
backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
# 修改第一层卷积以适应单通道输入 (ECG 信号)
backbone.conv1 = nn.Conv2d(
in_channels=1,
out_channels=64,
kernel_size=(7, 1),
stride=(2, 1),
padding=(3, 0),
bias=False
)
# 提取除了全连接层以外的所有层
self.model = nn.Sequential(*list(backbone.children())[:-1])
# 修改最后的全连接层
self.fc = nn.Linear(512, num_classes) # 假设是 resnet18
# 初始化新的第一层卷积权重
nn.init.kaiming_normal_(self.model[0].weight, mode='fan_out', nonlinearity='relu')
# 如果 freeze_backbone=True,则冻结 backbone 参数
if freeze_backbone:
for param in self.model.parameters():
param.requires_grad = False
print("ResNet 卷积层已冻结,仅训练分类头。")
else:
print("所有层均可训练。")
def forward(self, x):
# 调整输入尺寸以匹配 Conv2d 的期望格式
if x.dim() == 2:
x = x.unsqueeze(1).unsqueeze(-1) # (N, 187) -> (N, 1, 187, 1)
elif x.dim() == 3:
x = x.unsqueeze(-1) # (N, 1, 187) -> (N, 1, 187, 1)
x = self.model(x)
x = torch.flatten(x, 1) # 展平
x = self.fc(x)
return x
def create_model(freeze_backbone=False):
model = PretrainedResNet1D(
num_classes=5,
freeze_backbone=freeze_backbone
)
return model
if __name__ == '__main__':
# 创建模型实例
model = create_model(freeze_backbone=True)
# 模拟输入数据
x = torch.randn(4, 187)
# 前向传播
with torch.no_grad(): # 测试时关闭梯度计算
y = model(x)
print("输入形状:", x.shape) # [4, 187]
print("输出形状:", y.shape) # [4, 5]
print("模型创建成功!")
7. ECG信号分类系统实现
本节构建了一个端到端的 ECG 心跳分类系统,涵盖从原始信号输入到最终可视化输出的完整流程。系统分为三个核心模块:信号预处理、模型推理、结果可视化,实现了从“数据 → 预测 → 展示”的闭环,具备临床辅助诊断系统的雏形。
7.1 信号预处理(去噪、标准化)
原始 ECG 信号常受基线漂移、肌电干扰和工频噪声影响,直接影响模型识别精度。为此,设计了标准化的预处理流程:
- 带通滤波(0.5–40 Hz):保留典型心电信号频段,有效去除低频基线漂移和高频噪声;
- 零相位滤波(filtfilt):避免传统滤波引入的时间延迟,确保 R 峰位置不变;
- Z-score 标准化:将信号转换为均值为 0、标准差为 1 的分布,匹配模型训练时的数据分布。
该预处理流程简单高效,适用于实时或批量处理场景。
"""
清洗 ECG 信号
"""
# clean_ECG.py
import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
def notch_filter(ecg, notch_freq=50, fs=500, Q=30):
"""50Hz 工频陷波滤波"""
b, a = signal.iirnotch(notch_freq, Q, fs=fs)
return signal.filtfilt(b, a, ecg)
def bandpass_filter(ecg, low=0.5, high=150, fs=500, order=4):
"""带通滤波:保留 ECG 主要频带,去除非生理频率"""
nyquist = 0.5 * fs
low = low / nyquist
high = high / nyquist
if low >= 1.0:
low = 0.99
if high >= 1.0:
high = 0.99
b, a = signal.butter(order, [low, high], btype='band')
return signal.filtfilt(b, a, ecg)
def normalize(ecg):
"""标准化"""
return (ecg - np.mean(ecg)) / np.std(ecg)
def preprocess_ecg_signal(raw_signal, fs=500):
"""完整预处理流程"""
# 步骤1:去除 50Hz 工频干扰
cleaned = notch_filter(raw_signal, notch_freq=50, fs=fs)
# 步骤2:带通滤波(保留 0.5 - 150 Hz)
cleaned = bandpass_filter(cleaned, low=0.5, high=150, fs=fs, order=4)
# 步骤3:标准化
cleaned = normalize(cleaned)
return cleaned
# ======================
# 测试代码:导入并清洗信号
# ======================
if __name__ == '__main__':
try:
from Noisy_ECG_Signal import generate_noisy_ecg
except ImportError:
raise ImportError("请确保 'Noisy_ECG_Signal.py' 与本文件在同一目录下")
print(" 正在生成污染 ECG 信号...")
t, ecg_clean, ecg_noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=True)
print(" 正在清洗信号...")
ecg_cleaned = preprocess_ecg_signal(ecg_noisy, fs=500)
# 可视化:清洗前后对比(前 2 秒)
t_plot = t[:1000] # 前 2 秒 (500 * 2 = 1000)
clean_plot = ecg_clean[:1000]
noisy_plot = ecg_noisy[:1000]
cleaned_plot = ecg_cleaned[:1000]
plt.figure(figsize=(14, 8))
plt.subplot(3, 1, 1)
plt.plot(t_plot, clean_plot, color='blue', linewidth=1.2)
plt.title("1. Clean ECG Signal (Ground Truth)")
plt.ylabel("Amplitude (mV)")
plt.grid(True, alpha=0.3)
plt.subplot(3, 1, 2)
plt.plot(t_plot, noisy_plot, color='red', linewidth=1.2)
plt.title("2. Noisy ECG Signal (with Baseline, EMG, 50Hz)")
plt.ylabel("Amplitude (mV)")
plt.grid(True, alpha=0.3)
plt.subplot(3, 1, 3)
plt.plot(t_plot, cleaned_plot, color='green', linewidth=1.2)
plt.title("3. Cleaned ECG Signal (After Preprocessing)")
plt.xlabel("Time (s)")
plt.ylabel("Normalized Amplitude")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(" 测试完成!去噪效果如图所示。")
print(f" 原始 SNR 估计: {10*np.log10(np.var(ecg_clean)/np.var(ecg_noisy-ecg_clean)):.2f} dB")
print(f" 去噪后 SNR: {10*np.log10(np.var(ecg_clean)/np.var(ecg_cleaned-ecg_clean)):.2f} dB")
污染信号生成代码(Nosiy_ECG_Signal.py)
'''
生成受污染的ECG 信号
'''
# Nosiy_ECG_Signal.py
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
def generate_noisy_ecg(fs=500, duration=10, show_plot=True):
"""
生成受基线漂移、肌电干扰和工频噪声影响的 ECG 信号
参数:
fs: 采样频率 (Hz)
duration: 信号时长 (秒)
show_plot: 是否显示生成过程的可视化
返回:
t: 时间轴
ecg_clean: 干净 ECG 信号
ecg_noisy: 污染后的 ECG 信号
"""
# ======================
# 参数设置
# ======================
t = np.linspace(0, duration, int(fs * duration), endpoint=False) # 时间轴
# ======================
# 1. 生成干净 ECG 信号(简化模型:使用周期性波形模拟 P-QRS-T)
# ======================
def generate_ecg_clean(t, fs):
# 心率(bpm)
heart_rate = 75
heart_rate_rad = 2 * np.pi * heart_rate / 60
# R 波:周期性高斯脉冲(模拟 QRS 波群)
r_peaks = np.sin(heart_rate_rad * t) # 控制心跳节奏
qrs = 2.0 * np.exp(-1000 * (t % (60 / heart_rate) - 0.02)**2) # 高斯脉冲模拟 QRS
# T 波:稍宽的正向波
t_wave = 0.4 * np.exp(-200 * (t % (60 / heart_rate) - 0.15)**2)
# P 波:小的正向波
p_wave = 0.25 * np.exp(-400 * (t % (60 / heart_rate) - 0.0)**2)
# 组合 ECG
ecg_clean = p_wave + qrs + t_wave
# 添加轻微随机变化(心跳间期变异)
jitter = 0.01 * np.random.randn(len(t))
ecg_clean = np.interp(t, t + jitter, ecg_clean)
return ecg_clean
ecg_clean = generate_ecg_clean(t, fs)
# ======================
# 2. 添加噪声
# ======================
# (1) 基线漂移(Baseline Wander): 0.1 - 0.5 Hz 的低频正弦波组合
baseline_wander = (
0.3 * np.sin(2 * np.pi * 0.1 * t) +
0.2 * np.sin(2 * np.pi * 0.3 * t) +
0.1 * np.sin(2 * np.pi * 0.5 * t)
)
# (2) 肌电干扰(EMG-like noise): 高频随机噪声(30-200 Hz)
np.random.seed(42)
emg_noise = np.random.normal(0, 0.1, len(t))
# 用带通滤波器模拟肌电信号频带(30-200 Hz)
b_emg, a_emg = signal.butter(4, [30, 200], btype='bandpass', fs=fs)
emg_noise = signal.filtfilt(b_emg, a_emg, emg_noise)
emg_noise = 0.1 * emg_noise / np.max(np.abs(emg_noise)) # 归一化并控制幅度
# (3) 工频噪声(Power-line interference): 50 Hz(中国)或 60 Hz(美国)
power_freq = 50 # 可改为 60
power_noise = 0.15 * np.sin(2 * np.pi * power_freq * t)
# ======================
# 3. 合成污染信号
# ======================
ecg_noisy = ecg_clean + baseline_wander + emg_noise + power_noise
# ======================
# 4. 可视化(保持你原有的可视化不变)
# ======================
if show_plot:
plt.figure(figsize=(14, 8))
# 子图1:原始干净 ECG
plt.subplot(3, 1, 1)
plt.plot(t[:1000], ecg_clean[:1000], color='blue', linewidth=1.2)
plt.title("Clean ECG Signal")
plt.ylabel("Amplitude (mV)")
plt.grid(True, alpha=0.3)
# 子图2:添加的噪声
plt.subplot(3, 1, 2)
plt.plot(t[:1000], baseline_wander[:1000], label='Baseline Wander', color='orange')
plt.plot(t[:1000], emg_noise[:1000], label='EMG Noise', color='red', alpha=0.7)
plt.plot(t[:1000], power_noise[:1000], label='50 Hz Noise', color='purple', alpha=0.7)
plt.title("Added Noise Components")
plt.ylabel("Amplitude")
plt.legend()
plt.grid(True, alpha=0.3)
# 子图3:最终污染信号
plt.subplot(3, 1, 3)
plt.plot(t[:1000], ecg_noisy[:1000], color='red', linewidth=1.2)
plt.title("Noisy ECG Signal (with Baseline Wander, EMG, and 50 Hz Interference)")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude (mV)")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
return t, ecg_clean, ecg_noisy # 返回信号,供其他模块使用
# ======================
# 如果直接运行此文件,则生成并显示信号
# ======================
if __name__ == "__main__":
print(" 正在生成污染 ECG 信号...")
t, clean, noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=True)
print(" 信号生成完成,长度:", len(t))
污染信号生成可视化图 ,如下:
cleaned信号处理对比图 ,如下:
7.2 模型推理(PyTorch )
模型推理是分类系统的核心环节。本模块封装了模型加载与预测逻辑,支持单个心跳或批量信号输入。
- 使用
torch.load()
加载训练好的.pth
权重文件; - 通过
model.eval()
和torch.no_grad()
关闭梯度计算,提升推理效率; - 输出包含:预测类别、置信度(最大概率值)、各类别概率分布,便于后续分析与决策。
'''
模型推理
'''
# inference.py
import torch
import numpy as np
import onnxruntime as ort
import matplotlib.pyplot as plt
from scipy import signal
from model_self import ECGCNN
from clean_ECG import preprocess_ecg_signal
# 类别名称
CLASS_NAMES = ['Normal', 'Supraventricular', 'Ventricular', 'Fusion', 'Unknown']
# ======================
# 1. 加载模型
# ======================
def load_model(model_path="./weight/model_self.pth"):
"""
加载训练好的 PyTorch 模型
"""
model = ECGCNN(num_classes=5)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
print(" PyTorch 模型已加载")
return model
def load_onnx_model(onnx_model_path="./weight/ecg_model_self.onnx"):
"""
加载 ONNX 模型
"""
session = ort.InferenceSession(onnx_model_path)
print(" ONNX 模型已加载")
return session
# ======================
# 2. 检测 R 波(用于分割心跳)
# ======================
def detect_r_peaks(ecg, fs=500):
"""
使用简单阈值法检测 R 波(适用于干净或轻度污染信号)
返回 R 波位置索引
"""
# 使用带通滤波增强 QRS
b, a = signal.butter(2, [5, 15], btype='bandpass', fs=fs)
filtered = signal.filtfilt(b, a, ecg)
# 简单平方 + 滑动窗能量
squared = filtered ** 2
window_size = int(0.1 * fs) # 100ms 滑动窗
smoothed = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
# 阈值检测
threshold = 0.5 * np.max(smoothed)
r_peaks = signal.find_peaks(smoothed, height=threshold, distance=int(0.6 * fs))[0] # 最小间距 600ms
return r_peaks
# ======================
# 3. 提取心跳片段(长度 187)
# ======================
def extract_beats(ecg, r_peaks, fs=500, beat_length=187):
"""
以 R 峰为中心,前后截取心跳片段
beat_length: 模型输入长度(如 187)
"""
half_len = beat_length // 2
beats = []
valid_positions = []
for r in r_peaks:
start = r - half_len
end = r + (beat_length - half_len)
if start >= 0 and end <= len(ecg):
beat = ecg[start:end]
if len(beat) == beat_length:
beats.append(beat)
valid_positions.append(r)
return np.array(beats), np.array(valid_positions)
# ======================
# 4. Softmax 函数
# ======================
def softmax(x, axis=-1):
"""
Numerically stable softmax
"""
x = x - np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
# ======================
# 5. PyTorch 批量推理
# ======================
def predict_heartbeat(model, ecg_signal):
"""
ecg_signal: 已经去噪的 ECG 片段 (187,) 或 (N, 187)
返回: 预测类别, 置信度, 概率分布
"""
model.eval()
with torch.no_grad():
if ecg_signal.ndim == 1:
# 单个心跳
tensor = torch.tensor(ecg_signal, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
else:
# 批量心跳 (N, 187) -> (N, 1, 187)
tensor = torch.tensor(ecg_signal, dtype=torch.float32).unsqueeze(1)
output = model(tensor) # logits: (N, 5)
prob = torch.softmax(output, dim=1).numpy() # 转为 numpy 概率
confidence = np.max(prob, axis=1)
predicted = np.argmax(prob, axis=1)
return predicted, confidence, prob
# ======================
# 6. ONNX 批量推理
# ======================
def predict_heartbeat_onnx(session, ecg_signal):
"""
使用 ONNX 模型预测单个或批量心跳
"""
if ecg_signal.ndim == 1:
input_data = ecg_signal.reshape(1, 1, -1).astype(np.float32) # (1, 1, 187)
else:
input_data = ecg_signal.reshape(-1, 1, 187).astype(np.float32) # (N, 1, 187)
result = session.run(["logits"], {"ecg_input": input_data})
logits = result[0] # (N, 5)
prob = softmax(logits, axis=1)
confidence = np.max(prob, axis=1)
predicted = np.argmax(prob, axis=1)
return predicted, confidence, prob
# ======================
# 7. 可视化函数
# ======================
def plot_ecg_comparison(t, ecg_noisy, ecg_cleaned, r_peaks, beat_positions, predictions, confidences, class_names):
"""
绘制污染信号、干净信号、R波位置、预测结果
"""
fig, axes = plt.subplots(3, 1, figsize=(16, 10), sharex=True)
# 子图 1: 原始污染信号
axes[0].plot(t, ecg_noisy, color='lightcoral', linewidth=0.8)
axes[0].set_title("Noisy ECG Signal", fontsize=14, fontweight='bold')
axes[0].set_ylabel("Amplitude")
axes[0].grid(True, alpha=0.3)
# 子图 2: 去噪后信号
axes[1].plot(t, ecg_cleaned, color='steelblue', linewidth=1.0)
# 标出 R 波位置
for r in r_peaks:
axes[1].axvline(t[r], color='red', linestyle='--', alpha=0.7)
axes[1].set_title("Denoised ECG Signal with R-Peak Detection", fontsize=14, fontweight='bold')
axes[1].set_ylabel("Amplitude")
axes[1].grid(True, alpha=0.3)
# 子图 3: 预测结果标注
axes[2].plot(t, ecg_cleaned, color='gray', linewidth=0.8, alpha=0.8)
# 颜色映射(每类不同颜色)
colors = ['green', 'orange', 'red', 'purple', 'gray']
for i, (pos, pred, conf) in enumerate(zip(beat_positions, predictions, confidences)):
if pos < len(t):
x = t[pos]
y = ecg_cleaned[pos]
class_name = class_names[pred]
color = colors[pred]
axes[2].axvline(x, color=color, alpha=0.6)
axes[2].text(x, max(ecg_cleaned)*0.9, f'{class_name}\n{conf:.2f}',
color=color, fontsize=8, ha='center', rotation=90,
bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.2))
axes[2].set_title("Predicted Heartbeat Types (Color-coded)", fontsize=14, fontweight='bold')
axes[2].set_ylabel("Amplitude")
axes[2].set_xlabel("Time (s)")
axes[2].grid(True, alpha=0.3)
# 图例说明
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=colors[i], label=class_names[i]) for i in range(len(class_names))]
axes[2].legend(handles=legend_elements, bbox_to_anchor=(1.02, 1), loc='upper left')
plt.tight_layout()
plt.show()
def plot_prediction_confidence(predictions, confidences, class_names):
"""
绘制预测类别和置信度柱状图
"""
fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(len(predictions))
colors = ['green', 'orange', 'red', 'purple', 'gray']
bar_colors = [colors[pred] for pred in predictions]
bars = ax.bar(x, confidences, color=bar_colors, alpha=0.7, edgecolor='black', linewidth=0.5)
ax.set_xlabel("Beat Index")
ax.set_ylabel("Confidence")
ax.set_title("Prediction Confidence per Heartbeat", fontsize=14, fontweight='bold')
ax.set_ylim(0, 1.1)
ax.grid(True, axis='y', alpha=0.3)
# 在柱子上方标注类别
for i, (bar, pred) in enumerate(zip(bars, predictions)):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
class_names[pred], ha='center', va='bottom', fontsize=9, rotation=45)
# 图例
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=colors[i], label=class_names[i]) for i in range(len(class_names))]
ax.legend(handles=legend_elements, title="Classes")
plt.tight_layout()
plt.show()
# ======================
# 主测试流程
# ======================
if __name__ == "__main__":
# --- 1. 加载模型 ---
model = load_model("./weight/model_self.pth")
ort_session = load_onnx_model("./weight/ecg_model_self.onnx")
# --- 2. 生成并清洗 ECG 信号 ---
from Noisy_ECG_Signal import generate_noisy_ecg
t, ecg_clean, ecg_noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=False)
# 只清洗一次
ecg_cleaned = preprocess_ecg_signal(ecg_noisy, fs=500)
print(f"ECG 信号长度: {len(ecg_cleaned)}")
# --- 3. 检测 R 波 ---
r_peaks = detect_r_peaks(ecg_cleaned, fs=500)
print(f"检测到 {len(r_peaks)} 个 R 波")
# --- 4. 提取心跳 ---
beats, beat_positions = extract_beats(ecg_cleaned, r_peaks, fs=500, beat_length=187)
print(f"成功提取 {len(beats)} 个心跳片段")
if len(beats) == 0:
print(" 未提取到有效心跳片段")
else:
# --- 5. 批量推理 ---
pred_torch, conf_torch, prob_torch = predict_heartbeat(model, beats)
pred_onnx, conf_onnx, prob_onnx = predict_heartbeat_onnx(ort_session, beats)
# --- 6. 打印结果 ---
print("\n" + "=" * 50)
print(" 心跳分类结果对比")
print("=" * 50)
for i in range(min(10, len(beats))):
match = "✅" if pred_torch[i] == pred_onnx[i] else "❌"
print(f"心跳 {i+1:2d}: "
f"[PyTorch] {CLASS_NAMES[pred_torch[i]]} ({conf_torch[i]:.3f}) | "
f"[ONNX] {CLASS_NAMES[pred_onnx[i]]} ({conf_onnx[i]:.3f}) {match}")
# --- 7. 统计一致性 ---
accuracy = np.mean(pred_torch == pred_onnx)
print(f"\n ONNX 与 PyTorch 预测一致率: {accuracy * 100:.1f}%")
# --- 8. 可视化 ---
print("\n 正在生成可视化图表...")
# 时间轴
t = np.linspace(0, len(ecg_cleaned)/500, len(ecg_cleaned)) # fs=500
# 绘制信号对比和预测结果
plot_ecg_comparison(
t=t,
ecg_noisy=ecg_noisy,
ecg_cleaned=ecg_cleaned,
r_peaks=r_peaks,
beat_positions=beat_positions,
predictions=pred_torch, # 使用 PyTorch 预测结果
confidences=conf_torch,
class_names=CLASS_NAMES
)
# 绘制置信度图
plot_prediction_confidence(pred_torch, conf_torch, CLASS_NAMES)
7.3 结果可视化(画波形 + 打标签)
为了直观展示模型在真实 ECG 波形上的分类效果,开发了可视化模块。该模块将原始信号与预测结果融合呈现:
- 在原始 ECG 曲线上标注每个 R 峰对应的心跳类型;
- 使用颜色编码(绿色/红色)表示置信度高低(可设置阈值);
- 添加文本框显示类别名称与置信分数,提升可读性;
- 支持自定义采样率、标签位置、字体大小等参数。
可视化结果不仅有助于模型调试与错误分析,也为医生提供直观的辅助判读工具,增强人机协同效率。
"""
结果可视化
"""
# visualize.py
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Optional
from inference import load_model, predict_heartbeat # 导入你需要的函数
# 使用全局类别名(也可以传参)
from inference import CLASS_NAMES
def plot_ecg_with_labels(
ecg_signal: np.ndarray,
r_peaks: List[int],
model,
class_names: List[str] = None,
segment_length: int = 187,
sample_rate: int = 360,
confidence_threshold: float = 0.6,
title: str = "ECG 信号与心跳分类结果",
figsize: tuple = (14, 6)
):
"""
在原始 ECG 信号上绘制波形,并为每个心跳打上预测标签
参数:
ecg_signal: 完整 ECG 信号 (T,)
r_peaks: R 峰位置列表(索引)
model: 已加载的 PyTorch 模型对象
class_names: 类别名称列表(默认使用 inference 中的 CLASS_NAMES)
segment_length: 每个心跳输入长度(默认 187)
sample_rate: 采样率(Hz)
confidence_threshold: 置信度阈值(高于绿色,低于红色)
title: 图表标题
figsize: 图像大小
"""
if len(ecg_signal.shape) != 1:
raise ValueError("ecg_signal 必须是一维信号")
half_len = segment_length // 2
class_names = class_names or CLASS_NAMES # 默认使用全局类别
predictions = []
plt.figure(figsize=figsize)
t = np.arange(len(ecg_signal)) / sample_rate
plt.plot(t, ecg_signal, 'k', linewidth=0.8, label='ECG Signal')
# 对每个 R 峰进行预测并标注
for i, peak in enumerate(r_peaks):
left = peak - half_len
right = peak + half_len
if left < 0 or right >= len(ecg_signal):
predictions.append(None)
continue
# 提取单个心跳
heartbeat = ecg_signal[left:right]
try:
# 使用 inference.py 中的 predict_heartbeat 函数
pred_ids, confs, probs = predict_heartbeat(model, heartbeat)
pred_label = pred_ids[0] # 返回是 (1,) 数组
confidence = confs[0]
pred_name = class_names[pred_label]
except Exception as e:
print(f"第 {i} 个心跳预测失败: {e}")
pred_name = "Error"
confidence = 0.0
predictions.append((pred_name, confidence))
# 在 R 峰上方添加文本标签
x_time = peak / sample_rate
y_height = ecg_signal[peak] + (np.max(ecg_signal) - np.min(ecg_signal)) * 0.05
color = 'green' if confidence > confidence_threshold else 'red'
plt.text(
x_time, y_height, f"{pred_name}\n({confidence:.2f})",
fontsize=9, ha='center', va='bottom',
color=color, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7)
)
plt.title(title, fontsize=14)
plt.xlabel("时间 (秒)", fontsize=12)
plt.ylabel("幅度 (mV)", fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
return predictions
# ========================
# 示例使用(仅用于测试)
# ========================
if __name__ == "__main__":
import numpy as np
print(" 正在测试 ECG 可视化模块...")
# 模拟一段较长的 ECG 信号
T = 3000
sample_rate = 360
ecg_signal = np.zeros(T)
r_peaks = [187, 540, 910, 1280, 1650, 2020, 2400, 2780] # R 峰位置
# 生成模拟心跳(高斯脉冲)
for peak in r_peaks:
if peak < T:
start = max(0, peak - 20)
end = min(T, peak + 40)
x = np.arange(start, end)
ecg_signal[x] += np.exp(-0.05 * (x - peak)**2) * 1.8
ecg_signal += np.random.normal(0, 0.05, T) # 加噪声
# 加载模型
try:
model = load_model("./weight/model_self.pth")
print(" 模型加载成功")
except Exception as e:
print(f" 模型加载失败,请检查路径: {e}")
exit()
# 执行可视化
results = plot_ecg_with_labels(
ecg_signal=ecg_signal,
r_peaks=r_peaks,
model=model,
class_names=CLASS_NAMES,
sample_rate=sample_rate,
confidence_threshold=0.6,
title="ECG 心跳分类结果可视化"
)
# 打印结果
print("\n 分类结果:")
for i, res in enumerate(results):
if res:
print(f"心跳 {i+1}: {res[0]} (置信度: {res[1]:.3f})")
else:
print(f"心跳 {i+1}: 越界或分析失败")
8. 模型移植
为了提升模型的跨平台部署能力,降低对 PyTorch 框架的依赖,本项目采用 ONNX(Open Neural Network Exchange) 格式进行模型导出与推理验证。ONNX 支持多种运行时(如 ONNX Runtime、TensorRT、CoreML),可在服务器、移动端、边缘设备上高效运行,极大增强了模型的工程落地潜力。
8.1 导出ONNX
使用 torch.onnx.export
工具将训练好的 PyTorch 模型转换为 .onnx
文件。关键配置如下:
dummy_input
:提供示例输入张量,用于追踪计算图;opset_version=11
:确保支持常用算子(如 Conv1d、BatchNorm);dynamic_axes
:允许动态 batch size,适应不同输入规模;do_constant_folding=True
:优化常量节点,减小模型体积并提升推理速度。
导出后可通过 Netron 等工具查看模型结构,确认节点连接正确。
'''
模型导出onnx
'''
# export_onnx.py
import torch
from model_self import ECGCNN # 确保路径正确
def main():
# 1. 定义模型结构(必须和训练时一致)
model = ECGCNN(num_classes=5)
# 2. 加载你训练好的权重
model.load_state_dict(torch.load("./weight/model_self.pth", map_location='cpu'))
model.eval() # 切换到推理模式
# 3. 构造一个示例输入(shape: batch x channel x length)
dummy_input = torch.randn(1, 1, 187) # 和ECG 数据 shape 一致
# 4. 导出为 ONNX (通过输入构造一个示例输入 进行追踪推理)
torch.onnx.export(
model,
dummy_input,
"./weight/ecg_model_self.onnx", # 输出路径
export_params=True, # 保存模型参数
opset_version=11, # ONNX 版本兼容性
do_constant_folding=True, # 优化
input_names=['ecg_input'], # 输入名
output_names=['logits'], # 输出名
dynamic_axes={
'ecg_input': {0: 'batch_size'},
'logits': {0: 'batch_size'}
} # 支持动态 batch
)
print(" 成功导出 ONNX 模型到:./weight/ecg_model_self.onnx")
if __name__ == "__main__":
main()
8.2 使用ONNX推理
利用 ONNX Runtime 加载 .onnx
模型并执行推理,验证其输出是否与原始 PyTorch 模型一致。
onnxruntime.InferenceSession
提供跨平台高性能推理引擎;- 输入需按
input_names
指定的名称传入(如'ecg_input'
); - 输出返回概率分布,取最大值作为预测结果。
经测试,ONNX 模型与 PyTorch 模型的预测结果完全一致(误差 < 1e-6),说明转换成功。同时,ONNX Runtime 在 CPU 上的推理速度更快,更适合部署在资源受限设备上。
import torch
import numpy as np
import onnxruntime as ort
from model_self import ECGCNN
from preprocess import preprocess_ecg_signal
# 类别名称
CLASS_NAMES = ['Normal', 'Supraventricular', 'Ventricular', 'Fusion', 'Unknown'
# ONNX 推理函数
def predict_heartbeat_onnx(session, ecg_signal):
"""
使用 ONNX 模型预测单个或批量心跳
"""
if ecg_signal.ndim == 1:
input_data = ecg_signal.reshape(1, 1, -1).astype(np.float32) # (1, 1, 187)
else:
input_data = ecg_signal.reshape(-1, 1, 187).astype(np.float32) # (N, 1, 187)
result = session.run(["logits"], {"ecg_input": input_data})
logits = result[0] # (N, 5)
prob = softmax(logits, axis=1)
confidence = np.max(prob, axis=1)
predicted = np.argmax(prob, axis=1)
return predicted, confidence, prob
if __name__ == "__main__":
ort_session = load_onnx_model("./weight/ecg_model_self.onnx")
# 生成并清洗 ECG 信号
from Noisy_ECG_Signal import generate_noisy_ecg
t, ecg_clean, ecg_noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=False)
# 只清洗一次
ecg_cleaned = preprocess_ecg_signal(ecg_noisy, fs=500)
print(f"ECG 信号长度: {len(ecg_cleaned)}")
# 检测 R 波
r_peaks = detect_r_peaks(ecg_cleaned, fs=500)
print(f"检测到 {len(r_peaks)} 个 R 波")
# 提取心跳
beats, beat_positions = extract_beats(ecg_cleaned, r_peaks, fs=500, beat_length=187)
print(f"成功提取 {len(beats)} 个心跳片段")
# 使用 ONNX 模型推理
pred_onnx, conf_onnx, prob_onnx = predict_heartbeat_onnx(ort_session, beats)
print(f"\n【ONNX 推理】")
print(f"预测类别: {CLASS_NAMES[pred_onnx]}")
print(f"置信度: {conf_onnx:.3f}")
9. 项目总结
9.1 问题及解决办法
在进行基于CNN实现心律失常(ECG)的小颗粒度分类时,可能会遇到以下问题及解决办法:
问题 | 解决方案 |
---|---|
类别严重不平衡(N类占80%) | 使用 SMOTE 过采样平衡训练集 |
模型过拟合 | 添加 Dropout、BatchNorm |
输入维度不匹配 | reshape 为 (N, 187, 1) 适配 1D-CNN |
推理时预处理不一致 | 保存 StandardScaler 并在推理时复用 |
9.2 收获
- 掌握了 1D-CNN 在时间序列分类中的应用
- 学会了使用 SMOTE 解决类别不平衡问题
- 实践了从训练到部署的完整流程(PyTorch → ONNX)
- 提升了模型可视化与可解释性能力