启动多个进程并行采样有以下几个目的:
1.减少数据采集时间,提高效率
2.多个进程从不同初始状态采样,提供更丰富的数据,增加数据多样性
3.更高的数据更新频率加快模型收敛
4.充分利用硬件资源,平衡计算负载
5.显著减少训练时间,支持复杂任务
在Python中,启动多进程的方法通常使用multiprocessing模块,
以下是基本步骤:
导入multiprocessing模块
import multiprocessing as mp
定义工作函数
定义一个工作函数worker,该函数将在每个子进程中执行,工作函数通常包含与环境交互,数据采样和模型更新的逻辑。
def worker(process_id, replay_buffer, max_steps):
print(f"Process {process_id} is running")
for step in range(max_steps):
# 模拟数据采样
data = np.random.randn(10) # 假设采样 10 维数据
replay_buffer.push(data) # 将数据存入回放缓冲区
print(f"Process {process_id} finished")
process_id:进程的ID,用于区分不同的进程
shared_data:共享的数据
other_args:其他参数
创建共享对象
在多进程中,如果需要共享数据,可以使用multiprocessing提供的共享对象。
Queue:用于进程键的数据传递
Manager:用于创建共享的复杂对象
Value和Array:用于共享简单的数据类型
例如创建一个共享的回放缓冲区
from multiprocessing import Manager
manager = Manager()
replay_buffer = manager.ReplayBuffer(capacity=1e6) # 假设 ReplayBuffer 是一个自定义类
4.创建并启动进程
使用multiprocessing.Process创建多个进程,并启动
num_workers = 4 # 工作进程的数量
processes = []
for i in range(num_workers):
process = mp.Process(
target=worker, # 工作函数
args=(i, replay_buffer, other_args) # 传递给工作函数的参数
)
process.daemon = True # 设置为守护进程,主进程结束时子进程自动结束
processes.append(process)
# 启动所有进程
[p.start() for p in processes]
target = worker,指定工作函数
args= (..)传递给工作函数的参数
process.daemon = True将进程设置为守护进程,主进程结束时子进程自动结束
等待进程结束
使用join()方法等待所有进程结束
[p.join() for p in processes]
join():阻塞主进程,直到子进程执行完毕
完整代码示例
import multiprocessing as mp
from multiprocessing import Manager
import numpy as np
# 定义工作函数
def worker(process_id, replay_buffer, max_steps):
print(f"Process {process_id} is running")
for step in range(max_steps):
# 模拟数据采样
data = np.random.randn(10) # 假设采样 10 维数据
replay_buffer.push(data) # 将数据存入回放缓冲区
print(f"Process {process_id} finished")
# 自定义回放缓冲区类
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
def push(self, data):
if len(self.buffer) >= self.capacity:
self.buffer.pop(0)
self.buffer.append(data)
def get_length(self):
return len(self.buffer)
if __name__ == '__main__':
# 创建共享的回放缓冲区
manager = Manager()
replay_buffer = manager.ReplayBuffer(capacity=1e6)
# 启动多个进程
num_workers = 4
max_steps = 100
processes = []
for i in range(num_workers):
process = mp.Process(
target=worker,
args=(i, replay_buffer, max_steps)
)
process.daemon = True
processes.append(process)
# 启动所有进程
[p.start() for p in processes]
# 等待所有进程结束
[p.join() for p in processes]
print("All processes finished")
print("Replay buffer length:", replay_buffer.get_length())
上面的工作函数有点简陋,我们来看一个比较完整的
def worker(process_id, sac_trainer, rewards_queue, replay_buffer, max_episodes, max_steps,
batch_size, explore_steps, update_itr, AUTO_ENTROPY, DETERMINISTIC, USE_DEMONS,
hidden_dim, model_path, headless):
"""
工作函数:每个子进程执行的核心逻辑。
:param process_id: 进程 ID。
:param sac_trainer: SAC 训练器,包含策略网络、Q 网络和优化器。
:param rewards_queue: 奖励队列,用于存储每个 episode 的奖励。
:param replay_buffer: 回放缓冲区,用于存储经验数据。
:param max_episodes: 最大训练 episode 数。
:param max_steps: 每个 episode 的最大步数。
:param batch_size: 每次更新时从回放缓冲区中采样的批量大小。
:param explore_steps: 探索步数,在探索阶段使用随机动作。
:param update_itr: 每次采样后更新模型的迭代次数。
:param AUTO_ENTROPY: 是否自动调整熵系数。
:param DETERMINISTIC: 是否使用确定性策略。
:param USE_DEMONS: 是否使用演示数据。
:param hidden_dim: 神经网络的隐藏层维度。
:param model_path: 模型保存路径。
:param headless: 是否以无头模式运行环境。
"""
# 初始化环境
env = GraspEnv(headless=headless)
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]
frame_idx = 0 # 当前的总步数
rewards = [] # 存储每个 episode 的奖励
# 训练循环
for eps in range(max_episodes):
episode_reward = 0 # 当前 episode 的累积奖励
state = env.reset() # 重置环境,获取初始状态
# 每隔固定 episode 重新初始化环境,避免环境问题
if eps % 20 == 0 and eps > 0:
env.reinit()
# 每个 episode 的步数循环
for step in range(max_steps):
# 选择动作
if frame_idx > explore_steps:
action = sac_trainer.policy_net.get_action(state, deterministic=DETERMINISTIC)
else:
action = sac_trainer.policy_net.sample_action()
# 执行动作并获取下一状态、奖励和终止标志
try:
next_state, reward, done, _ = env.step(action)
except KeyboardInterrupt:
print('Finished')
sac_trainer.save_model(model_path)
env.shutdown()
return
# 将经验数据存入回放缓冲区
replay_buffer.push(state, action, reward, next_state, done)
# 更新状态和累积奖励
state = next_state
episode_reward += reward
frame_idx += 1
# 如果回放缓冲区中的数据量足够,则更新模型
if replay_buffer.get_length() > batch_size:
for i in range(update_itr):
sac_trainer.update(batch_size, reward_scale=10., auto_entropy=AUTO_ENTROPY,
use_demons=USE_DEMONS, target_entropy=-1. * action_dim)
# 每隔固定 episode 保存模型
if eps % 10 == 0 and eps > 0:
sac_trainer.save_model(model_path)
# 如果 episode 结束,则跳出循环
if done:
break
# 打印当前 episode 的奖励
print(f'Process {process_id}, Episode {eps}, Reward: {episode_reward}')
rewards_queue.put(episode_reward) # 将奖励存入队列
# 训练结束后保存模型并关闭环境
sac_trainer.save_model(model_path)
env.shutdown()
总结
启动多进程的方法包括以下步骤:
导入
multiprocessing
模块。定义工作函数:包含每个进程的执行逻辑。
创建共享对象:使用
Manager
、Queue
等共享数据。创建并启动进程:使用
Process
创建进程,并调用start()
启动。等待进程结束:使用
join()
等待所有进程完成。