【MindSpore】实现保存最优模型

发布于:2022-12-19 ⋅ 阅读:(145) ⋅ 点赞:(0)
  1. 功能描述: MindSpore训练模型时,实现保存最优模型。
  2. 实现保存最优模型功能简介: 在面对复杂网络时,往往需要进行几十甚至几百次的epoch训练。在训练之前,很难掌握在训练到第几个epoch时,模型的精度能达到满足要求的程度,所以经常会采用一边训练的同时,在相隔固定epoch的位置对模型进行精度验证,并保存相应的模型,等训练完毕后,通过查看对应模型精度的变化就能迅速地挑选出相对最优的模型。 流程如下:
  1. 定义回调函数EvalCallBack,实现同步进行训练和验证。
  2. 定义训练网络并执行。
  3. 将不同epoch下的模型精度绘制出折线图并挑选最优模型。
  1. 原因分析: MindSpore在训练模型时,保存最后一个ckpt可能精度不达标。
  2. 解决方案: apply_eval函数,用来验证模型的精度。定义回调函数EvalCallBack: 模型验证
    def apply_eval(eval_param):
        eval_model = eval_param['model']
        eval_ds = eval_param['dataset']
        metrics_name = eval_param['metrics_name']
        res = eval_model.eval(eval_ds)
        return res[metrics_name]
    我们自定义一个数据收集的回调类EvalCallBack,用于实现下面两种信息:
  • 4.1 训练过程中,每一个epoch结束之后,训练集的损失值和验证集的模型精度。
  • 4.2 保存精度最高的模型。
    class EvalCallBack(Callback):
      """
      回调类,获取训练过程中模型的信息
      """
      def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
                   ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
          super(EvalCallBack, self).__init__()
          self.eval_param_dict = eval_param_dict
          self.eval_function = eval_function
          self.eval_start_epoch = eval_start_epoch
          if interval < 1:
              raise ValueError("interval should >= 1.")
          self.interval = interval
          self.save_best_ckpt = save_best_ckpt
          self.best_res = 0
          self.best_epoch = 0
          if not os.path.isdir(ckpt_directory):
              os.makedirs(ckpt_directory)
          self.best_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
          self.metrics_name = metrics_name

删除ckpt文件

    def remove_ckpoint_file(self, file_name):
        os.chmod(file_name, stat.S_IWRITE)
        os.remove(file_name)

每一个epoch后,打印训练集的损失值和验证集的模型精度,并保存精度最好的ckpt文件

    def epoch_end(self, run_context):
        cb_params = run_context.original_args()
        cur_epoch = cb_params.cur_epoch_num
        loss_epoch = cb_params.net_outputs
        if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
            res = self.eval_function(self.eval_param_dict)
            print('Epoch {}/{}'.format(cur_epoch, num_epochs))
            print('-' * 10)
            print('train Loss: {}'.format(loss_epoch))
            print('val Acc: {}'.format(res))
            if res >= self.best_res:
                self.best_res = res
                self.best_epoch = cur_epoch
                if self.save_best_ckpt:
                    if os.path.exists(self.best_ckpt_path):
                        self.remove_ckpoint_file(self.best_ckpt_path)
                    save_checkpoint(cb_params.train_network, self.best_ckpt_path)

训练结束后,打印最好的精度和对应的epoch

def end(self, run_context):
    print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, self.best_res, self.best_epoch), flush=True)

训练和评估 运行下面代码,开始模型训练。

train_ds =  create_dataset(train_data_path)
val_ds = create_dataset(val_data_path)
eval_param_dict = {"model":model,"dataset":val_ds,"metrics_name":"Accuracy"}
eval_cb = EvalCallBack(apply_eval, eval_param_dict,)

训练模型

model.train(num_epochs,train_ds, callbacks=[eval_cb, TimeMonitor()], dataset_sink_mode=True)
复制

结果:

Epoch 1/20
----------
train Loss: 0.47486544
val Acc: 0.8333333333333334
epoch time: 8439.054 ms, per step time: 140.651 ms
Epoch 2/20
----------
train Loss: 0.20464368
val Acc: 0.8333333333333334
epoch time: 3805.755 ms, per step time: 63.429 ms
Epoch 3/20
----------
train Loss: 0.3345307
val Acc: 0.9166666666666666
epoch time: 3721.042 ms, per step time: 62.017 ms
Epoch 4/20
----------
train Loss: 0.7761406
val Acc: 0.8333333333333334
epoch time: 3302.892 ms, per step time: 55.048 ms
Epoch 5/20
----------
train Loss: 0.3566268
val Acc: 0.9
epoch time: 3375.371 ms, per step time: 56.256 ms
Epoch 6/20
----------
train Loss: 0.13434622
val Acc: 0.9333333333333333
epoch time: 4012.532 ms, per step time: 66.876 ms
Epoch 7/20
----------
train Loss: 0.20843573
val Acc: 0.85
epoch time: 3357.198 ms, per step time: 55.953 ms
Epoch 8/20
----------
train Loss: 0.96780926
val Acc: 0.95
epoch time: 3628.576 ms, per step time: 60.476 ms
Epoch 9/20
----------
train Loss: 1.4824448
val Acc: 0.8666666666666667
epoch time: 3403.053 ms, per step time: 56.718 ms
Epoch 10/20
----------
train Loss: 0.11375467
val Acc: 0.9166666666666666
epoch time: 3293.931 ms, per step time: 54.899 ms
Epoch 11/20
----------
train Loss: 0.14315866
val Acc: 0.8833333333333333
epoch time: 3308.482 ms, per step time: 55.141 ms
Epoch 12/20
----------
train Loss: 0.13462222
val Acc: 0.95
epoch time: 3922.425 ms, per step time: 65.374 ms
Epoch 13/20
----------
train Loss: 0.46668455
val Acc: 0.8666666666666667
epoch time: 3366.989 ms, per step time: 56.116 ms
Epoch 14/20
----------
train Loss: 0.18877655
val Acc: 0.9166666666666666
epoch time: 3301.854 ms, per step time: 55.031 ms
Epoch 15/20
----------
train Loss: 0.30053577
val Acc: 0.9
epoch time: 3218.894 ms, per step time: 53.648 ms
Epoch 16/20
----------
train Loss: 0.19290532
val Acc: 0.8166666666666667
epoch time: 3241.427 ms, per step time: 54.024 ms
Epoch 17/20
----------
train Loss: 0.00813961
val Acc: 0.8833333333333333
epoch time: 3317.892 ms, per step time: 55.298 ms
Epoch 18/20
----------
train Loss: 0.09142441
val Acc: 0.8166666666666667
epoch time: 3365.341 ms, per step time: 56.089 ms
Epoch 19/20
----------
train Loss: 0.89299583
val Acc: 0.9
epoch time: 3441.966 ms, per step time: 57.366 ms
Epoch 20/20
----------
train Loss: 0.29071262
val Acc: 0.8166666666666667
epoch time: 3269.289 ms, per step time: 54.488 ms
End training, the best acc is: 0.95, the best acc epoch is 12

使用精度最好的模型对验证集进行可视化预测。

visualize_model('best.ckpt', val_ds)

结果: