RT-DETR代码详解(官方pytorch版)——模型加载(2)

发布于:2025-02-10 ⋅ 阅读:(104) ⋅ 点赞:(0)

 概述

这篇博客主要是找到在RT-DETR中,模型和数据集是怎么传入train_ine_epoch中进行训练的

一、train.py

二、solver/__init__.py文件

 在train.py的头文件中from src.solver import TASKS,TASKS不是文件,可以看到左侧有init.py文件。

Python中的__init__.py文件作用-CSDN博客

在init.py文件中,TASKS是一个字典类型变量,使用了 Python 的 类型注解,通过 Dict[str, BaseSolver] 表明 TASKS 是一个字典类型。

  • Dict 表示这是一个字典,其中:
    • 键(key)是 字符串类型str)。这里表示某个任务名称,如"detection"
    • 值(value)是 BaseSolver 类型。(BaseSolever是一个父类,DetSolver是其子类)

用于将任务名称(如 'detection')映射到对应的求解器类。

在train.py文件中,

solver = TASKS[cfg.yaml_cfg['task']](cfg)

cfg.yaml_cfg['task'] 用于获取任务类型,cfg.yaml_cfg 一般内容如下:

{
    "task": "detection",
    "learning_rate": 0.001,
    "batch_size": 32
}

 (可看RT-DETR代码详解(官方pytorch版)——参数配置(1)-CSDN博客

  • 这里就相当于cfg.yaml_cfg['task'] 返回 'detection'

  • TASKS[cfg.yaml_cfg['task']] 等价于 TASKS['detection']

  • 结果是 DetectionSolver(一个类)

  • DetectionSolver(cfg) 会调用类的构造方法(__init__ 方法),并传入参数 cfg

 三、solver/solver.py/BaseSolver

3.1 __init__初始化

def __init__(self, cfg: BaseConfig) -> None:
    self.cfg = cfg
  • 作用:初始化 BaseSolver 实例,接收一个配置对象 cfg,该配置对象通常包含训练所需的所有配置信息(如设备、优化器、数据加载器等)。

3.2 setup方法

def setup(self, ):
    '''Avoid instantiating unnecessary classes 
    '''
    # 配置设备和属性
    cfg = self.cfg
    device = cfg.device
    self.device = device
    self.last_epoch = cfg.last_epoch

    # 初始化模型、损失函数、后处理器
    self.model = dist.warp_model(cfg.model.to(device), cfg.find_unused_parameters, cfg.sync_bn)
    self.criterion = cfg.criterion.to(device)
    self.postprocessor = cfg.postprocessor

    # 加载调优状态(如果有)
    if self.cfg.tuning:
        print(f'Tuning checkpoint from {self.cfg.tuning}')
        self.load_tuning_state(self.cfg.tuning)

    # 初始化混合精度、EMA、输出目录
    self.scaler = cfg.scaler
    self.ema = cfg.ema.to(device) if cfg.ema is not None else None 

    self.output_dir = Path(cfg.output_dir)
    self.output_dir.mkdir(parents=True, exist_ok=True)

 

作用

  • 配置模型、损失函数、设备、后处理器等。
  • 支持 混合精度训练scaler)和 EMA机制(Exponential Moving Average)
  • 支持 Fine-Tuning(如果有预训练模型)。
  • 创建输出目录,用于保存模型的状态。

设计思想

  • 避免在每个方法中重复配置,使用 setup 方法集中进行初始化。
  • 支持分布式训练(dist.warp_model 和 dist.warp_loader)。

3.3 train方法

def train(self, ):
    self.setup()
    self.optimizer = self.cfg.optimizer
    self.lr_scheduler = self.cfg.lr_scheduler

    # 加载断点
    if self.cfg.resume:
        print(f'Resume checkpoint from {self.cfg.resume}')
        self.resume(self.cfg.resume)

    # 数据加载器
    self.train_dataloader = dist.warp_loader(self.cfg.train_dataloader, shuffle=self.cfg.train_dataloader.shuffle)
    self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, shuffle=self.cfg.val_dataloader.shuffle)

作用

  • 调用 setup() 方法完成初始化。
  • 配置优化器和学习率调度器。
  • 支持从断点恢复训练。
  • 配置训练和验证数据加载器,并支持分布式。

设计思想

  • 提供训练流程的模板,确保训练前的各种组件正确配置。

 3.4 eval方法

def eval(self, ):
    self.setup()
    self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, shuffle=self.cfg.val_dataloader.shuffle)

    if self.cfg.resume:
        print(f'resume from {self.cfg.resume}')
        self.resume(self.cfg.resume)
  • 作用
    • 配置验证数据加载器。
    • 支持从断点恢复,以便进行测试或验证。

 3.5 模型保存与加载

3.5.1 state_dict()

def state_dict(self, last_epoch):
    '''state dict
    '''
    state = {}
    state['model'] = dist.de_parallel(self.model).state_dict()
    state['date'] = datetime.now().isoformat()
    state['last_epoch'] = last_epoch

    if self.optimizer is not None:
        state['optimizer'] = self.optimizer.state_dict()

    if self.lr_scheduler is not None:
        state['lr_scheduler'] = self.lr_scheduler.state_dict()

    if self.ema is not None:
        state['ema'] = self.ema.state_dict()

    if self.scaler is not None:
        state['scaler'] = self.scaler.state_dict()

    return state
  • 作用
    • 将模型、优化器、学习率调度器、EMA 和混合精度等状态保存为字典,以便后续加载。

3.5.2 load_state_dict()

def load_state_dict(self, state):
    '''load state dict
    '''
    # 加载模型、优化器、调度器等状态
    if getattr(self, 'last_epoch', None) and 'last_epoch' in state:
        self.last_epoch = state['last_epoch']
        print('Loading last_epoch')

    if getattr(self, 'model', None) and 'model' in state:
        if dist.is_parallel(self.model):
            self.model.module.load_state_dict(state['model'])
        else:
            self.model.load_state_dict(state['model'])
        print('Loading model.state_dict')

    if getattr(self, 'ema', None) and 'ema' in state:
        self.ema.load_state_dict(state['ema'])
        print('Loading ema.state_dict')

    if getattr(self, 'optimizer', None) and 'optimizer' in state:
        self.optimizer.load_state_dict(state['optimizer'])
        print('Loading optimizer.state_dict')

    if getattr(self, 'lr_scheduler', None) and 'lr_scheduler' in state:
        self.lr_scheduler.load_state_dict(state['lr_scheduler'])
        print('Loading lr_scheduler.state_dict')

    if getattr(self, 'scaler', None) and 'scaler' in state:
        self.scaler.load_state_dict(state['scaler'])
        print('Loading scaler.state_dict')
  • 作用
    • 从保存的状态字典中恢复模型和训练过程。
    • 支持分布式训练环境。

 3.5.3 save 和 resume

def save(self, path):
    '''save state
    '''
    state = self.state_dict()
    dist.save_on_master(state, path)

def resume(self, path):
    '''load resume
    '''
    state = torch.load(path, map_location='cpu')
    self.load_state_dict(state)
  • 作用
    • 保存和加载模型的断点状态,支持分布式的保存。

3.6 load_tuning_state方法

def load_tuning_state(self, path):
    """only load model for tuning and skip missed/dismatched keys
    """
    if 'http' in path:
        state = torch.hub.load_state_dict_from_url(path, map_location='cpu')
    else:
        state = torch.load(path, map_location='cpu')

    module = dist.de_parallel(self.model)

    if 'ema' in state:
        stat, infos = self._matched_state(module.state_dict(), state['ema']['module'])
    else:
        stat, infos = self._matched_state(module.state_dict(), state['model'])

    module.load_state_dict(stat, strict=False)
    print(f'Load model.state_dict, {infos}')
  • 作用
    • 加载调优状态时,跳过不匹配或缺失的权重。
    • 用于 Fine-Tuning 场景。

 四、solver/det_solver.py/DetSolver

TASKS中detection任务对应的DetSolver类,它继承自之前定义的 BaseSolver 类,并实现了具体的训练(fit 方法)和验证(val 方法)逻辑

4.1 fit方法

fit 方法定义了目标检测模型的训练流程,包括训练、验证、保存模型状态等。

def fit(self):
    print("Start training")
    self.train()  # 初始化训练配置
    ...
    # 开始训练循环
    for epoch in range(self.last_epoch + 1, args.epoches):
        ...
        # 训练单个 epoch
        train_stats = train_one_epoch(
            self.model, self.criterion, self.train_dataloader, self.optimizer, self.device, epoch,
            args.clip_max_norm, print_freq=args.log_step, ema=self.ema, scaler=self.scaler)

        # 更新学习率调度器
        self.lr_scheduler.step()

        # 保存模型状态
        checkpoint_paths = [self.output_dir / 'checkpoint.pth']
        ...
        for checkpoint_path in checkpoint_paths:
            dist.save_on_master(self.state_dict(epoch), checkpoint_path)

        # 验证模型
        module = self.ema.module if self.ema else self.model
        test_stats, coco_evaluator = evaluate(
            module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir
        )
        ...
        # 打印和保存日志
        ...

这部分内容可以对应DETR系列中main.py文件中的部分:

 找到这部分内容就好了,知道模型和数据集是怎么传入代码中进行训练的,后面就可以根据传入的模型和数据找到对应的初始位置然后进行修改

详细解析
(1) 准备工作
self.train()
args = self.cfg
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
best_stat = {'epoch': -1, }
  • 调用 self.train() 进行初始化,包括配置优化器、学习率调度器等。
  • 计算模型的参数总数(n_parameters),便于日志记录。
  • 获取验证数据集的 COCO 接口对象(base_ds),用于后续评估。
(2) 训练循环
for epoch in range(self.last_epoch + 1, args.epoches):
    if dist.is_dist_available_and_initialized():
        self.train_dataloader.sampler.set_epoch(epoch)
    
    train_stats = train_one_epoch(
        self.model, self.criterion, self.train_dataloader, self.optimizer, self.device, epoch,
        args.clip_max_norm, print_freq=args.log_step, ema=self.ema, scaler=self.scaler)
  • 分布式支持:如果使用分布式训练(dist),为每个 epoch 设置随机采样器的种子。
  • 单次 epoch 训练
    • 调用 train_one_epoch 方法,完成模型在一个 epoch 内的训练(包括前向传播、计算损失、反向传播、权重更新等)。
    • 支持 梯度裁剪clip_max_norm)、EMA 模型更新 和 混合精度训练
(3) 更新学习率
self.lr_scheduler.step()
  • 调用学习率调度器,调整优化器的学习率。
(4) 保存模型状态
if self.output_dir:
    checkpoint_paths = [self.output_dir / 'checkpoint.pth']
    if (epoch + 1) % args.checkpoint_step == 0:
        checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth')
    for checkpoint_path in checkpoint_paths:
        dist.save_on_master(self.state_dict(epoch), checkpoint_path)
  • 定期保存模型的断点状态(包括模型权重、优化器状态等)。
  • 默认保存为 checkpoint.pth,并在指定训练轮数(如每 100 轮)的基础上保存额外的检查点。
(5) 验证模型
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
    module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir
)

  • 使用 evaluate 方法对模型进行验证,测量其在验证集上的性能。
  • 如果启用了 EMA 模型,则使用 EMA 模型进行验证。
(6) 更新最佳结果(best_stat
for k in test_stats.keys():
    if k in best_stat:
        best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch']
        best_stat[k] = max(best_stat[k], test_stats[k][0])
    else:
        best_stat['epoch'] = epoch
        best_stat[k] = test_stats[k][0]
print('best_stat: ', best_stat)

  • 将当前 epoch 的验证结果与 best_stat 进行对比,更新最佳性能记录。
(7) 记录日志
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
            **{f'test_{k}': v for k, v in test_stats.items()},
            'epoch': epoch,
            'n_parameters': n_parameters}

if self.output_dir and dist.is_main_process():
    with (self.output_dir / "log.txt").open("a") as f:
        f.write(json.dumps(log_stats) + "\n")

  • 记录训练和验证的统计信息,并将其保存为 JSON 格式的日志文件。
(8) 保存评估结果
if coco_evaluator is not None:
    (self.output_dir / 'eval').mkdir(exist_ok=True)
    if "bbox" in coco_evaluator.coco_eval:
        filenames = ['latest.pth']
        if epoch % 50 == 0:
            filenames.append(f'{epoch:03}.pth')
        for name in filenames:
            torch.save(coco_evaluator.coco_eval["bbox"].eval,
                    self.output_dir / "eval" / name)

  • 保存 COCO 评估结果(如 bbox 的评估指标)。

 4.2 val方法

val 方法定义了模型的验证流程,主要用于评估模型在验证集上的性能。

def val(self):
    self.eval()  # 初始化验证配置
    base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
    module = self.ema.module if self.ema else self.model
    test_stats, coco_evaluator = evaluate(
        module, self.criterion, self.postprocessor,
        self.val_dataloader, base_ds, self.device, self.output_dir
    )
    if self.output_dir:
        dist.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth")

五、solver/det_engine.py/train_one_epoch

det_engine.py就和detr系列中的engine.py文件内容一样了

 

 


网站公告

今日签到

点亮在社区的每一天
去签到