基于RNN模型的心脏病预测,提供tensorflow和pytorch实现

发布于:2025-02-10 ⋅ 阅读:(65) ⋅ 点赞:(0)

前言

1、数据处理

1、导入库

import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
from torch.utils.data import DataLoader, TensorDataset
import torch 


device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
'cuda'

2、导入数据

data = pd.read_csv('./heart.csv')

data.head()
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
0 63 1 3 145 233 1 0 150 0 2.3 0 0 1 1
1 37 1 2 130 250 0 1 187 0 3.5 0 0 2 1
2 41 0 1 130 204 0 0 172 0 1.4 2 0 2 1
3 56 1 1 120 236 0 1 178 0 0.8 2 0 2 1
4 57 0 0 120 354 0 1 163 1 0.6 2 0 2 1
  • age - 年龄
  • sex - (1 = male(男性); 0 = (女性))
  • cp - chest pain type(胸部疼痛类型)(1:典型的心绞痛-typical,2:非典型心绞痛-atypical,3:没有心绞痛-non-anginal,4:无症状-asymptomatic)
  • trestbps - 静息血压 (in mm Hg on admission to the hospital)
  • chol - 胆固醇 in mg/dl
  • fbs - (空腹血糖 > 120 mg/dl) (1 = true; 0 = false)
  • restecg - 静息心电图测量(0:普通,1:ST-T波异常,2:可能左心室肥大)
  • thalach - 最高心跳率
  • exang - 运动诱发心绞痛 (1 = yes; 0 = no)
  • oldpeak - 运动相对于休息引起的ST抑制
  • slope - 运动ST段的峰值斜率(1:上坡-upsloping,2:平的-flat,3:下坡-downsloping)
  • ca - 主要血管数目(0-4)
  • thal - 一种叫做地中海贫血的血液疾病(3 = normal; 6 = 固定的缺陷-fixed defect; 7 = 可逆的缺陷-reversable defect)
  • target - 是否患病 (1=yes, 0=no)

3、数据分析

数据初步分析
data.info()   # 数据类型分析
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 303 entries, 0 to 302
Data columns (total 14 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   age       303 non-null    int64  
 1   sex       303 non-null    int64  
 2   cp        303 non-null    int64  
 3   trestbps  303 non-null    int64  
 4   chol      303 non-null    int64  
 5   fbs       303 non-null    int64  
 6   restecg   303 non-null    int64  
 7   thalach   303 non-null    int64  
 8   exang     303 non-null    int64  
 9   oldpeak   303 non-null    float64
 10  slope     303 non-null    int64  
 11  ca        303 non-null    int64  
 12  thal      303 non-null    int64  
 13  target    303 non-null    int64  
dtypes: float64(1), int64(13)
memory usage: 33.3 KB

其中分类变量为:sex、cp、fbs、restecg、exang、slope、ca、thal、target

数值型变量:age、trestbps、chol、thalach、oldpeak

data.describe()  # 描述性
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
count 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000
mean 54.366337 0.683168 0.966997 131.623762 246.264026 0.148515 0.528053 149.646865 0.326733 1.039604 1.399340 0.729373 2.313531 0.544554
std 9.082101 0.466011 1.032052 17.538143 51.830751 0.356198 0.525860 22.905161 0.469794 1.161075 0.616226 1.022606 0.612277 0.498835
min 29.000000 0.000000 0.000000 94.000000 126.000000 0.000000 0.000000 71.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
25% 47.500000 0.000000 0.000000 120.000000 211.000000 0.000000 0.000000 133.500000 0.000000 0.000000 1.000000 0.000000 2.000000 0.000000
50% 55.000000 1.000000 1.000000 130.000000 240.000000 0.000000 1.000000 153.000000 0.000000 0.800000 1.000000 0.000000 2.000000 1.000000
75% 61.000000 1.000000 2.000000 140.000000 274.500000 0.000000 1.000000 166.000000 1.000000 1.600000 2.000000 1.000000 3.000000 1.000000
max 77.000000 1.000000 3.000000 200.000000 564.000000 1.000000 2.000000 202.000000 1.000000 6.200000 2.000000 4.000000 3.000000 1.000000
  • 年纪:均值54,中位数55,标准差9,说明主要是老年人,偏大
  • 静息血压:均值131.62, 成年人一般:正常血压:收缩压 < 120 mmHg,偏大
  • 胆固醇:均值246.26,理想水平:小于 200 mg/dL,偏大
  • 最高心率:均值149.64,一般静息状态下通常是 60 到 100 次每分钟,偏大

最大值和最小值都可能发生,无异常值

缺失值
data.isnull().sum()
age         0
sex         0
cp          0
trestbps    0
chol        0
fbs         0
restecg     0
thalach     0
exang       0
oldpeak     0
slope       0
ca          0
thal        0
target      0
dtype: int64
相关性分析
import seaborn as sns

plt.figure(figsize=(20, 15))

sns.heatmap(data.corr(), annot=True, cmap='Greens')

plt.show()


在这里插入图片描述

相关系数的等级划分

  • 非常弱的相关性:
    • 0.00 至 0.19 或 -0.00 至 -0.19
    • 解释:几乎不存在线性关系。
  • 弱相关性:
    • 0.20 至 0.39 或 -0.20 至 -0.39
    • 解释:存在一定的线性关系,但较弱。
  • 中等相关性:
    • 0.40 至 0.59 或 -0.40 至 -0.59
    • 解释:有明显的线性关系,但不是特别强。
  • 强相关性:
    • 0.60 至 0.79 或 -0.60 至 -0.79
    • 解释:两个变量之间有较强的线性关系。
  • 非常强的相关性:
    • 0.80 至 1.00 或 -0.80 至 -1.00
    • 解释:几乎完全线性相关,表明两个变量的变化高度一致。

target与chol、没有什么相关性,fbs是分类变量,chol胆固醇是数值型变量,但是从实际角度,这些都有影响,故不剔除特征

4、数据划分

这里先划分为:训练集:测试集 = 9:1

from sklearn.model_selection import train_test_split

X = data.iloc[:, :-1]
y = data.iloc[:, -1]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

5、数据标准化

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 深度学习、用rnn模型,数据需要3通道,在图片中表示RGB,这里表示1
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)

6、转化为张量数据

# 假设  y_train, y_test 是 pandas Series 或 DataFrame
# 首先将它们转换为 NumPy 数组
y_train = y_train.values.astype(np.float32)
y_test = y_test.values.astype(np.float32)

batch_size = 32

# unsqueeze  (N,) 转换为 (N, 1)
train_dataset = TensorDataset(torch.tensor(X_train, dtype=torch.float32).to(device), torch.tensor(y_train, dtype=torch.float32).unsqueeze(1).to(device)) 
test_dataset = TensorDataset(torch.tensor(X_test, dtype=torch.float32).to(device), torch.tensor(y_test, dtype=torch.float32).unsqueeze(1).to(device))

train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

2、创建模型

  • 定义一个RNN层
    rnn = nn.RNN(input_size=10, hidden_size=20, num_layers=2, nonlinearity=‘tanh’,
    bias=True, batch_first=False, dropout=0, bidirectional=False)
  • input_size: 输入的特征维度
  • hidden_size: 隐藏层的特征维度
  • num_layers: RNN 层的数量
  • nonlinearity: 非线性激活函数 (‘tanh’ 或 ‘relu’)
  • bias: 如果为 False,则内部不含偏置项,默认为 True
  • batch_first: 如果为 True,则输入和输出张量提供为 (batch, seq, feature),默认为 False (seq, batch, feature)
  • dropout: 如果非零,则除了最后一层,在每层的输出中引入一个 Dropout 层,默认为 0
  • bidirectional: 如果为 True,则将成为双向 RNN,默认为 False
import torch  
import torch.nn as nn 

# 创建模型
'''
该问题本质是二分类问题,故最后一层全连接层用激活函数为:sigmoid
模型结构:
    RNN:隐藏层200,激活函数:relu
    Linear:--> 100(relu) -> 1(sigmoid)
'''
# 创建模型
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 在 Keras 中 input_shape=(13, 1) 表示的是, 每个样本有 13 个时间步(seq_length=13),每个时间步有一个特征(input_size=1), 换句话就是一行
        self.rnn = nn.RNN(input_size=1, hidden_size=200, num_layers=1, nonlinearity='relu', batch_first=True)
        
        self.fc1 = nn.Linear(200, 100)
        self.fc2 = nn.Linear(100, 1)
        
    def forward(self, x):
        # 初始化隐藏层状态
        h0 = torch.zeros(1, x.size(0), 200).to(device)  # (num_layers, batch_size, hidden_size)
        # 构建神经网络
        x, _ = self.rnn(x, h0)  # x: (batch_size, seq_length, hidden_size)
        x = x[:, -1, :] # 最后一个时间步作为全连接层的输入, 形状变为:(batch_size, input_size)
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x
    

model = Model().to(device)
        

3、模型训练

1、设置超参数

loss_fn = nn.BCELoss()
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

2、设置训练函数

def train(dataloader, model, loss_fn, optimizer):
    # 总大小
    size = len(dataloader.dataset)
    # 批次数量
    num_batches = len(dataloader)

    # 准确率和损失初始化
    correct = 0
    running_loss = 0.0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # 模型预测与误差评分
        pred = model(X).squeeze()  # 去除多余的维度以匹配目标形状
        if y.dim() == 2:  # 如果目标形状是 [batch_size, 1]
            y = y.squeeze()  # 将其转换为 [batch_size]
        loss = loss_fn(pred, y)  # 确保目标形状匹配
        
        # 梯度清零
        optimizer.zero_grad()
        
        # 反向传播与梯度更新
        loss.backward()
        optimizer.step()

        # 记录损失
        running_loss += loss.item()

        # 计算准确率, 二分类和多分类不同
        predicted_labels = (pred > 0.5).float()  # 使用 0.5 作为阈值
        correct += (predicted_labels == y).type(torch.float64).sum().item()

    # 计算平均损失和准确率
    train_loss = running_loss / num_batches
    train_acc = correct / size  

    return train_acc, train_loss

3、设置测试函数

def test(dataloader, model, loss_fn):
    # 总大小
    size = len(dataloader.dataset)
    # 批次数量
    num_batches = len(dataloader)

    # 准确率和损失初始化
    correct = 0
    running_loss = 0.0

    # 将模型设置为评估模式
    model.eval()

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            # 模型预测与误差评分
            pred = model(X).squeeze()  # 去除多余的维度以匹配目标形状
            if y.dim() == 2:  # 如果目标形状是 [batch_size, 1]
                y = y.squeeze()  # 将其转换为 [batch_size]
            loss = loss_fn(pred, y)  # 确保目标形状匹配

            # 记录损失
            running_loss += loss.item()

            # 计算准确率
            predicted_labels = (pred > 0.5).float()  # 使用 0.5 作为阈值
            correct += (predicted_labels == y).type(torch.float64).sum().item()

    # 计算平均损失和准确率
    test_loss = running_loss / num_batches
    test_acc = correct / size  # 转换为百分比

    return test_acc, test_loss

4、模型训练

train_acc = []
train_loss = []
test_acc = []
test_loss = []

# 定义训练次数
epoches = 100

for epoch in range(epoches):
    # 训练
    model.train()
    epoch_trian_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    
    # 测试
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    # 记录
    train_acc.append(epoch_trian_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_trian_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))

Epoch: 1, Train_acc:71.0%, Train_loss:0.689, Test_acc:71.0%, Test_loss:0.688
Epoch: 2, Train_acc:75.0%, Train_loss:0.686, Test_acc:71.0%, Test_loss:0.685
Epoch: 3, Train_acc:75.0%, Train_loss:0.682, Test_acc:71.0%, Test_loss:0.682
Epoch: 4, Train_acc:75.0%, Train_loss:0.678, Test_acc:71.0%, Test_loss:0.678
Epoch: 5, Train_acc:75.0%, Train_loss:0.673, Test_acc:71.0%, Test_loss:0.674
Epoch: 6, Train_acc:75.0%, Train_loss:0.669, Test_acc:71.0%, Test_loss:0.670
Epoch: 7, Train_acc:75.0%, Train_loss:0.661, Test_acc:71.0%, Test_loss:0.664
Epoch: 8, Train_acc:75.4%, Train_loss:0.657, Test_acc:67.7%, Test_loss:0.656
Epoch: 9, Train_acc:77.2%, Train_loss:0.644, Test_acc:67.7%, Test_loss:0.647
Epoch:10, Train_acc:77.6%, Train_loss:0.635, Test_acc:71.0%, Test_loss:0.632
Epoch:11, Train_acc:79.0%, Train_loss:0.615, Test_acc:74.2%, Test_loss:0.613
Epoch:12, Train_acc:79.0%, Train_loss:0.592, Test_acc:77.4%, Test_loss:0.585
Epoch:13, Train_acc:80.5%, Train_loss:0.559, Test_acc:77.4%, Test_loss:0.557
Epoch:14, Train_acc:77.9%, Train_loss:0.536, Test_acc:77.4%, Test_loss:0.532
Epoch:15, Train_acc:78.7%, Train_loss:0.508, Test_acc:74.2%, Test_loss:0.520
Epoch:16, Train_acc:77.9%, Train_loss:0.490, Test_acc:77.4%, Test_loss:0.510
Epoch:17, Train_acc:79.4%, Train_loss:0.482, Test_acc:74.2%, Test_loss:0.510
Epoch:18, Train_acc:79.0%, Train_loss:0.459, Test_acc:74.2%, Test_loss:0.505
Epoch:19, Train_acc:80.9%, Train_loss:0.440, Test_acc:74.2%, Test_loss:0.513
Epoch:20, Train_acc:79.8%, Train_loss:0.426, Test_acc:74.2%, Test_loss:0.522
Epoch:21, Train_acc:78.7%, Train_loss:0.424, Test_acc:74.2%, Test_loss:0.529
Epoch:22, Train_acc:77.6%, Train_loss:0.447, Test_acc:71.0%, Test_loss:0.538
Epoch:23, Train_acc:79.0%, Train_loss:0.441, Test_acc:74.2%, Test_loss:0.553
Epoch:24, Train_acc:80.5%, Train_loss:0.400, Test_acc:74.2%, Test_loss:0.517
Epoch:25, Train_acc:80.9%, Train_loss:0.421, Test_acc:74.2%, Test_loss:0.522
Epoch:26, Train_acc:80.1%, Train_loss:0.396, Test_acc:77.4%, Test_loss:0.539
Epoch:27, Train_acc:79.8%, Train_loss:0.393, Test_acc:77.4%, Test_loss:0.525
Epoch:28, Train_acc:81.2%, Train_loss:0.390, Test_acc:77.4%, Test_loss:0.524
Epoch:29, Train_acc:80.1%, Train_loss:0.378, Test_acc:77.4%, Test_loss:0.543
Epoch:30, Train_acc:80.1%, Train_loss:0.384, Test_acc:80.6%, Test_loss:0.521
Epoch:31, Train_acc:82.0%, Train_loss:0.392, Test_acc:77.4%, Test_loss:0.534
Epoch:32, Train_acc:81.6%, Train_loss:0.371, Test_acc:77.4%, Test_loss:0.513
Epoch:33, Train_acc:83.5%, Train_loss:0.376, Test_acc:77.4%, Test_loss:0.526
Epoch:34, Train_acc:81.6%, Train_loss:0.365, Test_acc:80.6%, Test_loss:0.511
Epoch:35, Train_acc:82.0%, Train_loss:0.383, Test_acc:77.4%, Test_loss:0.521
Epoch:36, Train_acc:83.8%, Train_loss:0.362, Test_acc:80.6%, Test_loss:0.513
Epoch:37, Train_acc:83.8%, Train_loss:0.357, Test_acc:80.6%, Test_loss:0.511
Epoch:38, Train_acc:84.2%, Train_loss:0.360, Test_acc:80.6%, Test_loss:0.511
Epoch:39, Train_acc:84.2%, Train_loss:0.354, Test_acc:80.6%, Test_loss:0.503
Epoch:40, Train_acc:84.9%, Train_loss:0.349, Test_acc:80.6%, Test_loss:0.512
Epoch:41, Train_acc:84.6%, Train_loss:0.371, Test_acc:80.6%, Test_loss:0.503
Epoch:42, Train_acc:84.6%, Train_loss:0.338, Test_acc:80.6%, Test_loss:0.510
Epoch:43, Train_acc:83.5%, Train_loss:0.353, Test_acc:80.6%, Test_loss:0.503
Epoch:44, Train_acc:83.8%, Train_loss:0.351, Test_acc:80.6%, Test_loss:0.500
Epoch:45, Train_acc:84.6%, Train_loss:0.339, Test_acc:80.6%, Test_loss:0.505
Epoch:46, Train_acc:85.7%, Train_loss:0.336, Test_acc:80.6%, Test_loss:0.500
Epoch:47, Train_acc:84.6%, Train_loss:0.358, Test_acc:80.6%, Test_loss:0.503
Epoch:48, Train_acc:84.9%, Train_loss:0.337, Test_acc:80.6%, Test_loss:0.513
Epoch:49, Train_acc:86.0%, Train_loss:0.334, Test_acc:80.6%, Test_loss:0.497
Epoch:50, Train_acc:85.3%, Train_loss:0.341, Test_acc:77.4%, Test_loss:0.513
Epoch:51, Train_acc:84.9%, Train_loss:0.337, Test_acc:80.6%, Test_loss:0.498
Epoch:52, Train_acc:84.9%, Train_loss:0.340, Test_acc:80.6%, Test_loss:0.499
Epoch:53, Train_acc:86.4%, Train_loss:0.328, Test_acc:80.6%, Test_loss:0.497
Epoch:54, Train_acc:84.9%, Train_loss:0.331, Test_acc:80.6%, Test_loss:0.502
Epoch:55, Train_acc:84.2%, Train_loss:0.343, Test_acc:77.4%, Test_loss:0.521
Epoch:56, Train_acc:84.6%, Train_loss:0.346, Test_acc:80.6%, Test_loss:0.486
Epoch:57, Train_acc:85.3%, Train_loss:0.351, Test_acc:77.4%, Test_loss:0.506
Epoch:58, Train_acc:85.7%, Train_loss:0.317, Test_acc:80.6%, Test_loss:0.491
Epoch:59, Train_acc:84.9%, Train_loss:0.327, Test_acc:77.4%, Test_loss:0.502
Epoch:60, Train_acc:86.0%, Train_loss:0.321, Test_acc:80.6%, Test_loss:0.503
Epoch:61, Train_acc:87.1%, Train_loss:0.340, Test_acc:80.6%, Test_loss:0.498
Epoch:62, Train_acc:85.3%, Train_loss:0.319, Test_acc:77.4%, Test_loss:0.501
Epoch:63, Train_acc:86.0%, Train_loss:0.317, Test_acc:77.4%, Test_loss:0.503
Epoch:64, Train_acc:86.4%, Train_loss:0.315, Test_acc:80.6%, Test_loss:0.493
Epoch:65, Train_acc:86.0%, Train_loss:0.323, Test_acc:80.6%, Test_loss:0.499
Epoch:66, Train_acc:86.8%, Train_loss:0.322, Test_acc:77.4%, Test_loss:0.518
Epoch:67, Train_acc:87.1%, Train_loss:0.308, Test_acc:80.6%, Test_loss:0.494
Epoch:68, Train_acc:86.8%, Train_loss:0.335, Test_acc:80.6%, Test_loss:0.507
Epoch:69, Train_acc:86.4%, Train_loss:0.307, Test_acc:80.6%, Test_loss:0.499
Epoch:70, Train_acc:86.4%, Train_loss:0.306, Test_acc:80.6%, Test_loss:0.505
Epoch:71, Train_acc:86.0%, Train_loss:0.314, Test_acc:77.4%, Test_loss:0.510
Epoch:72, Train_acc:86.8%, Train_loss:0.315, Test_acc:80.6%, Test_loss:0.495
Epoch:73, Train_acc:86.0%, Train_loss:0.311, Test_acc:77.4%, Test_loss:0.507
Epoch:74, Train_acc:86.8%, Train_loss:0.308, Test_acc:77.4%, Test_loss:0.512
Epoch:75, Train_acc:86.0%, Train_loss:0.316, Test_acc:80.6%, Test_loss:0.497
Epoch:76, Train_acc:85.7%, Train_loss:0.311, Test_acc:80.6%, Test_loss:0.504
Epoch:77, Train_acc:86.8%, Train_loss:0.307, Test_acc:77.4%, Test_loss:0.505
Epoch:78, Train_acc:86.4%, Train_loss:0.303, Test_acc:77.4%, Test_loss:0.508
Epoch:79, Train_acc:87.5%, Train_loss:0.296, Test_acc:80.6%, Test_loss:0.507
Epoch:80, Train_acc:87.1%, Train_loss:0.310, Test_acc:80.6%, Test_loss:0.508
Epoch:81, Train_acc:87.5%, Train_loss:0.297, Test_acc:77.4%, Test_loss:0.503
Epoch:82, Train_acc:87.5%, Train_loss:0.288, Test_acc:77.4%, Test_loss:0.527
Epoch:83, Train_acc:87.1%, Train_loss:0.293, Test_acc:80.6%, Test_loss:0.502
Epoch:84, Train_acc:87.1%, Train_loss:0.295, Test_acc:80.6%, Test_loss:0.508
Epoch:85, Train_acc:87.1%, Train_loss:0.283, Test_acc:80.6%, Test_loss:0.509
Epoch:86, Train_acc:87.1%, Train_loss:0.282, Test_acc:77.4%, Test_loss:0.514
Epoch:87, Train_acc:87.5%, Train_loss:0.278, Test_acc:80.6%, Test_loss:0.511
Epoch:88, Train_acc:87.5%, Train_loss:0.287, Test_acc:80.6%, Test_loss:0.513
Epoch:89, Train_acc:88.6%, Train_loss:0.308, Test_acc:77.4%, Test_loss:0.521
Epoch:90, Train_acc:87.9%, Train_loss:0.296, Test_acc:80.6%, Test_loss:0.512
Epoch:91, Train_acc:87.5%, Train_loss:0.287, Test_acc:77.4%, Test_loss:0.522
Epoch:92, Train_acc:87.5%, Train_loss:0.285, Test_acc:80.6%, Test_loss:0.512
Epoch:93, Train_acc:87.9%, Train_loss:0.287, Test_acc:80.6%, Test_loss:0.512
Epoch:94, Train_acc:88.2%, Train_loss:0.280, Test_acc:77.4%, Test_loss:0.530
Epoch:95, Train_acc:88.6%, Train_loss:0.283, Test_acc:80.6%, Test_loss:0.512
Epoch:96, Train_acc:89.3%, Train_loss:0.280, Test_acc:80.6%, Test_loss:0.516
Epoch:97, Train_acc:87.9%, Train_loss:0.276, Test_acc:77.4%, Test_loss:0.514
Epoch:98, Train_acc:88.6%, Train_loss:0.270, Test_acc:77.4%, Test_loss:0.526
Epoch:99, Train_acc:89.0%, Train_loss:0.269, Test_acc:80.6%, Test_loss:0.517
Epoch:100, Train_acc:88.6%, Train_loss:0.266, Test_acc:80.6%, Test_loss:0.521

5、结果展示

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epoch_length = range(epoches)

plt.figure(figsize=(12, 3))

plt.subplot(1, 2, 1)
plt.plot(epoch_length, train_acc, label='Train Accuaray')
plt.plot(epoch_length, test_acc, label='Test Accuaray')
plt.legend(loc='lower right')
plt.title('Accurary')

plt.subplot(1, 2, 2)
plt.plot(epoch_length, train_loss, label='Train Loss')
plt.plot(epoch_length, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Loss')

plt.show()


在这里插入图片描述

测试集表现不是很理想,合理尝试变化不同的批次,会有不同效果

6、模型评估

# 评估:返回的是自己在model.compile中设置,这里为accuracy
test_acc, test_loss = test(test_dl, model, loss_fn)
print("socre[loss, accuracy]: ", test_acc, test_loss) # 返回为两个,一个是loss,一个是accuracy

socre[loss, accuracy]:  0.8064516129032258 0.5212066173553467


网站公告

今日签到

点亮在社区的每一天
去签到