PyTorch深度学习实战(7)—— 线性回归

发布于:2024-08-10 ⋅ 阅读:(104) ⋅ 点赞:(0)

线性回归是机器学习的入门内容,应用十分广泛。线性回归利用数理统计中的回归分析来确定两种或两种以上变量间相互依赖的定量关系,其表达形式为$y = wx+b+e$。其中,$x$和$y$是输入输出数据,$w$和$b$是可学习参数,误差$e$服从均值为0的正态分布。线性回归的损失函数如式(3.2)所示。

本节利用随机梯度下降法更新参数$w$ 和$b$ 来最小化损失函数,最终学得$w$ 和$b$ 的数值。

In: import torch as t
    %matplotlib inline
    from matplotlib import pyplot as plt
    from IPython import display
    
    device = t.device('cpu') #如果使用GPU,则改成t.device('cuda:0')

In: # 设置随机数种子,保证在不同机器上运行时下面的输出一致
    t.manual_seed(2021) 
    
    def get_fake_data(batch_size=8):
        ''' 产生随机数据:y=2x+3,加上了一些噪声'''
        x = t.rand(batch_size, 1, device=device) * 5
        y = x * 2 + 3 +  t.randn(batch_size, 1, device=device)
        return x, y

In: # 来看看产生的x-y分布
    x, y = get_fake_data(batch_size=16)
    plt.scatter(x.squeeze().cpu().numpy(), y.squeeze().cpu().numpy())
 
 Out:<matplotlib.collections.PathCollection at 0x7fcd24179c88>

In: # 随机初始化参数
    w = t.rand(1, 1).to(device)
    b = t.zeros(1, 1).to(device)
    
    lr = 0.02 # 学习率learning rate
    
    for ii in range(500):
        x, y = get_fake_data(batch_size=4)
        
        # forward:计算loss
        y_pred = x.mm(w) + b.expand_as(y) # expand_as用到了广播法则
        loss = 0.5 * (y_pred - y) ** 2 # 均方误差
        loss = loss.mean()
        
        # backward:手动计算梯度
        dloss = 1
        dy_pred = dloss * (y_pred - y)
        
        dw = x.t().mm(dy_pred)
        db = dy_pred.sum()
        
        # 更新参数
        w.sub_(lr * dw) # inplace函数
        b.sub_(lr * db)
        
        if ii % 50 == 0:
            # 画图
            display.clear_output(wait=True)
            x = t.arange(0, 6).float().view(-1, 1)
            y = x.mm(w) + b.expand_as(x)
            plt.plot(x.cpu().numpy(), y.cpu().numpy()) # 线性回归的结果
            
            x2, y2 = get_fake_data(batch_size=32) 
            plt.scatter(x2.numpy(), y2.numpy()) # 真实的数据
            
            plt.xlim(0, 5)
            plt.ylim(0, 13)
            plt.show()
            plt.pause(0.5)
            
    print(f'w: {w.item():.3f}, b: {b.item():.3f}')

Out:w: 1.911 b: 3.044

可见程序已经基本学出$w=2$、$b=3$,并且图中直线和数据已经实现较好的拟合。

上面提到了Tensor的许多操作,这里不要求读者全部掌握,今后使用时可以再查阅这部分内容或者查阅官方文档,在此读者只需有个基本印象即可。


网站公告

今日签到

点亮在社区的每一天
去签到