糖尿病数据集的逻辑回归
数据集
pima-indians-diabetes: 一、 数据说明: Pima Indians Diabetes Data Set(皮马印第安人糖尿病数据集) 根据现有的医疗信息预测5年内皮马印第安人糖尿病发作的概率。 数据链接:https://archive.ics.uci.edu/ml/datasets/Pima+Indians+Diabetes https://gitee.com/biabianm/pima-indians-diabetespima-indians-diabetes: 一、 数据说明: Pima Indians Diabetes Data Set(皮马印第安人糖尿病数据集) 根据现有的医疗信息预测5年内皮马印第安人糖尿病发作的概率。 数据链接:https://archive.ics.uci.edu/ml/datasets/Pima+Indians+Diabetes
解压后里面有一个pima-indians-diabetes.csv文件 复制粘贴到python代码统一目录下
m_features.py如下
np.loadtxt()
用于从文本加载数据。
loadtxt(fname, dtype=<class 'float'>, comments='#', delimiter=None, converters=None, skiprows=0, usecols=None, unpack=False, ndmin=0)
xy = np.loadtxt('pima-indians-diabetes.csv', delimiter=',', dtype=np.float32,skiprows = 1)
fname要读取的文件、文件名、或生成器。
dtype数据类型,默认float。
comments注释。
delimiter分隔符,默认是空格。
skiprows跳过前几行读取,默认是0,必须是int整型。
usecols要读取哪些列,0是第一列。例如,usecols = (1,4,5)将提取第2,第5和第6列。默认读取所有列。
unpack如果为True,将分列读取。
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
# 加载数据集
xy = np.loadtxt('pima-indians-diabetes.csv', delimiter=',', dtype=np.float32,skiprows = 1)
# 数据预处理,包括从数据集里区分输入输出,最后把输入输出数据封装成Pytorch期望的Variable格式
X_train= torch.from_numpy(xy[:,:-1]) # 特征信息
y_train= torch.from_numpy(xy[:,[-1]]) # 目标分类
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1=torch.nn.Linear(8,6)
self.linear2= torch.nn.Linear(6,4)
self.linear3= torch.nn.Linear(4,1)
self.sigmoid=torch.nn.Sigmoid()
def forward(self,x):
x=self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model=Model()
criterion=torch.nn.BCELoss(size_average=True)
optimizer=torch.optim.SGD(model.parameters(),lr=0.1)
px,py = [],[] # 记录要绘制的数据
for epoch in range(1000):
y_pred=model(X_train)
loss=criterion(y_pred,y_train)
print(epoch,loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
px.append(epoch)
py.append(loss.item())
plt.plot(px, py)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
# # 每十次迭代绘制训练动态
# if epoch% 10 == 0:
# plt.cla()
# plt.plot(px, py, 'r-', lw=1)
# plt.text(0, 0, 'Loss=%.4f' % loss.item(), fontdict={'size': 20, 'color': 'red'})
# plt.pause(0.1)
本文含有隐藏内容,请 开通VIP 后查看