基于RNN模型的心脏病预测(tensorflow实现)

发布于:2025-02-10 ⋅ 阅读:(74) ⋅ 点赞:(0)

前言

1、数据处理

1、导入库

import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 

2、导入数据

data = pd.read_csv('./heart.csv')

data.head()
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
0 63 1 3 145 233 1 0 150 0 2.3 0 0 1 1
1 37 1 2 130 250 0 1 187 0 3.5 0 0 2 1
2 41 0 1 130 204 0 0 172 0 1.4 2 0 2 1
3 56 1 1 120 236 0 1 178 0 0.8 2 0 2 1
4 57 0 0 120 354 0 1 163 1 0.6 2 0 2 1
  • age - 年龄
  • sex - (1 = male(男性); 0 = (女性))
  • cp - chest pain type(胸部疼痛类型)(1:典型的心绞痛-typical,2:非典型心绞痛-atypical,3:没有心绞痛-non-anginal,4:无症状-asymptomatic)
  • trestbps - 静息血压 (in mm Hg on admission to the hospital)
  • chol - 胆固醇 in mg/dl
  • fbs - (空腹血糖 > 120 mg/dl) (1 = true; 0 = false)
  • restecg - 静息心电图测量(0:普通,1:ST-T波异常,2:可能左心室肥大)
  • thalach - 最高心跳率
  • exang - 运动诱发心绞痛 (1 = yes; 0 = no)
  • oldpeak - 运动相对于休息引起的ST抑制
  • slope - 运动ST段的峰值斜率(1:上坡-upsloping,2:平的-flat,3:下坡-downsloping)
  • ca - 主要血管数目(0-4)
  • thal - 一种叫做地中海贫血的血液疾病(3 = normal; 6 = 固定的缺陷-fixed defect; 7 = 可逆的缺陷-reversable defect)
  • target - 是否患病 (1=yes, 0=no)

3、数据分析

数据初步分析
data.info()   # 数据类型分析
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 303 entries, 0 to 302
Data columns (total 14 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   age       303 non-null    int64  
 1   sex       303 non-null    int64  
 2   cp        303 non-null    int64  
 3   trestbps  303 non-null    int64  
 4   chol      303 non-null    int64  
 5   fbs       303 non-null    int64  
 6   restecg   303 non-null    int64  
 7   thalach   303 non-null    int64  
 8   exang     303 non-null    int64  
 9   oldpeak   303 non-null    float64
 10  slope     303 non-null    int64  
 11  ca        303 non-null    int64  
 12  thal      303 non-null    int64  
 13  target    303 non-null    int64  
dtypes: float64(1), int64(13)
memory usage: 33.3 KB

其中分类变量为:sex、cp、fbs、restecg、exang、slope、ca、thal、target

数值型变量:age、trestbps、chol、thalach、oldpeak

data.describe()  # 描述性
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
count 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000
mean 54.366337 0.683168 0.966997 131.623762 246.264026 0.148515 0.528053 149.646865 0.326733 1.039604 1.399340 0.729373 2.313531 0.544554
std 9.082101 0.466011 1.032052 17.538143 51.830751 0.356198 0.525860 22.905161 0.469794 1.161075 0.616226 1.022606 0.612277 0.498835
min 29.000000 0.000000 0.000000 94.000000 126.000000 0.000000 0.000000 71.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
25% 47.500000 0.000000 0.000000 120.000000 211.000000 0.000000 0.000000 133.500000 0.000000 0.000000 1.000000 0.000000 2.000000 0.000000
50% 55.000000 1.000000 1.000000 130.000000 240.000000 0.000000 1.000000 153.000000 0.000000 0.800000 1.000000 0.000000 2.000000 1.000000
75% 61.000000 1.000000 2.000000 140.000000 274.500000 0.000000 1.000000 166.000000 1.000000 1.600000 2.000000 1.000000 3.000000 1.000000
max 77.000000 1.000000 3.000000 200.000000 564.000000 1.000000 2.000000 202.000000 1.000000 6.200000 2.000000 4.000000 3.000000 1.000000
  • 年纪:均值54,中位数55,标准差9,说明主要是老年人,偏大
  • 静息血压:均值131.62, 成年人一般:正常血压:收缩压 < 120 mmHg,偏大
  • 胆固醇:均值246.26,理想水平:小于 200 mg/dL,偏大
  • 最高心率:均值149.64,一般静息状态下通常是 60 到 100 次每分钟,偏大

最大值和最小值都可能发生,无异常值

缺失值
data.isnull().sum()
age         0
sex         0
cp          0
trestbps    0
chol        0
fbs         0
restecg     0
thalach     0
exang       0
oldpeak     0
slope       0
ca          0
thal        0
target      0
dtype: int64
相关性分析
import seaborn as sns

plt.figure(figsize=(20, 15))

sns.heatmap(data.corr(), annot=True, cmap='Greens')

plt.show()


在这里插入图片描述

相关系数的等级划分

  • 非常弱的相关性:
    • 0.00 至 0.19 或 -0.00 至 -0.19
    • 解释:几乎不存在线性关系。
  • 弱相关性:
    • 0.20 至 0.39 或 -0.20 至 -0.39
    • 解释:存在一定的线性关系,但较弱。
  • 中等相关性:
    • 0.40 至 0.59 或 -0.40 至 -0.59
    • 解释:有明显的线性关系,但不是特别强。
  • 强相关性:
    • 0.60 至 0.79 或 -0.60 至 -0.79
    • 解释:两个变量之间有较强的线性关系。
  • 非常强的相关性:
    • 0.80 至 1.00 或 -0.80 至 -1.00
    • 解释:几乎完全线性相关,表明两个变量的变化高度一致。

target与chol、没有什么相关性,fbs是分类变量,chol胆固醇是数值型变量,但是从实际角度,这些都有影响,故不剔除特征

4、数据划分

这里先划分为:训练集:测试集 = 9:1

from sklearn.model_selection import train_test_split

X = data.iloc[:, :-1]
y = data.iloc[:, -1]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

5、数据标准化

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 深度学习、用rnn模型,数据需要3通道,在图片中表示RGB,这里表示1
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)

2、创建模型

  • RNN的API
    • tf.keras.layers.SimpleRNN(units,activation=‘tanh’,use_bias=True,kernel_initializer=‘glorot_uniform’,
      recurrent_initializer=‘orthogonal’,bias_initializer=‘zeros’,kernel_regularizer=None,recurrent_regularizer=None,bias_regularizer=None,activity_regularizer=None,kernel_constraint=None,recurrent_constraint=None,
      bias_constraint=None,dropout=0.0,recurrent_dropout=0.0,return_sequences=False,return_state=False,
      go_backwards=False,stateful=False,unroll=False,**kwargs)
  • 参数
    • units: 正整数,表示该层输出空间的维度,即隐藏状态的大小。
    • activation: 激活函数,默认是 ‘tanh’。可以使用其他激活函数如 ‘relu’ 或自定义激活函数。
    • input_shape=(13, 1): 指定输入数据的形状。对于这个 RNN 层,每个样本包含长度为 13 的时间序列,每个时间步有一个特征(即每个时间点的数据维度是 1)。请注意,当你在模型的第一层使用 input_shape 参数时,你不需要指定批量大小(batch size),它默认是 None,意味着批量大小可以是任意值。
    • use_bias: 布尔值,默认为 True,指示是否使用偏置向量。
    • kernel_initializer: 权重矩阵的初始化方法,默认是 ‘glorot_uniform’。
    • recurrent_initializer: 循环核的初始化方法,默认是 ‘orthogonal’。
    • bias_initializer: 偏置向量的初始化方法,默认是 ‘zeros’。
    • kernel_regularizer: 权重矩阵的正则化方法。
    • recurrent_regularizer: 循环核的正则化方法。
    • bias_regularizer: 偏置向量的正则化方法。
    • activity_regularizer: 输出的正则化方法。
    • kernel_constraint: 对权重矩阵施加约束的方法。
    • recurrent_constraint: 对循环核施加约束的方法。
    • bias_constraint: 对偏置向量施加约束的方法。
    • dropout: 浮点数,介于 0 和 1 之间,指输入单元的丢弃比例,默认为 0.0。
    • recurrent_dropout: 浮点数,介于 0 和 1 之间,指循环状态的丢弃比例,默认为 0.0。
    • return_sequences: 布尔值,默认为 False。如果设置为 True,则整个序列会被返回;否则,只返回最后一个输出。
    • return_state: 布尔值,默认为 False。如果设置为 True,则除了输出外还会返回最后一个状态。
    • go_backwards: 布尔值,默认为 False。如果设置为 True,则会反向处理输入序列。
    • stateful: 布尔值,默认为 False。如果设置为 True,则批次间的状态会被保留下来。
    • unroll: 布尔值,默认为 False。如果设置为 True,则网络将被展开。当且仅当输入序列长度有限时适用,可以加速计算但占用更多内存。
import tensorflow as tf  
from tensorflow.keras.models import Sequential 
from tensorflow.keras.layers import SimpleRNN, Dense

# 创建模型
'''
该问题本质是二分类问题,故最后一层全连接层用激活函数为:sigmoid
模型结构:
    RNN:隐藏层200,激活函数:relu
    Dense:--> 100(relu) -> 1(sigmoid)
'''
# 创建模型
model = Sequential()
model.add(SimpleRNN(units=200, input_shape=(13, 1), activation='relu'))
model.add(Dense(100, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 simple_rnn (SimpleRNN)      (None, 200)               40400     
                                                                 
 dense (Dense)               (None, 100)               20100     
                                                                 
 dense_1 (Dense)             (None, 1)                 101       
                                                                 
=================================================================
Total params: 60,601
Trainable params: 60,601
Non-trainable params: 0
_________________________________________________________________

3、设置超参数

opt = tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(
    optimizer=opt,
    loss='binary_crossentropy',  # 二分类问题
    metrics=['accuracy']
)

4、模型训练

epochs = 100

history = model.fit(
    X_train, y_train,
    epochs=epochs,
    batch_size=32,
    validation_data=(X_test, y_test),
    verbose=1
)
Epoch 1/100
9/9 [==============================] - 1s 46ms/step - loss: 0.6821 - accuracy: 0.5551 - val_loss: 0.6679 - val_accuracy: 0.6774
Epoch 2/100
9/9 [==============================] - 0s 8ms/step - loss: 0.6549 - accuracy: 0.7059 - val_loss: 0.6460 - val_accuracy: 0.7097
Epoch 3/100
9/9 [==============================] - 0s 8ms/step - loss: 0.6299 - accuracy: 0.7904 - val_loss: 0.6265 - val_accuracy: 0.7419
Epoch 4/100
9/9 [==============================] - 0s 8ms/step - loss: 0.6051 - accuracy: 0.7978 - val_loss: 0.6062 - val_accuracy: 0.7097
Epoch 5/100
9/9 [==============================] - 0s 8ms/step - loss: 0.5784 - accuracy: 0.8015 - val_loss: 0.5835 - val_accuracy: 0.7097
Epoch 6/100
9/9 [==============================] - 0s 8ms/step - loss: 0.5484 - accuracy: 0.8051 - val_loss: 0.5573 - val_accuracy: 0.7097
Epoch 7/100
9/9 [==============================] - 0s 8ms/step - loss: 0.5103 - accuracy: 0.8125 - val_loss: 0.5264 - val_accuracy: 0.7419
Epoch 8/100
9/9 [==============================] - 0s 8ms/step - loss: 0.4676 - accuracy: 0.8162 - val_loss: 0.5022 - val_accuracy: 0.7742
Epoch 9/100
9/9 [==============================] - 0s 8ms/step - loss: 0.4247 - accuracy: 0.8088 - val_loss: 0.4968 - val_accuracy: 0.8065
Epoch 10/100
9/9 [==============================] - 0s 8ms/step - loss: 0.4020 - accuracy: 0.8088 - val_loss: 0.5068 - val_accuracy: 0.7742
Epoch 11/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3937 - accuracy: 0.8051 - val_loss: 0.5095 - val_accuracy: 0.7742
Epoch 12/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3824 - accuracy: 0.8235 - val_loss: 0.5062 - val_accuracy: 0.7742
Epoch 13/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3706 - accuracy: 0.8162 - val_loss: 0.5138 - val_accuracy: 0.7742
Epoch 14/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3667 - accuracy: 0.8199 - val_loss: 0.5076 - val_accuracy: 0.7742
Epoch 15/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3528 - accuracy: 0.8346 - val_loss: 0.5169 - val_accuracy: 0.7742
Epoch 16/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3472 - accuracy: 0.8272 - val_loss: 0.5167 - val_accuracy: 0.7742
Epoch 17/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3414 - accuracy: 0.8493 - val_loss: 0.5150 - val_accuracy: 0.7742
Epoch 18/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3462 - accuracy: 0.8235 - val_loss: 0.5171 - val_accuracy: 0.7742
Epoch 19/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3344 - accuracy: 0.8566 - val_loss: 0.5133 - val_accuracy: 0.7742
Epoch 20/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3226 - accuracy: 0.8529 - val_loss: 0.5268 - val_accuracy: 0.8065
Epoch 21/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3192 - accuracy: 0.8566 - val_loss: 0.5237 - val_accuracy: 0.7742
Epoch 22/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3127 - accuracy: 0.8676 - val_loss: 0.5270 - val_accuracy: 0.7742
Epoch 23/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3071 - accuracy: 0.8640 - val_loss: 0.5354 - val_accuracy: 0.8065
Epoch 24/100
9/9 [==============================] - 0s 8ms/step - loss: 0.3029 - accuracy: 0.8713 - val_loss: 0.5337 - val_accuracy: 0.8065
Epoch 25/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2931 - accuracy: 0.8824 - val_loss: 0.5310 - val_accuracy: 0.8065
Epoch 26/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2906 - accuracy: 0.8897 - val_loss: 0.5291 - val_accuracy: 0.8065
Epoch 27/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2833 - accuracy: 0.8934 - val_loss: 0.5333 - val_accuracy: 0.8065
Epoch 28/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2777 - accuracy: 0.8897 - val_loss: 0.5417 - val_accuracy: 0.8065
Epoch 29/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2725 - accuracy: 0.8897 - val_loss: 0.5342 - val_accuracy: 0.8065
Epoch 30/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2696 - accuracy: 0.9044 - val_loss: 0.5417 - val_accuracy: 0.8065
Epoch 31/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2626 - accuracy: 0.8897 - val_loss: 0.5420 - val_accuracy: 0.8065
Epoch 32/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2552 - accuracy: 0.9007 - val_loss: 0.5424 - val_accuracy: 0.8065
Epoch 33/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2506 - accuracy: 0.9044 - val_loss: 0.5456 - val_accuracy: 0.8065
Epoch 34/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2482 - accuracy: 0.9044 - val_loss: 0.5500 - val_accuracy: 0.8065
Epoch 35/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2437 - accuracy: 0.9044 - val_loss: 0.5552 - val_accuracy: 0.8065
Epoch 36/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2425 - accuracy: 0.9191 - val_loss: 0.5511 - val_accuracy: 0.8065
Epoch 37/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2383 - accuracy: 0.9081 - val_loss: 0.5523 - val_accuracy: 0.8065
Epoch 38/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2253 - accuracy: 0.9191 - val_loss: 0.5765 - val_accuracy: 0.8065
Epoch 39/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2265 - accuracy: 0.9191 - val_loss: 0.5664 - val_accuracy: 0.8065
Epoch 40/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2197 - accuracy: 0.9265 - val_loss: 0.5732 - val_accuracy: 0.8065
Epoch 41/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2151 - accuracy: 0.9338 - val_loss: 0.5716 - val_accuracy: 0.8065
Epoch 42/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2199 - accuracy: 0.9154 - val_loss: 0.5718 - val_accuracy: 0.8065
Epoch 43/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2208 - accuracy: 0.9338 - val_loss: 0.5830 - val_accuracy: 0.8065
Epoch 44/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2010 - accuracy: 0.9412 - val_loss: 0.5761 - val_accuracy: 0.8065
Epoch 45/100
9/9 [==============================] - 0s 8ms/step - loss: 0.2038 - accuracy: 0.9265 - val_loss: 0.5897 - val_accuracy: 0.8065
Epoch 46/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1971 - accuracy: 0.9412 - val_loss: 0.5865 - val_accuracy: 0.8065
Epoch 47/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1941 - accuracy: 0.9412 - val_loss: 0.5939 - val_accuracy: 0.8065
Epoch 48/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1917 - accuracy: 0.9375 - val_loss: 0.5984 - val_accuracy: 0.8065
Epoch 49/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1890 - accuracy: 0.9449 - val_loss: 0.5874 - val_accuracy: 0.8065
Epoch 50/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1880 - accuracy: 0.9449 - val_loss: 0.5964 - val_accuracy: 0.8065
Epoch 51/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1799 - accuracy: 0.9485 - val_loss: 0.6004 - val_accuracy: 0.8065
Epoch 52/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1762 - accuracy: 0.9449 - val_loss: 0.6068 - val_accuracy: 0.8065
Epoch 53/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1719 - accuracy: 0.9485 - val_loss: 0.6046 - val_accuracy: 0.8065
Epoch 54/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1720 - accuracy: 0.9522 - val_loss: 0.6117 - val_accuracy: 0.8065
Epoch 55/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1691 - accuracy: 0.9485 - val_loss: 0.6201 - val_accuracy: 0.8065
Epoch 56/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1614 - accuracy: 0.9522 - val_loss: 0.6132 - val_accuracy: 0.8387
Epoch 57/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1597 - accuracy: 0.9559 - val_loss: 0.6295 - val_accuracy: 0.8065
Epoch 58/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1583 - accuracy: 0.9559 - val_loss: 0.6428 - val_accuracy: 0.8065
Epoch 59/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1542 - accuracy: 0.9632 - val_loss: 0.6260 - val_accuracy: 0.8387
Epoch 60/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1515 - accuracy: 0.9559 - val_loss: 0.6527 - val_accuracy: 0.8065
Epoch 61/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1495 - accuracy: 0.9596 - val_loss: 0.6550 - val_accuracy: 0.8065
Epoch 62/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1468 - accuracy: 0.9559 - val_loss: 0.6562 - val_accuracy: 0.8065
Epoch 63/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1452 - accuracy: 0.9596 - val_loss: 0.6574 - val_accuracy: 0.8387
Epoch 64/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1463 - accuracy: 0.9522 - val_loss: 0.6606 - val_accuracy: 0.8065
Epoch 65/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1406 - accuracy: 0.9632 - val_loss: 0.6614 - val_accuracy: 0.8387
Epoch 66/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1322 - accuracy: 0.9706 - val_loss: 0.6803 - val_accuracy: 0.8065
Epoch 67/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1306 - accuracy: 0.9669 - val_loss: 0.6647 - val_accuracy: 0.8387
Epoch 68/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1239 - accuracy: 0.9706 - val_loss: 0.6856 - val_accuracy: 0.8387
Epoch 69/100
9/9 [==============================] - 0s 10ms/step - loss: 0.1195 - accuracy: 0.9743 - val_loss: 0.6805 - val_accuracy: 0.8387
Epoch 70/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1164 - accuracy: 0.9743 - val_loss: 0.7036 - val_accuracy: 0.8387
Epoch 71/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1154 - accuracy: 0.9706 - val_loss: 0.7068 - val_accuracy: 0.8387
Epoch 72/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1107 - accuracy: 0.9706 - val_loss: 0.7011 - val_accuracy: 0.8387
Epoch 73/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1081 - accuracy: 0.9706 - val_loss: 0.7218 - val_accuracy: 0.8387
Epoch 74/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1031 - accuracy: 0.9706 - val_loss: 0.7341 - val_accuracy: 0.8387
Epoch 75/100
9/9 [==============================] - 0s 8ms/step - loss: 0.1045 - accuracy: 0.9706 - val_loss: 0.7233 - val_accuracy: 0.8387
Epoch 76/100
9/9 [==============================] - 0s 8ms/step - loss: 0.0986 - accuracy: 0.9669 - val_loss: 0.7459 - val_accuracy: 0.8387
Epoch 77/100
9/9 [==============================] - 0s 8ms/step - loss: 0.0955 - accuracy: 0.9743 - val_loss: 0.7471 - val_accuracy: 0.8387
Epoch 78/100
9/9 [==============================] - 0s 8ms/step - loss: 0.0900 - accuracy: 0.9743 - val_loss: 0.7459 - val_accuracy: 0.8387
Epoch 79/100
9/9 [==============================] - 0s 8ms/step - loss: 0.0916 - accuracy: 0.9743 - val_loss: 0.7714 - val_accuracy: 0.8387
Epoch 80/100
9/9 [==============================] - 0s 8ms/step - loss: 0.0845 - accuracy: 0.9743 - val_loss: 0.7712 - val_accuracy: 0.8387
Epoch 81/100
9/9 [==============================] - 0s 9ms/step - loss: 0.0817 - accuracy: 0.9743 - val_loss: 0.7707 - val_accuracy: 0.8387
Epoch 82/100
9/9 [==============================] - 0s 10ms/step - loss: 0.0827 - accuracy: 0.9779 - val_loss: 0.7993 - val_accuracy: 0.8387
Epoch 83/100
9/9 [==============================] - 0s 9ms/step - loss: 0.0750 - accuracy: 0.9779 - val_loss: 0.7947 - val_accuracy: 0.8387
Epoch 84/100
9/9 [==============================] - 0s 9ms/step - loss: 0.0738 - accuracy: 0.9743 - val_loss: 0.8213 - val_accuracy: 0.8387
Epoch 85/100
9/9 [==============================] - 0s 8ms/step - loss: 0.0713 - accuracy: 0.9779 - val_loss: 0.8187 - val_accuracy: 0.8387
Epoch 86/100
9/9 [==============================] - 0s 9ms/step - loss: 0.0670 - accuracy: 0.9816 - val_loss: 0.8190 - val_accuracy: 0.8387
Epoch 87/100
9/9 [==============================] - 0s 9ms/step - loss: 0.0643 - accuracy: 0.9816 - val_loss: 0.8394 - val_accuracy: 0.8387
Epoch 88/100
9/9 [==============================] - 0s 9ms/step - loss: 0.0623 - accuracy: 0.9816 - val_loss: 0.8506 - val_accuracy: 0.8387
Epoch 89/100
9/9 [==============================] - 0s 9ms/step - loss: 0.0569 - accuracy: 0.9890 - val_loss: 0.8615 - val_accuracy: 0.8387
Epoch 90/100
9/9 [==============================] - 0s 9ms/step - loss: 0.0551 - accuracy: 0.9890 - val_loss: 0.8653 - val_accuracy: 0.8387
Epoch 91/100
9/9 [==============================] - 0s 12ms/step - loss: 0.0518 - accuracy: 0.9890 - val_loss: 0.8789 - val_accuracy: 0.8387
Epoch 92/100
9/9 [==============================] - 0s 11ms/step - loss: 0.0506 - accuracy: 0.9890 - val_loss: 0.8979 - val_accuracy: 0.8387
Epoch 93/100
9/9 [==============================] - 0s 10ms/step - loss: 0.0475 - accuracy: 0.9853 - val_loss: 0.9083 - val_accuracy: 0.8387
Epoch 94/100
9/9 [==============================] - 0s 9ms/step - loss: 0.0458 - accuracy: 0.9926 - val_loss: 0.8964 - val_accuracy: 0.8387
Epoch 95/100
9/9 [==============================] - 0s 9ms/step - loss: 0.0430 - accuracy: 0.9926 - val_loss: 0.9234 - val_accuracy: 0.8387
Epoch 96/100
9/9 [==============================] - 0s 8ms/step - loss: 0.0422 - accuracy: 0.9926 - val_loss: 0.9358 - val_accuracy: 0.8387
Epoch 97/100
9/9 [==============================] - 0s 8ms/step - loss: 0.0390 - accuracy: 0.9890 - val_loss: 0.9299 - val_accuracy: 0.8387
Epoch 98/100
9/9 [==============================] - 0s 8ms/step - loss: 0.0367 - accuracy: 0.9926 - val_loss: 0.9745 - val_accuracy: 0.8387
Epoch 99/100
9/9 [==============================] - 0s 8ms/step - loss: 0.0348 - accuracy: 0.9926 - val_loss: 0.9798 - val_accuracy: 0.8387
Epoch 100/100
9/9 [==============================] - 0s 8ms/step - loss: 0.0343 - accuracy: 0.9963 - val_loss: 0.9618 - val_accuracy: 0.8387

5、结果展示

train_acc = history.history['accuracy']
train_loss = history.history['loss']

test_acc = history.history['val_accuracy']
test_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Train_acc')
plt.plot(epochs_range, test_acc, label='Test_acc')
plt.legend(loc='lower right')
plt.title("Accuracy")

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Train_loss')
plt.plot(epochs_range, test_loss, label='Test_loss')
plt.legend(loc='upper right')
plt.title("Loss")

plt.show()


在这里插入图片描述

6、模型评估

# 评估:返回的是自己在model.compile中设置,这里为accuracy
score = model.evaluate(X_test, y_test, verbose=0)
print("socre[loss, accuracy]: ", score) # 返回为两个,一个是loss,一个是accuracy
socre[loss, accuracy]:  [0.9617615938186646, 0.8387096524238586]

网站公告

今日签到

点亮在社区的每一天
去签到