warm-up结合CosineAnnealingLR

发布于:2024-05-14 ⋅ 阅读:(192) ⋅ 点赞:(0)

      在结合 warm-up 和余弦退火调度器时,warm-up 阶段的初始学习率和余弦退火调度器的最大学习率不一定需要相同。通常情况下,这两者的学习率可以不同。

      在实际应用中,你可以根据具体情况合理设置这两个阶段的学习率,使得模型训练能够更好地收敛和达到较高的性能。一般来说,warm-up 阶段的学习率可以设置相对较低,以帮助模型在初始阶段更稳定地学习参数;而余弦退火阶段的最大学习率可以设置较高,以在训练后期更好地优化模型。示例展示在 PyTorch 中将 warm-up 阶段的学习率与余弦退火阶段的最大学习率设置为不同的值:

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

# 定义模型和优化器
model = YourModel()
optimizer = optim.SGD(model.parameters(), lr=0.1)  # 初始学习率
warmup_lr = 0.01
cosine_max_lr = 0.2
warmup_epochs = 5
cosine_epochs = 50
total_epochs = warmup_epochs + cosine_epochs

# 定义学习率调度器,结合 warm-up 和余弦退火
scheduler = CosineAnnealingLR(optimizer, T_max=cosine_epochs, eta_min=0)

# 训练循环
for epoch in range(total_epochs):
    # 更新学习率
    if epoch < warmup_epochs:
        new_lr = warmup_lr + (0.1 - warmup_lr) * (epoch / warmup_epochs)
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lr
    else:
        scheduler.step()

    # 训练代码
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()

各种学习率曲线:

import torch
from torch.optim.lr_scheduler import *
import torch.nn as nn
from torchvision.models import resnet50
import matplotlib.pyplot as plt
# from lr_scheduler.scheduler import GradualWarmupScheduler
 
 
model = resnet50(False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)
 
scheduler1 = LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)
scheduler2 = StepLR(optimizer, step_size=10, gamma=0.1)
scheduler3 = MultiStepLR(optimizer, milestones=[5,10,15,20,25], gamma=0.1)
scheduler4 = ExponentialLR(optimizer, gamma=0.8) 
scheduler5 = CosineAnnealingLR(optimizer,T_max=5,eta_min=0.05)
scheduler6 = CyclicLR(optimizer, base_lr=0.01, max_lr=0.2, step_size_up=10, step_size_down=5)
scheduler7 = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=0.01)
# scheduler8 = GradualWarmupScheduler(optimizer, 1, 5, scheduler2)
 
plt.figure()
max_epoch = 30
cur_lr_list = []
for epoch in range(max_epoch):
    optimizer.step()
    scheduler5.step()
    cur_lr = optimizer.param_groups[-1]['lr']
    cur_lr_list.append(cur_lr)
    print('Current lr:', cur_lr)
x_list = list(range(len(cur_lr_list)))
plt.plot(x_list, cur_lr_list)
plt.show()
plt.savefig('gradualwarmupscheduler.png')