Pytorch学习笔记

发布于:2022-07-26 ⋅ 阅读:(307) ⋅ 点赞:(0)

记录一下自己的学习笔记

看了B站刘二大人的pytorch视频,想复现一下代码来加深对深度学习的理解,以下是代码和注解,以及测试结果

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]  #创建两个测试数据集列表
y_data = [2.0, 4.0, 6.0]

'''w是我们要通过测试找出来的那个误差最小时w的值'''
def forword(x):
    return x * w #需要下面for循环来定义,并赋值给w

'''计算MSE(平均平方误差)'''
def loss(x,y):
    y_pred = forword(x)
    return (y - y_pred)*(y - y_pred) #或者可以写成(y-y_pred)**2

'''training'''
w_list = []  #创建两个空列表,用来存放模拟数据集
mse_list = []
'''遍历'''
'''for循环中,同一个循环中的语句的前端要对齐,当有两个嵌套循环时,注意各语句所属位置'''
for w in np.arange(0.0, 4.0, 0.1): #以0.0为初始值循环赋值给w,每次增加0.1,终止值为4.0
    print('w=', w) #输出w的值
    l_sum = 0
    for x_val,y_val in zip(x_data,y_data): #通过zip函数使得数据集能够拼成所测试函数
        y_pred_val = forword(x_val) #计算预算值并打印输出
        loss_val = loss(x_val,y_val)   #loss = (x_val - y_pred_val)**2
        l_sum += loss_val
        print('\t', x_val,y_val,y_pred_val,loss_val)
    print('mse', l_sum/3)  #由于只有三个测试数据,所以平均平方误差要除以三
    # matplotlib实现可视化
    w_list.append(w)  #空列表添加数据用append
    mse_list.append(l_sum/3)

plt.plot(w_list,mse_list)
plt.xlabel('w')
plt.ylabel('loss')
plt.show()

![这是在知道w为2的时候得到的数据模型为线性函数y=w*x,
一般数据模型需要我们自己去推导(https://img-blog.csdnimg.cn/c7c348a053e1478ea10ae7524575436b.png#pic_center)














本文含有隐藏内容,请 开通VIP 后查看

网站公告

今日签到

点亮在社区的每一天
去签到