卷积神经网络项目:基于CNN实现心律失常(ECG)的小颗粒度分类系统

发布于:2025-09-02 ⋅ 阅读:(21) ⋅ 点赞:(0)

卷积神经网络项目实现文档

1、项目简介

1.1 项目名称

​ 基于CNN实现心律失常(ECG)的小颗粒度分类系统

1.2 项目简介

​ 心律失常是临床上常见且潜在致命的心血管疾病之一,包括房性早搏(PAC)、室性早搏(PVC)、心动过速等多种类型。传统的心电图(ECG)分析依赖医生人工判读,耗时长、主观性强,尤其在面对长时间动态心电监测(如 24 小时 Holter)数据时,极易出现漏诊或误诊。

​ 本项目旨在利用卷积神经网络(CNN)对MIT-BIH心律失常数据库中的ECG信号进行细粒度分类,识别五种常见的心律失常类型:正常心跳(N)、室上性早搏(S)、室性早搏(V)、融合波(F)和未知心跳(Q)。由于不同类别的ECG信号在形态上差异细微,且存在严重的类别不平衡问题,传统的机器学习方法难以取得理想效果。

​ 本项目采用深度学习中的CNN模型,充分利用卷积层对局部时序特征的提取能力。项目涵盖数据预处理、模型构建、训练优化、性能评估及模型部署全流程,并探索数据重采样、标准化、迁移学习等关键技术手段,最终实现高精度、可部署的心律失常自动识别系统。该系统可应用于:

  • 院内心电监护报警系统
  • 远程健康监测平台
  • 可穿戴设备(如智能手表)的异常心律预警
  • 医学教学与训练辅助工具

1.3 技术选择

为什么选择1D-CNN?

本项目采用 一维卷积神经网络(1D-CNN) 作为核心模型,主要基于以下几点考虑:

优势 说明
保留时序结构 ECG 信号本质上是一维时间序列,1D-CNN 能直接在原始信号上进行卷积操作,保留完整的时序信息,避免特征工程带来的信息损失。
自动特征提取 CNN 能自动学习 QRS 波群、P 波、T 波等关键形态特征,无需手动设计特征(如 RR 间期、波幅等),提升模型泛化能力。
局部感知能力 卷积核具有局部感受野,能有效捕捉 ECG 中局部波形变化(如 R 波突起、ST 段抬高),对异常心跳(如 PVC 的宽大畸形 QRS)敏感。
参数效率高 相比 RNN/LSTM,1D-CNN 训练更快、更稳定,适合部署在边缘设备或实时系统中。
成功先例 在 MIT-BIH 心律失常数据库上的多项研究(如 Kiranyaz et al., 2016; Acharya et al., 2017)已验证 1D-CNN 在 ECG 分类任务中的优越性能。

为什么不选 RNN 或 Transformer?

虽然 RNN 能建模长期依赖,但 ECG 心跳分类主要依赖局部波形特征而非长序列依赖。RNN 训练慢、易梯度消失;Transformer 在短序列上无明显优势且计算开销大。因此,1D-CNN 是精度与效率的最优平衡

2、数据

2.1 公开数据集

本项目使用国际公认的标准心律失常数据库:MIT-BIH Arrhythmia Database,该数据集由美国麻省理工学院(MIT)与贝斯以色列医院(Beth Israel Hospital)联合发布,是心电图自动分析领域最广泛使用的基准数据集之一。

名称:MIT-BIH Arrhythmia Database
来源Kaggle - MIT-BIH Arrhythmia Database
内容
mitbih_train.csv:训练集,共 109,446 条样本
mitbih_test.csv:测试集,共 21,892 条样本
格式说明
每行表示一个心跳周期的 ECG 信号,共 187 个时间点
最后一列为类别标签(0~4),对应五种心律类型

标签 类别 描述
0 N 正常心跳(Normal Beat)
1 S 室上性早搏(Supraventricular Premature)
2 V 室性早搏(Ventricular Premature)
3 F 融合波(Fusion Beat)
4 Q 未知心跳(Unclassifiable Beat)

可视ECG信号

在这里插入图片描述

2.2 数据分析与清洗

类别分布分析:训练集中类别严重不平衡,N类占比超过80%,V类仅占约5%。
处理方式

  • 对训练集使用 SMOTE 过采样,平衡各类别样本数量
  • 测试集保持原始分布,用于真实性能评估
  • 移除异常值(如全零信号)

2.3 数据预处理

标准化:使用 StandardScaler 对每个信号进行标准化
维度重塑:将 (N, 187) 转为 (N, 187, 1),适配1D-CNN输入

2.4 数据分割

训练集:mitbih_train.csv → 用于模型训练
测试集:mitbih_test.csv → 用于最终性能评估
验证集:从训练集中划分20%用于调参

数据处理(清洗+预处理+分割)

'''
数据预处理
'''
# data_preprocess.py

import joblib
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE
import pandas as pd
import os
import numpy as np


# 加载 CSV 文件  df.shape = (109446, 188)   109446个样本,187个特征 + 1个标签
df = pd.read_csv('./data/archive/mitbih_train.csv', header=None)

# 数据预处理
# 分离data和label
X = df.iloc[:, :-1].values
y = df.iloc[:, -1].values

# 标准化数据
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 重采样  解决样本不均衡问题
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_scaled, y)

# 划分训练集和测试集
X_train, X_val, y_train, y_val = train_test_split(
    X_resampled, 
    y_resampled, 
    test_size=0.2, 
    random_state=42,
    # 确保训练集和测试集的标签分布一致
    stratify=y_resampled
)

# 修改数据维度  (样本数, 时间步, 特征数)  X_train.shape = (87371, 187, 1)  
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_val = X_val.reshape(X_val.shape[0], X_val.shape[1], 1)

print(X_train.shape, X_val.shape, y_train.shape, y_val.shape)  # (87371, 187, 1) (21927, 187, 1) (87371,) (21927,)

# 保存预处理结果
# 创建目录    exist_ok=True 表示如果目录已经存在,则不报错
os.makedirs('./data/processed_data', exist_ok=True)

# 保存处理后的 numpy 数组
np.save('./data/processed_data/X_train.npy', X_train)
np.save('./data/processed_data/X_val.npy', X_val)
np.save('./data/processed_data/y_train.npy', y_train)
np.save('./data/processed_data/y_val.npy', y_val)

# 保存 StandardScaler  (重要!推理时要用)
joblib.dump(scaler, './data/processed_data/scaler.pkl')

# 保存 SMOTE 对象 (用于分析)
joblib.dump(smote, './data/processed_data/smote.pkl')

print('数据处理完成!')

3. 神经网络

为实现对 ECG 心跳信号的自动分类,本项目设计并实现了一个轻量级的一维卷积神经网络(1D-CNN),命名为 ECGCNN。该模型专为处理长度为 187 的单导联心电信号设计,能够在保持较高精度的同时满足实时性要求。

3.1 模型架构设计

模型整体结构由 3 个卷积块 + 2 个全连接层组成,采用“卷积提取特征 → 展平 → 分类”的经典流程。每一层的设计均针对 ECG 信号特点进行优化:

  • 输入格式适配:支持 (N, 187)(N, 187, 1) 格式的输入,自动转换为 PyTorch 所需的 (N, C, L) 格式(即 (batch_size, channels, sequence_length))。
  • 多层卷积提取局部特征:使用 1D 卷积核捕捉 QRS 波群、ST 段等关键形态特征。
  • BatchNorm + ReLU + Dropout:提升训练稳定性、加速收敛,并防止过拟合。
  • 动态全连接层尺寸计算:通过虚拟输入自动推导展平后的维度,增强模型灵活性。

3.2 自定义网络实现(model_self.py)

以下是基于 PyTorch 实现的完整模型定义:

'''
模型构建:基于 PyTorch 的 1D-CNN 模型
'''
# model_self.py
import torch 
import torch.nn as nn

class ECGCNN(nn.Module):
    """
    基于 PyTorch 的 1D-CNN 模型,用于 ECG 心跳分类
    """
    def __init__(self, input_shape = (187,1),num_classes=5, dropout_rate=0.5):
        super(ECGCNN, self).__init__()
        # 注意:PyTorch 的 Conv1D 输入是 (N, C, L)
        input_channels = input_shape[1] if len(input_shape) == 2 else 1  # 默认 1 通道

        # 卷积层
        self.conv1 =nn.Sequential(
            nn.Conv1d(input_channels, 32, kernel_size=5, stride=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv1d(32, 64, kernel_size=3, stride=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.MaxPool1d(kernel_size=2)
        )

        # nn.Flatten()  的作用:将输入的维度进行展平,方便全连接层处理
        self.flatten = nn.Flatten()

        # 计算展平 flatten 之后的维度
        dummy_input = torch.randn(1, input_channels,187)  # 创建一个虚拟输入张量
        with torch.no_grad():  # 禁用梯度计算,以节省内存和计算资源
            features = self._forward_features(dummy_input)  # 通过前向传播计算展平后的维度
            flatten_dim = features.view(1, -1).shape[1]  # 计算展平后的维度

        self.fc1 = nn.Sequential(
            nn.Linear(flatten_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        self.fc2 = nn.Sequential(
            nn.Linear(128, num_classes),
            # nn.Softmax(dim=1)  # dim=1 表示对列进行归一化
        )
    
    def _forward_features(self, x):
        '''
        用于计算全连接层输入维度的辅助函数
        '''
        x = self.conv1(x)

        x = self.conv2(x)

        x = self.conv3(x)

        return x
    
    def forward(self, x):
        """
        前向传播(PyTorch 中叫 forward,不是 call)
        :param x: 输入张量 (N, C, L)
        :return: 输出张量 (N, num_classes)
        """
        # PyTorch 输入格式: (batch_size, channels, sequence_length)
        # 所以需要把 (N, 187, 1) 转成 (N, 1, 187)
        if x.dim() == 3 and x.shape[2] == 1:  #  如果是(N, 187, 1) 则转成(N, 1, 187)
            x = x.permute(0, 2, 1)  # (N, 187, 1) -> (N, 1, 187)
        elif x.dim() == 2:  # 如果是(N, 187) 则转成(N, 1, 187)
            x = x.unsqueeze(1)  # (N, 187) -> (N, 1, 187)

        # 卷积块
        x = self.conv1(x)

        x = self.conv2(x)

        x = self.conv3(x)

        # 全连接层
        x = self.flatten(x)   #展平

        x = self.fc1(x)

        x = self.fc2(x)

        return x
    

if __name__ == '__main__':
    model = ECGCNN(num_classes=5)
    print(model)

    # 测试前向传播
    x = torch.randn(32, 187, 1)  # 32个样本,187个特征,1个通道
    
    with torch.no_grad():
        output = model(x)
    
    print("输出x的形状:",x.shape)
    print("输出形状:", output.shape)  # (32, 5)

4. 模型训练

4.1 训练参数

为确保模型有效学习且不过拟合,采用以下训练配置:

参数 说明
轮次(Epochs) 10 在验证集性能趋于稳定后停止,避免过拟合
批次大小(Batch Size) 32 平衡梯度稳定性与内存占用
学习率(Learning Rate) 0.0001 使用 Adam 优化器时的常用小学习率,保证收敛稳定
设备 CPU 当前运行环境为 CPU,未来可扩展支持 GPU 加速

4.2 损失函数

采用 交叉熵损失函数(Cross-Entropy Loss)

criterion = nn.CrossEntropyLoss()

该函数结合了 LogSoftmaxNLLLoss,适用于多分类任务。它衡量模型输出概率分布与真实标签之间的差异,是分类任务中最常用的损失函数之一。

4.3 优化器

使用 Adam 优化器

optimizer = optim.Adam(model.parameters(), lr=0.0001)

Adam 结合了动量(Momentum)和自适应学习率(RMSProp)的优点,具有收敛快、鲁棒性强的特点,特别适合深度神经网络的训练。

4.4 训练过程可视化

为监控训练动态,使用 TensorBoard 进行可视化,记录每轮的训练/验证损失与准确率。

训练过程指标走势图

训练准确率和损失走势图 (如下)在这里插入图片描述

验证准确率和损失走势图(如下)在这里插入图片描述

从图中可以看出:

  • 训练损失持续下降,训练准确率稳步上升,表明模型正在有效学习;
  • 验证损失先降后趋于平稳,未出现明显回升,说明模型未严重过拟合;
  • 最终验证准确率可达 90% 以上(具体数值依运行结果而定),表现出良好的分类能力。

网络结构图

在这里插入图片描述

该图为 TensorBoard 自动生成的计算图,清晰展示了数据流动路径和各层连接关系。

训练脚本(train.py)

以下是完整的训练流程实现,包含数据加载、模型定义、训练循环、验证评估与结果保存:

'''
模型训练
'''
# train.py

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader,TensorDataset
from model_self import ECGCNN
from torch.utils.tensorboard import SummaryWriter

# tensorboard 可视化操作
writer = SummaryWriter()

# 加载预处理数据
X_train = np.load('./data/processed_data/X_train.npy')
X_val = np.load('./data/processed_data/X_val.npy')
y_train = np.load('./data/processed_data/y_train.npy')
y_val = np.load('./data/processed_data/y_val.npy')

# 转为PyTorch张量
X_train = torch.tensor(X_train, dtype=torch.float32)
X_val = torch.tensor(X_val, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_val = torch.tensor(y_val, dtype=torch.long)

# print("Loaded data",X_train.shape,y_train.shape)

# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# 设备选择val
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 创建模型 、损失函数、优化器
model = ECGCNN(num_classes=5)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

model.to(device)
criterion.to(device)

# 可视化网络结构
with torch.no_grad():
    #  X_train 是 (N, 187, 1)
    sample_input = torch.zeros(1, X_train.shape[1], X_train.shape[2]).to(device)  # (1, 187, 1)
    writer.add_graph(model, sample_input)

# 记录训练过程中准确率最高的准确率
best_acc = 0.0

# 定义history字典,用于保存训练过程中损失和准确率
history = {'train_loss': [], 'val_loss': [], 'val_acc': []}

for epoch in range(10):
    # 训练模式
    model.train()
    train_loss = 0
    train_acc = 0
    train_total = 0
    for i, (inputs, labels) in enumerate(train_loader):
        # 输入数据
        inputs = inputs.to(device)  # 将数据加载到设备
        # 标签
        labels = labels.to(device)  # 将标签加载到设备
        #  梯度清零
        optimizer.zero_grad()
        #  获取预测结果
        outputs = model(inputs)
        # 预测准确率
        _, predicted = outputs.max(1)
        train_total += labels.size(0)
        train_acc += predicted.eq(labels).sum().item()
        # 计算损失
        loss = criterion(outputs, labels)
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()
        # 累计损失
        train_loss += loss.item()
        # if (i+1) % 10 == 0:
        #     print(f" Step [{i+1:>3}/{len(train_loader)}] | Batch Loss: {loss.item():.4f}")
    # print(f"Epoch [{epoch+1}/{10}], Loss: {train_loss/len(train_loader):.4f}")
    
    #  计算平均准确率
    avg_train_acc = 100.0 * train_acc / train_total
    # 平均训练损失
    avg_train_loss = train_loss / len(train_loader)
    # 记录训练损失
    history['train_loss'].append(avg_train_loss)

    # 验证模式  (使用 X_test, y_test)
    model.eval()
    val_loss = 0.0
    correct = 0   # 验证集准确率
    total = 0  # 验证集样本总数
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(val_loader):
            #  输入数据
            inputs = inputs.to(device)  # 将数据移动到设备
            #  标签
            labels = labels.to(device)  # 将标签移动到设备
            # 预测
            outputs = model(inputs)
            # 损失
            loss = criterion(outputs, labels)
            # 验证集损失 累加
            val_loss += loss.item()
            # 得到预测概率最高的类别
            _, predicted = outputs.max(1)
            #   labels.size(0)  获取当前batch标签的样本数量
            total += labels.size(0)
            # 预测正确的数量 .eq()  判断两个张量是否相等,返回一个布尔张量 .sum()  把 True/False 转成 1/0,然后求和
            #  .item()  把 PyTorch 的标量张量(scalar tensor)转成 Python 数字
            correct += predicted.eq(labels).sum().item()
    # 平均验证损失
    avg_val_loss = val_loss / len(val_loader)
    # 验证集准确率
    val_acc = 100.0 * correct / total
    # 记录验证损失和准确率
    history['val_loss'].append(avg_val_loss)
    history['val_acc'].append(val_acc)

    # 打印结果
    print(f"\n   Epoch [{epoch+1:>2}/10] ")
    print(f"    🟢 Train Loss: {avg_train_loss:.4f} | Train Acc: {avg_train_acc:.2f}%")
    print(f"    🔴 Val Loss:   {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    
    # tensorboard --logdir=runs   runs为保存路径(替换为绝对路径)  集成终端打开
    writer.add_scalar("Train/Loss", avg_train_loss, epoch+1)
    writer.add_scalar("Train/Acc", avg_train_acc, epoch+1)
    writer.add_scalar("Val/Loss", avg_val_loss, epoch+1)
    writer.add_scalar("Val/Acc", val_acc, epoch+1)

    # 保存模型
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "./weight/model_self.pth")
        print("Saved best model!")

# 训练结束后,也可以打印最终的最佳准确率
print(f" Training finished. Best validation accuracy: {best_acc:.2f}%")

# 训练结束后关闭TensorBoard writer
writer.close()


# 保存history
np.save('./data/history_self.npy', history)
print("History saved!")

训练结果总结

  • 最佳验证准确率:通常可达 90%~95%(受数据划分影响略有波动)
  • 模型保存路径./weight/model_self.pth
  • 训练历史保存./data/history_self.npy,可用于绘制学习曲线

5. 模型验证

为了全面评估所构建模型的性能,本节从量化指标、分类报表、混淆矩阵三个维度对模型在验证集和测试集上的表现进行系统分析。重点考察模型的准确率、召回率、F1 分数以及各类别之间的混淆情况,从而判断其鲁棒性与泛化能力。

5.1 验证过程数据化

在模型训练完成后,需对预测结果进行结构化保存,以便后续分析与部署。本项目将验证集上的原始预测结果(包括输入信号、真实标签、预测概率、预测类别等)导出为 CSV 文件,形成结构化的验证数据集。

该文件可用于:

  • 追踪错误样本(如误判的 V 类心跳)
  • 分析模型置信度分布
  • 支持临床医生复核
  • 构建自动化评估流水线

最终生成的 Excel/CSV 文件示例如下:

在这里插入图片描述

5.2 指标报表

使用 sklearn.metrics.classification_report 生成详细的分类性能报表,包含每个类别的精确率(Precision)、召回率(Recall)、F1 分数(F1-Score)和支持样本数(Support)。

  • Precision(精确率):预测为某类的样本中,真正属于该类的比例 → 关注“预测是否可靠”
  • Recall(召回率):真实为某类的样本中,被正确识别的比例 → 关注“是否漏检”
  • F1-Score:Precision 与 Recall 的调和平均,综合反映类别识别能力
  • Support:该类在数据集中出现的次数

验证数据得到的报表

在这里插入图片描述

从图中可见:

  • N 类(正常心跳):由于样本占比高,Precision 和 Recall 均接近 1.0,模型对其识别非常稳定。
  • V 类(室性早搏):Precision 高达 100%,Recall 达到 99%,表明模型在该类上表现出色——不仅极少将其他类型误判为 V 类(高精确率),也成功捕获了绝大多数真实的室性早搏(高召回率)。这在临床应用中尤为重要,因为漏检室性早搏可能带来严重风险。

测试数据得到的报表
在这里插入图片描述

测试集报表反映了模型在“未见过”数据上的表现。整体指标略低于验证集,但仍保持较高水平(平均 F1 > 0.85),表明模型具备良好的泛化能力

5.3 混淆矩阵

混淆矩阵是评估分类模型性能的重要工具,能够直观展示各类别之间的误判模式。通过分析混淆矩阵,可以识别模型最容易混淆的类别对,进而指导后续优化方向。

混淆矩阵可视化(验证数据集),如下:
在这里插入图片描述
混淆矩阵可视化(测试数据集),如下:
在这里插入图片描述

主要观察结论

  1. 主对角线值高:说明大多数样本被正确分类,模型整体有效。
  2. N ↔ S 类之间存在少量混淆:可能由于部分 S 类心跳形态接近正常,导致边界模糊。
  3. F 类(融合波)识别效果最差:常被误判为 N 或 V 类,因其形态介于两者之间且样本稀少。

6. 模型优化

尽管基础模型已取得较好性能,但仍存在改进空间,尤其是在稀有类别(如 F 类)识别和泛化能力方面。为此,本节提出三种优化策略:增加网络深度、继续训练(微调)、引入预训练与迁移学习,以进一步提升模型表现。

6.1 增加网络深度

原始模型采用三层卷积结构,在特征提取能力上存在一定限制。为此,设计了一个更深的变体 ECGCNN_Deep,包含四个卷积块,并引入 AdaptiveAvgPool1d 保证全连接层输入维度固定。

更深的网络能够:

  • 提取更复杂的高层语义特征
  • 扩大感受野,捕捉更长程的心律上下文
  • 增强非线性表达能力

实验表明,适当加深网络可在不显著增加过拟合风险的前提下,提升对复杂心跳模式(如 S/F 类)的识别能力。

'''
模型构建(增加网络层版)
'''
# model_deep.py

import torch.nn as nn
import torch

class ECGCNN_Deep(nn.Module):
    def __init__(self, num_classes=5):
        super(ECGCNN_Deep, self).__init__()
        self.features = nn.Sequential(
            # Block 1
            # 输入通道数1,输出通道数32,卷积核大小5,步长1,填充2
            nn.Conv1d(1, 32, kernel_size=5, stride=1,padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),

            # Block 2
            nn.Conv1d(32, 64, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),

            # Block 3
            nn.Conv1d(64, 128, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),

            # Block 4
            nn.Conv1d(128, 256, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(8),    # 固定输出长度
        )

        self.classifier = nn.Sequential(
            nn.Linear(256 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes),
        )

    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
    

if __name__ == '__main__':
    # 创建数据
    X_train = torch.randn(1, 187, 1)

    # 创建模型
    model = ECGCNN_Deep(num_classes=5)

    # 测试模型
    print(model(X_train).shape)  # [1, 5]
         

6.2 继续训练

初始训练仅进行 10 个 epoch,模型可能尚未完全收敛。为进一步挖掘模型潜力,采用**小学习率继续训练(fine-tuning)**策略,在已有权重基础上再训练 50 个 epoch。

该方法的优势包括:

  • 避免从头训练的高成本
  • 利用已学习的特征基础进行精细化调整
  • 在损失平台期后实现进一步下降
'''
继续训练模型
'''
# continue_train.py

import torch
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader,TensorDataset
from model_self import ECGCNN
import os


# 加载预处理数据
X_train = np.load('./data/processed_data/X_train.npy')
y_train = np.load('./data/processed_data/y_train.npy')


# 转为PyTorch张量
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)

# print("Loaded data",X_train.shape,y_train.shape)

# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


# 设备选择
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载已有模型
model = ECGCNN(num_classes=5)
model.load_state_dict(torch.load('./weight/model_self.pth'))
model.to(device)

# 继续训练(例如再训练 50 个 epoch)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # 更小的学习率
criterion = torch.nn.CrossEntropyLoss()

new_history = {
    'train_loss': [],
    'train_acc': [],
}

for epoch in range(50):  # 继续训练 50 个 epoch
    print(f"\n=== Epoch [{epoch+1}/50] ===")

    model.train()
    train_loss = 0.0
    train_acc = 0
    train_total = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        train_total += labels.size(0)
        train_acc += predicted.eq(labels).sum().item()

    avg_train_loss = train_loss / len(train_loader)
    avg_train_acc = 100. * train_acc / train_total

    # 记录历史
    new_history['train_loss'].append(avg_train_loss)
    new_history['train_acc'].append(avg_train_acc)

    # 打印
    print(f"Train Loss: {avg_train_loss:.4f} | Acc: {avg_train_acc:.2f}%")

# 7. 保存最终模型(训练完直接保存)
os.makedirs('./weight', exist_ok=True)
torch.save(model.state_dict(), './weight/model_self_finetuned_final.pth')
print(f" 继续训练完成!最终模型已保存。")

# 8. 保存训练历史
np.save('./data/new_history_finetune.npy', new_history)
print(" 训练历史已保存。")

6.3 预训练和迁移学习

采用 ResNet18 作为基础架构,并对其进行改造以适应 1D 心电信号输入:

  • 将原始的 2D 卷积层 conv1 修改为 kernel_size=(7,1),使其能处理单通道时间序列;

  • 使用在大规模 ECG 数据集上预训练的权重(ecg_resnet18.pth)初始化模型;

  • 在目标任务上进行迁移学习:

    • 可选择 冻结主干网络(backbone),仅微调分类头(适合小数据集);

    • 或 全模型微调(fine-tune all layers),适合数据量较大时。

该方法有效利用了预训练模型提取通用心电特征的能力,提升了小样本下的分类性能。

'''
预训练模型 + 迁移学习
'''
# pretrain_translearn.py

import torch
import torch.nn as nn
import torchvision.models as models

class PretrainedResNet1D(nn.Module):
    def __init__(self, num_classes=5, freeze_backbone=False):
        super(PretrainedResNet1D, self).__init__()
        
        # 加载 ImageNet 上预训练的 ResNet18 模型
        backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        
        # 修改第一层卷积以适应单通道输入 (ECG 信号)
        backbone.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=64,
            kernel_size=(7, 1),
            stride=(2, 1),
            padding=(3, 0),
            bias=False
        )
        
        # 提取除了全连接层以外的所有层
        self.model = nn.Sequential(*list(backbone.children())[:-1])
        
        # 修改最后的全连接层
        self.fc = nn.Linear(512, num_classes)  # 假设是 resnet18
        
        # 初始化新的第一层卷积权重
        nn.init.kaiming_normal_(self.model[0].weight, mode='fan_out', nonlinearity='relu')
        
        # 如果 freeze_backbone=True,则冻结 backbone 参数
        if freeze_backbone:
            for param in self.model.parameters():
                param.requires_grad = False
            print("ResNet 卷积层已冻结,仅训练分类头。")
        else:
            print("所有层均可训练。")

    def forward(self, x):
        # 调整输入尺寸以匹配 Conv2d 的期望格式
        if x.dim() == 2:
            x = x.unsqueeze(1).unsqueeze(-1)  # (N, 187) -> (N, 1, 187, 1)
        elif x.dim() == 3:
            x = x.unsqueeze(-1)  # (N, 1, 187) -> (N, 1, 187, 1)
            
        x = self.model(x)
        x = torch.flatten(x, 1)  # 展平
        x = self.fc(x)
        return x

def create_model(freeze_backbone=False):
    model = PretrainedResNet1D(
        num_classes=5,
        freeze_backbone=freeze_backbone
    )
    return model

if __name__ == '__main__':
    # 创建模型实例
    model = create_model(freeze_backbone=True)

    # 模拟输入数据
    x = torch.randn(4, 187)

    # 前向传播
    with torch.no_grad():  # 测试时关闭梯度计算
        y = model(x)

    print("输入形状:", x.shape)  # [4, 187]
    print("输出形状:", y.shape)  # [4, 5]
    print("模型创建成功!")

7. ECG信号分类系统实现

本节构建了一个端到端的 ECG 心跳分类系统,涵盖从原始信号输入到最终可视化输出的完整流程。系统分为三个核心模块:信号预处理、模型推理、结果可视化,实现了从“数据 → 预测 → 展示”的闭环,具备临床辅助诊断系统的雏形。

7.1 信号预处理(去噪、标准化)

原始 ECG 信号常受基线漂移、肌电干扰和工频噪声影响,直接影响模型识别精度。为此,设计了标准化的预处理流程:

  • 带通滤波(0.5–40 Hz):保留典型心电信号频段,有效去除低频基线漂移和高频噪声;
  • 零相位滤波(filtfilt):避免传统滤波引入的时间延迟,确保 R 峰位置不变;
  • Z-score 标准化:将信号转换为均值为 0、标准差为 1 的分布,匹配模型训练时的数据分布。

该预处理流程简单高效,适用于实时或批量处理场景。

"""
清洗 ECG 信号
"""
# clean_ECG.py 

import numpy as np
from scipy import signal
import matplotlib.pyplot as plt

def notch_filter(ecg, notch_freq=50, fs=500, Q=30):
    """50Hz 工频陷波滤波"""
    b, a = signal.iirnotch(notch_freq, Q, fs=fs)
    return signal.filtfilt(b, a, ecg)

def bandpass_filter(ecg, low=0.5, high=150, fs=500, order=4):
    """带通滤波:保留 ECG 主要频带,去除非生理频率"""
    nyquist = 0.5 * fs
    low = low / nyquist
    high = high / nyquist
    if low >= 1.0:
        low = 0.99
    if high >= 1.0:
        high = 0.99
    b, a = signal.butter(order, [low, high], btype='band')
    return signal.filtfilt(b, a, ecg)

def normalize(ecg):
    """标准化"""
    return (ecg - np.mean(ecg)) / np.std(ecg)

def preprocess_ecg_signal(raw_signal, fs=500):
    """完整预处理流程"""
    # 步骤1:去除 50Hz 工频干扰
    cleaned = notch_filter(raw_signal, notch_freq=50, fs=fs)
    
    # 步骤2:带通滤波(保留 0.5 - 150 Hz)
    cleaned = bandpass_filter(cleaned, low=0.5, high=150, fs=fs, order=4)
    
    # 步骤3:标准化
    cleaned = normalize(cleaned)
    
    return cleaned


# ======================
# 测试代码:导入并清洗信号
# ======================
if __name__ == '__main__':
    try:
        from Noisy_ECG_Signal import generate_noisy_ecg
    except ImportError:
        raise ImportError("请确保 'Noisy_ECG_Signal.py' 与本文件在同一目录下")

    print(" 正在生成污染 ECG 信号...")
    t, ecg_clean, ecg_noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=True)

    print(" 正在清洗信号...")
    ecg_cleaned = preprocess_ecg_signal(ecg_noisy, fs=500)

    # 可视化:清洗前后对比(前 2 秒)
    t_plot = t[:1000]  # 前 2 秒 (500 * 2 = 1000)
    clean_plot = ecg_clean[:1000]
    noisy_plot = ecg_noisy[:1000]
    cleaned_plot = ecg_cleaned[:1000]

    plt.figure(figsize=(14, 8))

    plt.subplot(3, 1, 1)
    plt.plot(t_plot, clean_plot, color='blue', linewidth=1.2)
    plt.title("1. Clean ECG Signal (Ground Truth)")
    plt.ylabel("Amplitude (mV)")
    plt.grid(True, alpha=0.3)

    plt.subplot(3, 1, 2)
    plt.plot(t_plot, noisy_plot, color='red', linewidth=1.2)
    plt.title("2. Noisy ECG Signal (with Baseline, EMG, 50Hz)")
    plt.ylabel("Amplitude (mV)")
    plt.grid(True, alpha=0.3)

    plt.subplot(3, 1, 3)
    plt.plot(t_plot, cleaned_plot, color='green', linewidth=1.2)
    plt.title("3. Cleaned ECG Signal (After Preprocessing)")
    plt.xlabel("Time (s)")
    plt.ylabel("Normalized Amplitude")
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    print(" 测试完成!去噪效果如图所示。")
    print(f"   原始 SNR 估计: {10*np.log10(np.var(ecg_clean)/np.var(ecg_noisy-ecg_clean)):.2f} dB")
    print(f"   去噪后 SNR: {10*np.log10(np.var(ecg_clean)/np.var(ecg_cleaned-ecg_clean)):.2f} dB")

污染信号生成代码(Nosiy_ECG_Signal.py)

'''
生成受污染的ECG 信号
'''
# Nosiy_ECG_Signal.py
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

def generate_noisy_ecg(fs=500, duration=10, show_plot=True):
    """
    生成受基线漂移、肌电干扰和工频噪声影响的 ECG 信号
    
    参数:
        fs: 采样频率 (Hz)
        duration: 信号时长 (秒)
        show_plot: 是否显示生成过程的可视化
    
    返回:
        t: 时间轴
        ecg_clean: 干净 ECG 信号
        ecg_noisy: 污染后的 ECG 信号
    """
    # ======================
    # 参数设置
    # ======================
    t = np.linspace(0, duration, int(fs * duration), endpoint=False)  # 时间轴

    # ======================
    # 1. 生成干净 ECG 信号(简化模型:使用周期性波形模拟 P-QRS-T)
    # ======================
    def generate_ecg_clean(t, fs):
        # 心率(bpm)
        heart_rate = 75
        heart_rate_rad = 2 * np.pi * heart_rate / 60

        # R 波:周期性高斯脉冲(模拟 QRS 波群)
        r_peaks = np.sin(heart_rate_rad * t)  # 控制心跳节奏
        qrs = 2.0 * np.exp(-1000 * (t % (60 / heart_rate) - 0.02)**2)  # 高斯脉冲模拟 QRS

        # T 波:稍宽的正向波
        t_wave = 0.4 * np.exp(-200 * (t % (60 / heart_rate) - 0.15)**2)

        # P 波:小的正向波
        p_wave = 0.25 * np.exp(-400 * (t % (60 / heart_rate) - 0.0)**2)

        # 组合 ECG
        ecg_clean = p_wave + qrs + t_wave

        # 添加轻微随机变化(心跳间期变异)
        jitter = 0.01 * np.random.randn(len(t))
        ecg_clean = np.interp(t, t + jitter, ecg_clean)

        return ecg_clean

    ecg_clean = generate_ecg_clean(t, fs)

    # ======================
    # 2. 添加噪声
    # ======================

    # (1) 基线漂移(Baseline Wander): 0.1 - 0.5 Hz 的低频正弦波组合
    baseline_wander = (
        0.3 * np.sin(2 * np.pi * 0.1 * t) +
        0.2 * np.sin(2 * np.pi * 0.3 * t) +
        0.1 * np.sin(2 * np.pi * 0.5 * t)
    )

    # (2) 肌电干扰(EMG-like noise): 高频随机噪声(30-200 Hz)
    np.random.seed(42)
    emg_noise = np.random.normal(0, 0.1, len(t))
    # 用带通滤波器模拟肌电信号频带(30-200 Hz)
    b_emg, a_emg = signal.butter(4, [30, 200], btype='bandpass', fs=fs)
    emg_noise = signal.filtfilt(b_emg, a_emg, emg_noise)
    emg_noise = 0.1 * emg_noise / np.max(np.abs(emg_noise))  # 归一化并控制幅度

    # (3) 工频噪声(Power-line interference): 50 Hz(中国)或 60 Hz(美国)
    power_freq = 50  # 可改为 60
    power_noise = 0.15 * np.sin(2 * np.pi * power_freq * t)

    # ======================
    # 3. 合成污染信号
    # ======================
    ecg_noisy = ecg_clean + baseline_wander + emg_noise + power_noise

    # ======================
    # 4. 可视化(保持你原有的可视化不变)
    # ======================
    if show_plot:
        plt.figure(figsize=(14, 8))

        # 子图1:原始干净 ECG
        plt.subplot(3, 1, 1)
        plt.plot(t[:1000], ecg_clean[:1000], color='blue', linewidth=1.2)
        plt.title("Clean ECG Signal")
        plt.ylabel("Amplitude (mV)")
        plt.grid(True, alpha=0.3)

        # 子图2:添加的噪声
        plt.subplot(3, 1, 2)
        plt.plot(t[:1000], baseline_wander[:1000], label='Baseline Wander', color='orange')
        plt.plot(t[:1000], emg_noise[:1000], label='EMG Noise', color='red', alpha=0.7)
        plt.plot(t[:1000], power_noise[:1000], label='50 Hz Noise', color='purple', alpha=0.7)
        plt.title("Added Noise Components")
        plt.ylabel("Amplitude")
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 子图3:最终污染信号
        plt.subplot(3, 1, 3)
        plt.plot(t[:1000], ecg_noisy[:1000], color='red', linewidth=1.2)
        plt.title("Noisy ECG Signal (with Baseline Wander, EMG, and 50 Hz Interference)")
        plt.xlabel("Time (s)")
        plt.ylabel("Amplitude (mV)")
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

    return t, ecg_clean, ecg_noisy  # 返回信号,供其他模块使用


# ======================
# 如果直接运行此文件,则生成并显示信号
# ======================
if __name__ == "__main__":
    print(" 正在生成污染 ECG 信号...")
    t, clean, noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=True)
    print(" 信号生成完成,长度:", len(t))

污染信号生成可视化图 ,如下:
在这里插入图片描述

cleaned信号处理对比图 ,如下:
在这里插入图片描述

7.2 模型推理(PyTorch )

模型推理是分类系统的核心环节。本模块封装了模型加载与预测逻辑,支持单个心跳或批量信号输入。

  • 使用 torch.load() 加载训练好的 .pth 权重文件;
  • 通过 model.eval()torch.no_grad() 关闭梯度计算,提升推理效率;
  • 输出包含:预测类别、置信度(最大概率值)、各类别概率分布,便于后续分析与决策。
'''
模型推理
'''
# inference.py
import torch
import numpy as np
import onnxruntime as ort
import matplotlib.pyplot as plt
from scipy import signal

from model_self import ECGCNN
from clean_ECG import preprocess_ecg_signal

# 类别名称
CLASS_NAMES = ['Normal', 'Supraventricular', 'Ventricular', 'Fusion', 'Unknown']


# ======================
# 1. 加载模型
# ======================
def load_model(model_path="./weight/model_self.pth"):
    """
    加载训练好的 PyTorch 模型
    """
    model = ECGCNN(num_classes=5)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model.eval()
    print(" PyTorch 模型已加载")
    return model


def load_onnx_model(onnx_model_path="./weight/ecg_model_self.onnx"):
    """
    加载 ONNX 模型
    """
    session = ort.InferenceSession(onnx_model_path)
    print(" ONNX 模型已加载")
    return session


# ======================
# 2. 检测 R 波(用于分割心跳)
# ======================
def detect_r_peaks(ecg, fs=500):
    """
    使用简单阈值法检测 R 波(适用于干净或轻度污染信号)
    返回 R 波位置索引
    """
    # 使用带通滤波增强 QRS
    b, a = signal.butter(2, [5, 15], btype='bandpass', fs=fs)
    filtered = signal.filtfilt(b, a, ecg)

    # 简单平方 + 滑动窗能量
    squared = filtered ** 2
    window_size = int(0.1 * fs)  # 100ms 滑动窗
    smoothed = np.convolve(squared, np.ones(window_size) / window_size, mode='same')

    # 阈值检测
    threshold = 0.5 * np.max(smoothed)
    r_peaks = signal.find_peaks(smoothed, height=threshold, distance=int(0.6 * fs))[0]  # 最小间距 600ms
    return r_peaks


# ======================
# 3. 提取心跳片段(长度 187)
# ======================
def extract_beats(ecg, r_peaks, fs=500, beat_length=187):
    """
    以 R 峰为中心,前后截取心跳片段
    beat_length: 模型输入长度(如 187)
    """
    half_len = beat_length // 2
    beats = []
    valid_positions = []

    for r in r_peaks:
        start = r - half_len
        end = r + (beat_length - half_len)
        if start >= 0 and end <= len(ecg):
            beat = ecg[start:end]
            if len(beat) == beat_length:
                beats.append(beat)
                valid_positions.append(r)

    return np.array(beats), np.array(valid_positions)


# ======================
# 4. Softmax 函数
# ======================
def softmax(x, axis=-1):
    """
    Numerically stable softmax
    """
    x = x - np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


# ======================
# 5. PyTorch 批量推理
# ======================
def predict_heartbeat(model, ecg_signal):
    """
    ecg_signal: 已经去噪的 ECG 片段 (187,) 或 (N, 187)
    返回: 预测类别, 置信度, 概率分布
    """
    model.eval()
    with torch.no_grad():
        if ecg_signal.ndim == 1:
            # 单个心跳
            tensor = torch.tensor(ecg_signal, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        else:
            # 批量心跳 (N, 187) -> (N, 1, 187)
            tensor = torch.tensor(ecg_signal, dtype=torch.float32).unsqueeze(1)

        output = model(tensor)  # logits: (N, 5)
        prob = torch.softmax(output, dim=1).numpy()  # 转为 numpy 概率
        confidence = np.max(prob, axis=1)
        predicted = np.argmax(prob, axis=1)

    return predicted, confidence, prob


# ======================
# 6. ONNX 批量推理
# ======================
def predict_heartbeat_onnx(session, ecg_signal):
    """
    使用 ONNX 模型预测单个或批量心跳
    """
    if ecg_signal.ndim == 1:
        input_data = ecg_signal.reshape(1, 1, -1).astype(np.float32)  # (1, 1, 187)
    else:
        input_data = ecg_signal.reshape(-1, 1, 187).astype(np.float32)  # (N, 1, 187)

    result = session.run(["logits"], {"ecg_input": input_data})
    logits = result[0]  # (N, 5)
    prob = softmax(logits, axis=1)
    confidence = np.max(prob, axis=1)
    predicted = np.argmax(prob, axis=1)

    return predicted, confidence, prob

# ======================
# 7. 可视化函数
# ======================

def plot_ecg_comparison(t, ecg_noisy, ecg_cleaned, r_peaks, beat_positions, predictions, confidences, class_names):
    """
    绘制污染信号、干净信号、R波位置、预测结果
    """
    fig, axes = plt.subplots(3, 1, figsize=(16, 10), sharex=True)

    # 子图 1: 原始污染信号
    axes[0].plot(t, ecg_noisy, color='lightcoral', linewidth=0.8)
    axes[0].set_title("Noisy ECG Signal", fontsize=14, fontweight='bold')
    axes[0].set_ylabel("Amplitude")
    axes[0].grid(True, alpha=0.3)

    # 子图 2: 去噪后信号
    axes[1].plot(t, ecg_cleaned, color='steelblue', linewidth=1.0)
    # 标出 R 波位置
    for r in r_peaks:
        axes[1].axvline(t[r], color='red', linestyle='--', alpha=0.7)
    axes[1].set_title("Denoised ECG Signal with R-Peak Detection", fontsize=14, fontweight='bold')
    axes[1].set_ylabel("Amplitude")
    axes[1].grid(True, alpha=0.3)

    # 子图 3: 预测结果标注
    axes[2].plot(t, ecg_cleaned, color='gray', linewidth=0.8, alpha=0.8)
    
    # 颜色映射(每类不同颜色)
    colors = ['green', 'orange', 'red', 'purple', 'gray']
    for i, (pos, pred, conf) in enumerate(zip(beat_positions, predictions, confidences)):
        if pos < len(t):
            x = t[pos]
            y = ecg_cleaned[pos]
            class_name = class_names[pred]
            color = colors[pred]
            axes[2].axvline(x, color=color, alpha=0.6)
            axes[2].text(x, max(ecg_cleaned)*0.9, f'{class_name}\n{conf:.2f}',
                        color=color, fontsize=8, ha='center', rotation=90,
                        bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.2))

    axes[2].set_title("Predicted Heartbeat Types (Color-coded)", fontsize=14, fontweight='bold')
    axes[2].set_ylabel("Amplitude")
    axes[2].set_xlabel("Time (s)")
    axes[2].grid(True, alpha=0.3)

    # 图例说明
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=colors[i], label=class_names[i]) for i in range(len(class_names))]
    axes[2].legend(handles=legend_elements, bbox_to_anchor=(1.02, 1), loc='upper left')

    plt.tight_layout()
    plt.show()


def plot_prediction_confidence(predictions, confidences, class_names):
    """
    绘制预测类别和置信度柱状图
    """
    fig, ax = plt.subplots(figsize=(10, 6))
    x = np.arange(len(predictions))
    colors = ['green', 'orange', 'red', 'purple', 'gray']
    bar_colors = [colors[pred] for pred in predictions]

    bars = ax.bar(x, confidences, color=bar_colors, alpha=0.7, edgecolor='black', linewidth=0.5)
    ax.set_xlabel("Beat Index")
    ax.set_ylabel("Confidence")
    ax.set_title("Prediction Confidence per Heartbeat", fontsize=14, fontweight='bold')
    ax.set_ylim(0, 1.1)
    ax.grid(True, axis='y', alpha=0.3)

    # 在柱子上方标注类别
    for i, (bar, pred) in enumerate(zip(bars, predictions)):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                class_names[pred], ha='center', va='bottom', fontsize=9, rotation=45)

    # 图例
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=colors[i], label=class_names[i]) for i in range(len(class_names))]
    ax.legend(handles=legend_elements, title="Classes")

    plt.tight_layout()
    plt.show()

# ======================
# 主测试流程
# ======================
if __name__ == "__main__":
    # --- 1. 加载模型 ---
    model = load_model("./weight/model_self.pth")
    ort_session = load_onnx_model("./weight/ecg_model_self.onnx")

    # --- 2. 生成并清洗 ECG 信号 ---
    from Noisy_ECG_Signal import generate_noisy_ecg
    t, ecg_clean, ecg_noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=False)

    # 只清洗一次
    ecg_cleaned = preprocess_ecg_signal(ecg_noisy, fs=500)
    print(f"ECG 信号长度: {len(ecg_cleaned)}")

    # --- 3. 检测 R 波 ---
    r_peaks = detect_r_peaks(ecg_cleaned, fs=500)
    print(f"检测到 {len(r_peaks)} 个 R 波")

    # --- 4. 提取心跳 ---
    beats, beat_positions = extract_beats(ecg_cleaned, r_peaks, fs=500, beat_length=187)
    print(f"成功提取 {len(beats)} 个心跳片段")

    if len(beats) == 0:
        print(" 未提取到有效心跳片段")
    else:
        # --- 5. 批量推理 ---
        pred_torch, conf_torch, prob_torch = predict_heartbeat(model, beats)
        pred_onnx, conf_onnx, prob_onnx = predict_heartbeat_onnx(ort_session, beats)

        # --- 6. 打印结果 ---
        print("\n" + "=" * 50)
        print(" 心跳分类结果对比")
        print("=" * 50)
        for i in range(min(10, len(beats))):
            match = "✅" if pred_torch[i] == pred_onnx[i] else "❌"
            print(f"心跳 {i+1:2d}: "
                  f"[PyTorch] {CLASS_NAMES[pred_torch[i]]} ({conf_torch[i]:.3f}) | "
                  f"[ONNX] {CLASS_NAMES[pred_onnx[i]]} ({conf_onnx[i]:.3f}) {match}")

        # --- 7. 统计一致性 ---
        accuracy = np.mean(pred_torch == pred_onnx)
        print(f"\n ONNX 与 PyTorch 预测一致率: {accuracy * 100:.1f}%")

        # --- 8. 可视化 ---
        print("\n  正在生成可视化图表...")

        # 时间轴
        t = np.linspace(0, len(ecg_cleaned)/500, len(ecg_cleaned))  # fs=500

        # 绘制信号对比和预测结果
        plot_ecg_comparison(
            t=t,
            ecg_noisy=ecg_noisy,
            ecg_cleaned=ecg_cleaned,
            r_peaks=r_peaks,
            beat_positions=beat_positions,
            predictions=pred_torch,      # 使用 PyTorch 预测结果
            confidences=conf_torch,
            class_names=CLASS_NAMES
        )

        # 绘制置信度图
        plot_prediction_confidence(pred_torch, conf_torch, CLASS_NAMES)

7.3 结果可视化(画波形 + 打标签)

为了直观展示模型在真实 ECG 波形上的分类效果,开发了可视化模块。该模块将原始信号与预测结果融合呈现:

  • 在原始 ECG 曲线上标注每个 R 峰对应的心跳类型;
  • 使用颜色编码(绿色/红色)表示置信度高低(可设置阈值);
  • 添加文本框显示类别名称与置信分数,提升可读性;
  • 支持自定义采样率、标签位置、字体大小等参数。

可视化结果不仅有助于模型调试与错误分析,也为医生提供直观的辅助判读工具,增强人机协同效率。

"""
结果可视化
"""
# visualize.py
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Optional
from inference import load_model, predict_heartbeat  # 导入你需要的函数

# 使用全局类别名(也可以传参)
from inference import CLASS_NAMES

def plot_ecg_with_labels(
    ecg_signal: np.ndarray,
    r_peaks: List[int],
    model,
    class_names: List[str] = None,
    segment_length: int = 187,
    sample_rate: int = 360,
    confidence_threshold: float = 0.6,
    title: str = "ECG 信号与心跳分类结果",
    figsize: tuple = (14, 6)
):
    """
    在原始 ECG 信号上绘制波形,并为每个心跳打上预测标签

    参数:
        ecg_signal: 完整 ECG 信号 (T,)
        r_peaks: R 峰位置列表(索引)
        model: 已加载的 PyTorch 模型对象
        class_names: 类别名称列表(默认使用 inference 中的 CLASS_NAMES)
        segment_length: 每个心跳输入长度(默认 187)
        sample_rate: 采样率(Hz)
        confidence_threshold: 置信度阈值(高于绿色,低于红色)
        title: 图表标题
        figsize: 图像大小
    """
    if len(ecg_signal.shape) != 1:
        raise ValueError("ecg_signal 必须是一维信号")

    half_len = segment_length // 2
    class_names = class_names or CLASS_NAMES  # 默认使用全局类别
    predictions = []

    plt.figure(figsize=figsize)
    t = np.arange(len(ecg_signal)) / sample_rate
    plt.plot(t, ecg_signal, 'k', linewidth=0.8, label='ECG Signal')

    # 对每个 R 峰进行预测并标注
    for i, peak in enumerate(r_peaks):
        left = peak - half_len
        right = peak + half_len
        if left < 0 or right >= len(ecg_signal):
            predictions.append(None)
            continue

        # 提取单个心跳
        heartbeat = ecg_signal[left:right]

        try:
            # 使用 inference.py 中的 predict_heartbeat 函数
            pred_ids, confs, probs = predict_heartbeat(model, heartbeat)
            pred_label = pred_ids[0]  # 返回是 (1,) 数组
            confidence = confs[0]
            pred_name = class_names[pred_label]
        except Exception as e:
            print(f"第 {i} 个心跳预测失败: {e}")
            pred_name = "Error"
            confidence = 0.0

        predictions.append((pred_name, confidence))

        # 在 R 峰上方添加文本标签
        x_time = peak / sample_rate
        y_height = ecg_signal[peak] + (np.max(ecg_signal) - np.min(ecg_signal)) * 0.05
        color = 'green' if confidence > confidence_threshold else 'red'

        plt.text(
            x_time, y_height, f"{pred_name}\n({confidence:.2f})",
            fontsize=9, ha='center', va='bottom',
            color=color, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7)
        )

    plt.title(title, fontsize=14)
    plt.xlabel("时间 (秒)", fontsize=12)
    plt.ylabel("幅度 (mV)", fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    return predictions


# ========================
# 示例使用(仅用于测试)
# ========================
if __name__ == "__main__":
    import numpy as np

    print(" 正在测试 ECG 可视化模块...")

    # 模拟一段较长的 ECG 信号
    T = 3000
    sample_rate = 360
    ecg_signal = np.zeros(T)
    r_peaks = [187, 540, 910, 1280, 1650, 2020, 2400, 2780]  # R 峰位置

    # 生成模拟心跳(高斯脉冲)
    for peak in r_peaks:
        if peak < T:
            start = max(0, peak - 20)
            end = min(T, peak + 40)
            x = np.arange(start, end)
            ecg_signal[x] += np.exp(-0.05 * (x - peak)**2) * 1.8
    ecg_signal += np.random.normal(0, 0.05, T)  # 加噪声

    # 加载模型
    try:
        model = load_model("./weight/model_self.pth")
        print(" 模型加载成功")
    except Exception as e:
        print(f" 模型加载失败,请检查路径: {e}")
        exit()

    # 执行可视化
    results = plot_ecg_with_labels(
        ecg_signal=ecg_signal,
        r_peaks=r_peaks,
        model=model,
        class_names=CLASS_NAMES,
        sample_rate=sample_rate,
        confidence_threshold=0.6,
        title="ECG 心跳分类结果可视化"
    )

    # 打印结果
    print("\n 分类结果:")
    for i, res in enumerate(results):
        if res:
            print(f"心跳 {i+1}: {res[0]} (置信度: {res[1]:.3f})")
        else:
            print(f"心跳 {i+1}: 越界或分析失败")

8. 模型移植

为了提升模型的跨平台部署能力,降低对 PyTorch 框架的依赖,本项目采用 ONNX(Open Neural Network Exchange) 格式进行模型导出与推理验证。ONNX 支持多种运行时(如 ONNX Runtime、TensorRT、CoreML),可在服务器、移动端、边缘设备上高效运行,极大增强了模型的工程落地潜力。

8.1 导出ONNX

使用 torch.onnx.export 工具将训练好的 PyTorch 模型转换为 .onnx 文件。关键配置如下:

  • dummy_input:提供示例输入张量,用于追踪计算图;
  • opset_version=11:确保支持常用算子(如 Conv1d、BatchNorm);
  • dynamic_axes:允许动态 batch size,适应不同输入规模;
  • do_constant_folding=True:优化常量节点,减小模型体积并提升推理速度。

导出后可通过 Netron 等工具查看模型结构,确认节点连接正确。

在这里插入图片描述

'''
模型导出onnx
'''
# export_onnx.py

import torch
from model_self import ECGCNN  # 确保路径正确

def main():
    # 1. 定义模型结构(必须和训练时一致)
    model = ECGCNN(num_classes=5)
    
    # 2. 加载你训练好的权重
    model.load_state_dict(torch.load("./weight/model_self.pth", map_location='cpu'))
    model.eval()  # 切换到推理模式

    # 3. 构造一个示例输入(shape: batch x channel x length)
    dummy_input = torch.randn(1, 1, 187)  # 和ECG 数据 shape 一致

    # 4. 导出为 ONNX  (通过输入构造一个示例输入 进行追踪推理)
    torch.onnx.export(
        model,
        dummy_input,
        "./weight/ecg_model_self.onnx",  # 输出路径
        export_params=True,           # 保存模型参数
        opset_version=11,             # ONNX 版本兼容性
        do_constant_folding=True,     # 优化
        input_names=['ecg_input'],    # 输入名
        output_names=['logits'],  # 输出名
        dynamic_axes={
            'ecg_input': {0: 'batch_size'},
            'logits': {0: 'batch_size'}
        }  # 支持动态 batch
    )
    print(" 成功导出 ONNX 模型到:./weight/ecg_model_self.onnx")

if __name__ == "__main__":
    main()

8.2 使用ONNX推理

利用 ONNX Runtime 加载 .onnx 模型并执行推理,验证其输出是否与原始 PyTorch 模型一致。

  • onnxruntime.InferenceSession 提供跨平台高性能推理引擎;
  • 输入需按 input_names 指定的名称传入(如 'ecg_input');
  • 输出返回概率分布,取最大值作为预测结果。

经测试,ONNX 模型与 PyTorch 模型的预测结果完全一致(误差 < 1e-6),说明转换成功。同时,ONNX Runtime 在 CPU 上的推理速度更快,更适合部署在资源受限设备上。

import torch
import numpy as np
import onnxruntime as ort
from model_self import ECGCNN  
from preprocess import preprocess_ecg_signal

# 类别名称
CLASS_NAMES = ['Normal', 'Supraventricular', 'Ventricular', 'Fusion', 'Unknown'

#  ONNX 推理函数
def predict_heartbeat_onnx(session, ecg_signal):
    """
    使用 ONNX 模型预测单个或批量心跳
    """
    if ecg_signal.ndim == 1:
        input_data = ecg_signal.reshape(1, 1, -1).astype(np.float32)  # (1, 1, 187)
    else:
        input_data = ecg_signal.reshape(-1, 1, 187).astype(np.float32)  # (N, 1, 187)

    result = session.run(["logits"], {"ecg_input": input_data})
    logits = result[0]  # (N, 5)
    prob = softmax(logits, axis=1)
    confidence = np.max(prob, axis=1)
    predicted = np.argmax(prob, axis=1)

    return predicted, confidence, prob


if __name__ == "__main__":
    ort_session = load_onnx_model("./weight/ecg_model_self.onnx")
    # 生成并清洗 ECG 信号 
    from Noisy_ECG_Signal import generate_noisy_ecg
    t, ecg_clean, ecg_noisy = generate_noisy_ecg(fs=500, duration=10, show_plot=False)

    # 只清洗一次
    ecg_cleaned = preprocess_ecg_signal(ecg_noisy, fs=500)
    print(f"ECG 信号长度: {len(ecg_cleaned)}")

    # 检测 R 波 
    r_peaks = detect_r_peaks(ecg_cleaned, fs=500)
    print(f"检测到 {len(r_peaks)} 个 R 波")

    # 提取心跳 
    beats, beat_positions = extract_beats(ecg_cleaned, r_peaks, fs=500, beat_length=187)
    print(f"成功提取 {len(beats)} 个心跳片段")
               
    # 使用 ONNX 模型推理
    pred_onnx, conf_onnx, prob_onnx = predict_heartbeat_onnx(ort_session, beats)
    print(f"\n【ONNX 推理】")
    print(f"预测类别: {CLASS_NAMES[pred_onnx]}")
    print(f"置信度: {conf_onnx:.3f}")

9. 项目总结

9.1 问题及解决办法

在进行基于CNN实现心律失常(ECG)的小颗粒度分类时,可能会遇到以下问题及解决办法:

问题 解决方案
类别严重不平衡(N类占80%) 使用 SMOTE 过采样平衡训练集
模型过拟合 添加 Dropout、BatchNorm
输入维度不匹配 reshape 为 (N, 187, 1) 适配 1D-CNN
推理时预处理不一致 保存 StandardScaler 并在推理时复用

9.2 收获

  • 掌握了 1D-CNN 在时间序列分类中的应用
  • 学会了使用 SMOTE 解决类别不平衡问题
  • 实践了从训练到部署的完整流程(PyTorch → ONNX)
  • 提升了模型可视化与可解释性能力

网站公告

今日签到

点亮在社区的每一天
去签到