Class9简洁实现

发布于:2025-07-18 ⋅ 阅读:(11) ⋅ 点赞:(0)

Class9简洁实现

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
# 初始化训练样本、测试样本、样本特征维度和批量大小
n_train,n_test,num_inputs,batch_size = 20,100,200,5
# 设置真实权重和偏置
true_w,true_b = torch.ones((num_inputs,1)) * 0.01,0.05
# 生成训练数据
# d2l.synthetic_data():函数生成模拟的训练数据
# synthetic_data()L返回三元组(features,labels)
train_data = d2l.synthetic_data(true_w,true_b,n_train)
# 数据封装为训练数据迭代器
# d2l.load_array():把数据打包成一个笑屁刘昂迭代器,便于后续训练
# batch_size=5:每次迭代返回5个样本
train_iter = d2l.load_array(train_data,batch_size)
# 生成测试数据
test_data = d2l.synthetic_data(true_w,true_b,n_test)
# 数据封装为测试数据迭代器
test_iter = d2l.load_array(test_data,batch_size,is_train=False)
# 实现带权重衰减(L2正则)线性回归模型训练
# wd:L2正则化系数lambd
def train_concise(wd):
    # 构建一个全连接层,输入为num_inputs,输出为1
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        # 将参数用正态分布随机初始化
        param.data.normal_()
    # 样本的均方误差不求平均
    loss = nn.MSELoss(reduction='none')
    # 定义训练轮数和学习率
    num_epochs, lr = 100, 0.003
    # 使用随机梯度下降优化器
    trainer = torch.optim.SGD([
        # 权重参数,应用L2正则
        {"params":net[0].weight,'weight_decay': wd},
        # 偏置参数,不加正则
        {"params":net[0].bias}], lr=lr)
    # 定义可视化工具
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    # 循环训练
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 清空梯度,防止梯度累加
            trainer.zero_grad()
            # 计算每个样本的MSELoss
            l = loss(net(X), y)
            # 进行反向传播
            l.mean().backward()
            # 更新模型参数
            trainer.step()
        # 每5轮评估训练集和测试集的loss损失函数
        if (epoch + 1) % 5 == 0:
            # 将当前loss加入到动态图中
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    # 打印输出L2范数
    print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)
train_concise(3)

网站公告

今日签到

点亮在社区的每一天
去签到