DQN(Deep Q-Network)算法详解
DQN(Deep Q-Network)算法详解:深度强化学习的里程碑
在强化学习的浩瀚宇宙中,DQN(Deep Q-Network,简称DQN)算法无疑是一座璀璨的里程碑,它首次将深度学习的强大功能引入Q学习,为解决高维状态空间中的复杂决策问题打开了新纪元。本文将深入解析DQN算法的内在原理,探讨其为何能在众多领域中引发变革,并通过Python代码实例,带领你亲手构建一个DQN模型,亲历深度强化学习的奥秘。
DQN算法原理
DQN算法的核心思想在于利用神经网络近似Q函数,即Q值函数,而不是传统Q学习中的Q表。这使得算法能处理状态空间巨大乃至连续的问题,因为神经网络能够学习到状态的抽象特征表示。算法主要包括以下关键组件:
- 经验回放缓冲(Experience Replay):存储过往的经验(状态、动作、奖励、新状态、是否终止标志),并在训练时随机抽取样本,减少数据的相关性,稳定学习过程。
- 固定Q-targets(Fixed Q-targets):保持目标网络参数固定一段时间,减缓训练波动,优化更稳定。
- 神经网络:作为Q值函数的近似器,输入为状态,输出为在该状态下每个动作的Q值。
代码实现
以经典的CartPole平衡任务为例,我们使用Keras框架实现一个基本的DQN模型。
import numpy as np
import gym
from collections import deque
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam
# 环境设置
env = gym.make('CartPole-v1')
state_space = env.observation_space.shape[0]
action_space = env.action_space.n
# DQN参数
buffer_size = 10000
batch_size = 32
gamma = 0.95
eps_start = 1.0
eps_end = 0.1
exploration_fraction = 0.1
target_update_freq = 100
learning_rate = 0.001
# 经验回放缓冲
memory = deque(maxlen(buffer_size))
# 主网络与目标网络
def build_model():
model = Sequential()
model.add(Flatten(input_shape=(1,) + (state_space,))
model.add(Dense(24, activation='relu'))
model.add(Dense(24, activation='relu'))
model.add(Dense(action_space, activation='linear'))
return model
main_model = build_model()
target_model = build_model()
target_model.set_weights(main_model.get_weights())
# 训练习函数
def train(batch_size):
minibatch = random.sample(memory, batch_size)
states, actions, rewards, next_states, dones = zip(*minibatch)
states = np.array(states)
next_states = np.array(next_states)
q_values = main_model.predict_on_batch(states)
next_q_values = target_model.predict_on_batch(next_states)
max_next_q = np.max(next_q_values, axis=1)
targets = rewards + gamma * (1 - dones) * max_next_q
# 更新Q值
q_values[np.arange(len(states), actions] = targets
main_model.train_on_batch(states, q_values)
# 主循环
for episode in range(num_episodes):
state = env.reset()
done = False
episode_reward = 0
while not done:
if np.random.rand() < eps or episode < exploration_fraction * num_episodes:
action = env.action_space.sample()
else:
q_values = main_model.predict(np.expand_dims(state, axis=0))
action = np.argmax(q_values)
next_state, reward, done, _ = env.step(action)
memory.append((state, action, reward, next_state, done))
episode_reward += reward
# 经验回放缓冲满时训练
if len(memory) > batch_size:
train(batch_size)
state = next_state
# 定期更新目标网络
if episode % target_update_freq == 0:
target_model.set_weights(main_model.get_weights())
print(f"Episode {episode}: Reward: {episode_reward}")
env.close()
结语
通过上述代码示例,我们不仅理解了DQN算法的精髓,还亲自构建了一个简单的DQN模型解决CartPole平衡问题。DQN算法的成功在于其创新地结合了深度学习的表达力与Q学习的决策框架,为强化学习领域的突破性进展铺平了道路。随着研究的深入,诸如Double DQN、Dueling DQN等进一步优化了原始模型,强化学习的边界不断被拓宽。未来,DQN及其变种将在更广泛的领域,如自动驾驶、机器人控制、游戏AI等,发挥关键作用,持续推动智能系统的进步。