25/1/7 算法笔记<强化学习> 多进程方法

发布于:2025-02-10 ⋅ 阅读:(29) ⋅ 点赞:(0)

启动多个进程并行采样有以下几个目的:

1.减少数据采集时间,提高效率

2.多个进程从不同初始状态采样,提供更丰富的数据,增加数据多样性

3.更高的数据更新频率加快模型收敛

4.充分利用硬件资源,平衡计算负载

5.显著减少训练时间,支持复杂任务

在Python中,启动多进程的方法通常使用multiprocessing模块,

以下是基本步骤:

导入multiprocessing模块

import multiprocessing as mp

定义工作函数

定义一个工作函数worker,该函数将在每个子进程中执行,工作函数通常包含与环境交互,数据采样和模型更新的逻辑。

def worker(process_id, replay_buffer, max_steps):
    print(f"Process {process_id} is running")
    for step in range(max_steps):
        # 模拟数据采样
        data = np.random.randn(10)  # 假设采样 10 维数据
        replay_buffer.push(data)  # 将数据存入回放缓冲区
    print(f"Process {process_id} finished")

process_id:进程的ID,用于区分不同的进程

shared_data:共享的数据

other_args:其他参数

创建共享对象

在多进程中,如果需要共享数据,可以使用multiprocessing提供的共享对象。

Queue:用于进程键的数据传递

Manager:用于创建共享的复杂对象

Value和Array:用于共享简单的数据类型

例如创建一个共享的回放缓冲区

from multiprocessing import Manager

manager = Manager()
replay_buffer = manager.ReplayBuffer(capacity=1e6)  # 假设 ReplayBuffer 是一个自定义类

4.创建并启动进程

使用multiprocessing.Process创建多个进程,并启动

num_workers = 4  # 工作进程的数量
processes = []

for i in range(num_workers):
    process = mp.Process(
        target=worker,  # 工作函数
        args=(i, replay_buffer, other_args)  # 传递给工作函数的参数
    )
    process.daemon = True  # 设置为守护进程,主进程结束时子进程自动结束
    processes.append(process)

# 启动所有进程
[p.start() for p in processes]

target = worker,指定工作函数

args= (..)传递给工作函数的参数

process.daemon = True将进程设置为守护进程,主进程结束时子进程自动结束

等待进程结束

使用join()方法等待所有进程结束

[p.join() for p in processes]

join():阻塞主进程,直到子进程执行完毕

完整代码示例

import multiprocessing as mp
from multiprocessing import Manager
import numpy as np

# 定义工作函数
def worker(process_id, replay_buffer, max_steps):
    print(f"Process {process_id} is running")
    for step in range(max_steps):
        # 模拟数据采样
        data = np.random.randn(10)  # 假设采样 10 维数据
        replay_buffer.push(data)  # 将数据存入回放缓冲区
    print(f"Process {process_id} finished")

# 自定义回放缓冲区类
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []

    def push(self, data):
        if len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
        self.buffer.append(data)

    def get_length(self):
        return len(self.buffer)

if __name__ == '__main__':
    # 创建共享的回放缓冲区
    manager = Manager()
    replay_buffer = manager.ReplayBuffer(capacity=1e6)

    # 启动多个进程
    num_workers = 4
    max_steps = 100
    processes = []

    for i in range(num_workers):
        process = mp.Process(
            target=worker,
            args=(i, replay_buffer, max_steps)
        )
        process.daemon = True
        processes.append(process)

    # 启动所有进程
    [p.start() for p in processes]

    # 等待所有进程结束
    [p.join() for p in processes]

    print("All processes finished")
    print("Replay buffer length:", replay_buffer.get_length())

上面的工作函数有点简陋,我们来看一个比较完整的

def worker(process_id, sac_trainer, rewards_queue, replay_buffer, max_episodes, max_steps,
           batch_size, explore_steps, update_itr, AUTO_ENTROPY, DETERMINISTIC, USE_DEMONS,
           hidden_dim, model_path, headless):
    """
    工作函数:每个子进程执行的核心逻辑。
    :param process_id: 进程 ID。
    :param sac_trainer: SAC 训练器,包含策略网络、Q 网络和优化器。
    :param rewards_queue: 奖励队列,用于存储每个 episode 的奖励。
    :param replay_buffer: 回放缓冲区,用于存储经验数据。
    :param max_episodes: 最大训练 episode 数。
    :param max_steps: 每个 episode 的最大步数。
    :param batch_size: 每次更新时从回放缓冲区中采样的批量大小。
    :param explore_steps: 探索步数,在探索阶段使用随机动作。
    :param update_itr: 每次采样后更新模型的迭代次数。
    :param AUTO_ENTROPY: 是否自动调整熵系数。
    :param DETERMINISTIC: 是否使用确定性策略。
    :param USE_DEMONS: 是否使用演示数据。
    :param hidden_dim: 神经网络的隐藏层维度。
    :param model_path: 模型保存路径。
    :param headless: 是否以无头模式运行环境。
    """
    # 初始化环境
    env = GraspEnv(headless=headless)
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]

    frame_idx = 0  # 当前的总步数
    rewards = []  # 存储每个 episode 的奖励

    # 训练循环
    for eps in range(max_episodes):
        episode_reward = 0  # 当前 episode 的累积奖励
        state = env.reset()  # 重置环境,获取初始状态

        # 每隔固定 episode 重新初始化环境,避免环境问题
        if eps % 20 == 0 and eps > 0:
            env.reinit()

        # 每个 episode 的步数循环
        for step in range(max_steps):
            # 选择动作
            if frame_idx > explore_steps:
                action = sac_trainer.policy_net.get_action(state, deterministic=DETERMINISTIC)
            else:
                action = sac_trainer.policy_net.sample_action()

            # 执行动作并获取下一状态、奖励和终止标志
            try:
                next_state, reward, done, _ = env.step(action)
            except KeyboardInterrupt:
                print('Finished')
                sac_trainer.save_model(model_path)
                env.shutdown()
                return

            # 将经验数据存入回放缓冲区
            replay_buffer.push(state, action, reward, next_state, done)

            # 更新状态和累积奖励
            state = next_state
            episode_reward += reward
            frame_idx += 1

            # 如果回放缓冲区中的数据量足够,则更新模型
            if replay_buffer.get_length() > batch_size:
                for i in range(update_itr):
                    sac_trainer.update(batch_size, reward_scale=10., auto_entropy=AUTO_ENTROPY,
                                      use_demons=USE_DEMONS, target_entropy=-1. * action_dim)

            # 每隔固定 episode 保存模型
            if eps % 10 == 0 and eps > 0:
                sac_trainer.save_model(model_path)

            # 如果 episode 结束,则跳出循环
            if done:
                break

        # 打印当前 episode 的奖励
        print(f'Process {process_id}, Episode {eps}, Reward: {episode_reward}')
        rewards_queue.put(episode_reward)  # 将奖励存入队列

    # 训练结束后保存模型并关闭环境
    sac_trainer.save_model(model_path)
    env.shutdown()

总结

启动多进程的方法包括以下步骤:

  1. 导入 multiprocessing 模块

  2. 定义工作函数:包含每个进程的执行逻辑。

  3. 创建共享对象:使用 ManagerQueue 等共享数据。

  4. 创建并启动进程:使用 Process 创建进程,并调用 start() 启动。

  5. 等待进程结束:使用 join() 等待所有进程完成。


网站公告

今日签到

点亮在社区的每一天
去签到