@浙大疏锦行https://blog.csdn.net/weixin_45655710
知识点回顾:
- 彩色和灰度图片测试和训练的规范写法:封装在函数中
- 展平操作:除第一个维度batchsize外全部展平
- dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout
作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。
今天代码中训练 (train
) 和测试 (test
) 函数的规范写法。这套代码框架是PyTorch深度学习项目的基石,理解了它,未来接触任何复杂的模型和任务,其核心训练逻辑都是万变不离其宗的。
一个通俗的比喻来理解:把整个过程想象成“学生(模型)备考(训练)和参加期末考(测试)”。
核心框架概览
我们的代码主要由两个核心部分组成:
train
函数:学生在一个学期内(一个epoch
)反复做练习题(train_loader
中的数据)的过程。test
函数:学期结束后,用一套全新的模拟卷(test_loader
中的数据)来检验学生的真实水平。
主程序 (if __name__ == "__main__":
) 则扮演**“教务处”**的角色,负责安排学期总数 (epochs
),并协调“训练”和“测试”的进行。
一、 train
函数解析:学生的学习过程
def train(model, train_loader, ...)
这个函数的目标是让模型 (model
) 通过学习训练数据 (train_loader
) 来不断更新自己的知识(权重参数)。
它的内部逻辑可以分为两层循环:
外层循环:for epoch in range(epochs):
(一个学期)
epoch
代表一个完整的学习周期,我们称之为“轮次”。在一轮中,学生(模型)会把所有的练习册(整个train_dataset
)从头到尾做一遍。model.train()
:在每个学期开始时,学生要告诉自己:“现在是学习时间!” 这会开启一些只在学习时才用的“超能力”,比如 Dropout(为了防止死记硬背而故意忘掉一些东西)和 BatchNorm(一种让学习更稳定的技巧)。
内层循环:for batch_idx, (data, target) in enumerate(train_loader):
(做一页练习题)
train_loader
像是一本很厚的练习册,它被分成了很多页,每一页就是一批 (batch
) 数据。- 这个循环就是学生一页一页地做练习题的过程。
data
是这一页的题目(图像),target
是标准答案(标签)。 data, target = data.to(device), target.to(device)
: 把这一页练习题和答案都拿到“大脑”(GPU)里去处理,速度更快。
做一页练习题的核心四步曲:
optimizer.zero_grad()
(清空草稿纸):在做新一页题前,先把上一页的计算草稿(梯度)擦干净。PyTorch默认会累积梯度,所以每次都必须手动清零。output = model(data)
(做题):学生(模型)根据自己当前的知识水平,对这页的题目(data
)给出自己的答案(output
)。loss = criterion(output, target)
(对答案并计算差距):criterion
(损失函数) 就像一个评分老师,它会比较学生的答案(output
)和标准答案(target
),然后计算出一个差距值(loss
)。差距越大,loss
值也越大。loss.backward()
(反思总结):这是最神奇的一步,也叫反向传播。学生根据差距 (loss
),反思自己知识体系里的每一个知识点(模型参数)对这次做错题的“责任”有多大。这个“责任”就是梯度。optimizer.step()
(修正知识):optimizer
(优化器) 像一个学习方法指导老师,它根据每个知识点的“责任”(梯度),告诉学生该如何去调整、更新自己的知识(模型参数),以便下次能做得更好。
二、 test
函数解析:学生的期末考试
def test(model, test_loader, ...)
这个函数的目标是检验模型在从未见过的新数据上的表现,以评估其真实的泛化能力。
它的核心逻辑如下:
model.eval()
(进入考试模式):在考试前,学生要告诉自己:“现在是考试时间!” 这会关闭那些只在学习时才用的“超能力”(如Dropout和BatchNorm),确保每次考试的结果都是稳定、一致的。这是至关重要的一步。with torch.no_grad():
(收起草稿纸,只答题不学习):这个代码块告诉PyTorch:“接下来只进行计算,不需要记录任何‘反思过程’(梯度)”。这能大大加快计算速度,并节省显存,因为考试时不需要再学习了。- 循环与计算:
- 它会遍历测试题库 (
test_loader
) 中的每一批数据。 output = model(data)
(做题):学生用自己最终学到的知识来解答这些全新的题目。correct += ...
(计分):将学生的答案与标准答案进行比较,统计做对的总题数。
- 它会遍历测试题库 (
- 返回结果:最终计算出总的平均损失和准确率,作为这次期末考的最终成绩。
总结:一条清晰的逻辑线
将整个流程想象成一个高度自动化、目标明确的“智能教育系统”:
- 数据准备 (
DataLoader
):系统将海量的练习题和模拟卷整理成册,分门别类。 - 模型定义 (
nn.Module
):我们设计了一个“学生”的大脑结构。
主流程(if __name__ == "__main__":
) :- “教务处”宣布:“本学期共
epochs
轮学习!” - 进入每一轮学习 (
for epoch in ...
):- 首先,命令学生(模型)进入学习状态 (
model.train()
),并开始做一整本练习册 (train
函数)。 - 做完练习册后,为了检验本轮学习效果,立刻命令学生进入考试状态 (
model.eval()
),做一套期末模拟卷 (test
函数),并公布成绩。
- 首先,命令学生(模型)进入学习状态 (
- 所有学期结束后,整个培养计划完成。
- “教务处”宣布:“本学期共
这个“训练一轮,测试一轮”的循环框架,是深度学习项目中最核心、最通用的代码结构。掌握了它,就掌握了驱动所有复杂模型进行学习和评估的“引擎”。