def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
if config.get(“prioritized_replay”):
prio_args = {
“prioritized_replay_alpha”: config[“prioritized_replay_alpha”],
“prioritized_replay_beta”: config[“prioritized_replay_beta”],
“prioritized_replay_eps”: config[“prioritized_replay_eps”],
}
else:
prio_args = {}
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
#这一行需要注释掉,如果不注释掉,整个代码就跑不起来,可能是因为ray1.4.1版本没有这个参数
# replay_sequence_length=config["replay_sequence_length"],
**prio_args)
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# Update penalty
rollouts = rollouts.for_each(UpdateSaverPenalty(workers))
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer. Calling
# next() on store_op drives this.
store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=local_replay_buffer))
def update_prio(item):
samples, info_dict = item
if config.get("prioritized_replay"):
prio_dict = {}
for policy_id, info in info_dict.items():
# TODO(sven): This is currently structured differently for
# torch/tf. Clean up these results/info dicts across
# policies (note: fixing this in torch_policy.py will
# break e.g. DDPPO!).
td_error = info.get("td_error",
info[LEARNER_STATS_KEY].get("td_error"))
prio_dict[policy_id] = (samples.policy_batches[policy_id]
.data.get("batch_indexes"), td_error)
local_replay_buffer.update_priorities(prio_dict)
return info_dict
# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target network.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
replay_op = Replay(local_buffer=local_replay_buffer) \
.for_each(lambda x: post_fn(x, workers, config)) \
.for_each(TrainOneStep(workers)) \
.for_each(update_prio) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))
# Alternate deterministically between (1) and (2). Only return the output
# of (2) since training metrics are not available until (2) runs.
train_op = Concurrently(
[store_op, replay_op],
mode="round_robin",
output_indexes=[1],
round_robin_weights=calculate_rr_weights(config))
return StandardMetricsReporting(train_op, workers, config)