ray.rllib 入门实践-6: 保存模型

发布于:2025-02-10 ⋅ 阅读:(52) ⋅ 点赞:(0)

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

训练模型后保存模型,比较简单,这里简单介绍。

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print


## 配置算法
storage_path = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs"
os.makedirs(storage_path,exist_ok=True)

config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config.output = storage_path  ## 设置过程文件的存储路径

## 构建算法
algo = config.build()

## 训练算法
for i in range(3):
    result = algo.train() 
    print(f"episode_{i}")

## 保存模型
## 方法1: 保存到默认路径下
algo.save() ## 保存到默认路径下, 一般是: ~/ray_result 文件夹下, 或 C:\Users\xxx\ray_results\ 文件夹下
            ## 上面设置的 config.output 只用于保存一些过程文件,不能决定这里的存储位置

## 方法2: 保存到默认路径下,并返回保存路径
checkpoint_dir = algo.save().checkpoint.path
print(f"Checkpoint saved in directory {checkpoint_dir}")

## 方法3: 保存到指定路径下
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
os.makedirs(checkpoint_dir,exist_ok=True)
algo.save_checkpoint(checkpoint_dir) ## 保存到指定路径下
print(f"saved checkpoint to {checkpoint_dir}")


网站公告

今日签到

点亮在社区的每一天
去签到