从代码学习深度强化学习 - DQN PyTorch版

发布于:2025-06-19 ⋅ 阅读:(24) ⋅ 点赞:(0)

前言

欢迎来到深度强化学习的世界!如果你对 Q-learning 有所了解,你可能会知道它使用一个表格(Q-table)来存储每个状态-动作对的价值。然而,当状态空间变得巨大,甚至是连续的时候(比如一个小车在轨道上的位置),Q-table 就变得不切实际。这时,深度Q网络(Deep Q-Network, DQN)就闪亮登场了。

DQN 的核心思想是用一个神经网络来代替 Q-table,实现从状态到(各个动作的)Q值的映射。这使得我们能够处理具有连续或高维状态空间的环境。本文将以经典的 CartPole-v1 环境为例,通过一个完整的 PyTorch 代码实现,带你深入理解 DQN 的工作原理及其关键组成部分:神经网络近似经验回放目标网络

在这里插入图片描述

图 1 CartPole环境示意图

在 CartPole 环境中,智能体的任务是左右移动小车,以保持车上的杆子竖直不倒。这个环境的状态是连续的(车的位置、速度、杆的角度、角速度),而动作是离散的(向左或向右)。这正是DQN大显身手的完美场景。

让我们一起通过代码,揭开DQN的神秘面纱。

完整代码:下载链接

DQN 算法核心思想

在深入代码之前,我们先回顾一下 DQN 的几个关键概念。

Q-Learning 与函数近似

传统的 Q-learning 更新规则如下:
Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a)\leftarrow Q(s,a)+\alpha\left[r+\gamma\max_{a^{\prime}\in\mathcal{A}}Q(s^{\prime},a^{\prime})-Q(s,a)\right] Q(s,a)Q(s,a)+α[r+γaAmaxQ(s,a)Q(s,a)]

当状态是连续的,我们无法用表格记录所有 Q(s,a)。因此,我们引入一个带参数 w 的神经网络,即 Q-网络 Q ω ( s , a ) Q_\omega\left(s,a\right) Qω(s,a),来近似真实的 Q-函数。我们的目标是让网络预测的Q值 Q ω ( s , a ) Q_\omega\left(s,a\right) Qω(s,a) 逼近“目标Q值” r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) r+\gamma\max_{a^{\prime}\in\mathcal{A}}Q(s',a') r+γmaxaAQ(s,a)

为此,我们可以定义一个损失函数,最常见的就是均方误差(MSE Loss):
ω ∗ = arg ⁡ min ⁡ ω 1 2 N ∑ i = 1 N [ Q ω ( s i , a i ) − ( r i + γ max ⁡ a ′ Q ω ( s i ′ , a ′ ) ) ] 2 \omega^*=\arg\min_\omega\frac{1}{2N}\sum_{i=1}^N\left[Q_\omega\left(s_i,a_i\right)-\left(r_i+\gamma\max_{a^{\prime}}Q_\omega\left(s_i^{\prime},a^{\prime}\right)\right)\right]^2 ω=argωmin2N1i=1N[Q


网站公告

今日签到

点亮在社区的每一天
去签到