- 功能描述: MindSpore训练模型时,实现保存最优模型。
- 实现保存最优模型功能简介: 在面对复杂网络时,往往需要进行几十甚至几百次的epoch训练。在训练之前,很难掌握在训练到第几个epoch时,模型的精度能达到满足要求的程度,所以经常会采用一边训练的同时,在相隔固定epoch的位置对模型进行精度验证,并保存相应的模型,等训练完毕后,通过查看对应模型精度的变化就能迅速地挑选出相对最优的模型。 流程如下:
- 定义回调函数EvalCallBack,实现同步进行训练和验证。
- 定义训练网络并执行。
- 将不同epoch下的模型精度绘制出折线图并挑选最优模型。
- 原因分析: MindSpore在训练模型时,保存最后一个ckpt可能精度不达标。
- 解决方案: apply_eval函数,用来验证模型的精度。定义回调函数EvalCallBack: 模型验证
我们自定义一个数据收集的回调类EvalCallBack,用于实现下面两种信息:def apply_eval(eval_param): eval_model = eval_param['model'] eval_ds = eval_param['dataset'] metrics_name = eval_param['metrics_name'] res = eval_model.eval(eval_ds) return res[metrics_name]
- 4.1 训练过程中,每一个epoch结束之后,训练集的损失值和验证集的模型精度。
- 4.2 保存精度最高的模型。
class EvalCallBack(Callback): """ 回调类,获取训练过程中模型的信息 """ def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): super(EvalCallBack, self).__init__() self.eval_param_dict = eval_param_dict self.eval_function = eval_function self.eval_start_epoch = eval_start_epoch if interval < 1: raise ValueError("interval should >= 1.") self.interval = interval self.save_best_ckpt = save_best_ckpt self.best_res = 0 self.best_epoch = 0 if not os.path.isdir(ckpt_directory): os.makedirs(ckpt_directory) self.best_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) self.metrics_name = metrics_name
删除ckpt文件
def remove_ckpoint_file(self, file_name):
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
每一个epoch后,打印训练集的损失值和验证集的模型精度,并保存精度最好的ckpt文件
def epoch_end(self, run_context):
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
loss_epoch = cb_params.net_outputs
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print('Epoch {}/{}'.format(cur_epoch, num_epochs))
print('-' * 10)
print('train Loss: {}'.format(loss_epoch))
print('val Acc: {}'.format(res))
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
if self.save_best_ckpt:
if os.path.exists(self.best_ckpt_path):
self.remove_ckpoint_file(self.best_ckpt_path)
save_checkpoint(cb_params.train_network, self.best_ckpt_path)
训练结束后,打印最好的精度和对应的epoch
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, self.best_res, self.best_epoch), flush=True)
训练和评估 运行下面代码,开始模型训练。
train_ds = create_dataset(train_data_path)
val_ds = create_dataset(val_data_path)
eval_param_dict = {"model":model,"dataset":val_ds,"metrics_name":"Accuracy"}
eval_cb = EvalCallBack(apply_eval, eval_param_dict,)
训练模型
model.train(num_epochs,train_ds, callbacks=[eval_cb, TimeMonitor()], dataset_sink_mode=True)
复制
结果:
Epoch 1/20
----------
train Loss: 0.47486544
val Acc: 0.8333333333333334
epoch time: 8439.054 ms, per step time: 140.651 ms
Epoch 2/20
----------
train Loss: 0.20464368
val Acc: 0.8333333333333334
epoch time: 3805.755 ms, per step time: 63.429 ms
Epoch 3/20
----------
train Loss: 0.3345307
val Acc: 0.9166666666666666
epoch time: 3721.042 ms, per step time: 62.017 ms
Epoch 4/20
----------
train Loss: 0.7761406
val Acc: 0.8333333333333334
epoch time: 3302.892 ms, per step time: 55.048 ms
Epoch 5/20
----------
train Loss: 0.3566268
val Acc: 0.9
epoch time: 3375.371 ms, per step time: 56.256 ms
Epoch 6/20
----------
train Loss: 0.13434622
val Acc: 0.9333333333333333
epoch time: 4012.532 ms, per step time: 66.876 ms
Epoch 7/20
----------
train Loss: 0.20843573
val Acc: 0.85
epoch time: 3357.198 ms, per step time: 55.953 ms
Epoch 8/20
----------
train Loss: 0.96780926
val Acc: 0.95
epoch time: 3628.576 ms, per step time: 60.476 ms
Epoch 9/20
----------
train Loss: 1.4824448
val Acc: 0.8666666666666667
epoch time: 3403.053 ms, per step time: 56.718 ms
Epoch 10/20
----------
train Loss: 0.11375467
val Acc: 0.9166666666666666
epoch time: 3293.931 ms, per step time: 54.899 ms
Epoch 11/20
----------
train Loss: 0.14315866
val Acc: 0.8833333333333333
epoch time: 3308.482 ms, per step time: 55.141 ms
Epoch 12/20
----------
train Loss: 0.13462222
val Acc: 0.95
epoch time: 3922.425 ms, per step time: 65.374 ms
Epoch 13/20
----------
train Loss: 0.46668455
val Acc: 0.8666666666666667
epoch time: 3366.989 ms, per step time: 56.116 ms
Epoch 14/20
----------
train Loss: 0.18877655
val Acc: 0.9166666666666666
epoch time: 3301.854 ms, per step time: 55.031 ms
Epoch 15/20
----------
train Loss: 0.30053577
val Acc: 0.9
epoch time: 3218.894 ms, per step time: 53.648 ms
Epoch 16/20
----------
train Loss: 0.19290532
val Acc: 0.8166666666666667
epoch time: 3241.427 ms, per step time: 54.024 ms
Epoch 17/20
----------
train Loss: 0.00813961
val Acc: 0.8833333333333333
epoch time: 3317.892 ms, per step time: 55.298 ms
Epoch 18/20
----------
train Loss: 0.09142441
val Acc: 0.8166666666666667
epoch time: 3365.341 ms, per step time: 56.089 ms
Epoch 19/20
----------
train Loss: 0.89299583
val Acc: 0.9
epoch time: 3441.966 ms, per step time: 57.366 ms
Epoch 20/20
----------
train Loss: 0.29071262
val Acc: 0.8166666666666667
epoch time: 3269.289 ms, per step time: 54.488 ms
End training, the best acc is: 0.95, the best acc epoch is 12
使用精度最好的模型对验证集进行可视化预测。
visualize_model('best.ckpt', val_ds)
结果: