基于强化学习的Deep-Qlearning网络玩cartpole游戏

发布于:2024-08-08 ⋅ 阅读:(128) ⋅ 点赞:(0)

1、环境准备,gym的版本为0.26.2

2、编写网络代码

# 导入必要的库
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random


# 定义DQN网络
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        # 定义三层全连接网络
        self.fc1 = nn.Linear(state_size, 24)
        self.fc2 = nn.Linear(24, 24)
        self.fc3 = nn.Linear(24, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


# 定义DQN智能体
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)  # 经验回放池
        self.gamma = 0.95  # 折扣因子
        self.epsilon = 1.0  # 探索率
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = DQN(state_size, action_size).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

    def remember(self, state, action, reward, next_state, done):
        # 将经验存储到经验回放池中
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        # ε-贪婪策略选择动作
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        act_values = self.model(state)
        return np.argmax(act_values.cpu().data.numpy())

    def replay(self, batch_size):
        # 从经验回放池中随机采样进行学习
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done:
                next_state = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
                target = (reward + self.gamma * np.amax(self.model(next_state).cpu().data.numpy()))
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            target_f = self.model(state)
            target_f[0][action] = target
            self.optimizer.zero_grad()
            loss = nn.MSELoss()(self.model(state), target_f)
            loss.backward()
            self.optimizer.step()
        # 更新探索率
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def load(self, name):
        # 加载模型
        self.model.load_state_dict(torch.load(name))

    def save(self, name):
        # 保存模型
        torch.save(self.model.state_dict(), name)


# 训练函数
def train_dqn():
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)
    episodes = 1000
    batch_size = 32

    for e in range(episodes):
        state, _ = env.reset() #重置环境,返回初始观察值和初始奖励
        for time in range(500):
            action = agent.act(state)
            next_state, reward, done, _, _ = env.step(action) # 执行动作,返回5个数值
            reward = reward if not done else -10 # 如果游戏结束,给予负奖励
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                print(f"episode: {e}/{episodes}, score: {time}, epsilon: {agent.epsilon:.2}")
                break
            if len(agent.memory) > batch_size:
                agent.replay(batch_size)
        if e % 100 == 0:
            agent.save(f"cartpole-dqn-{e}.pth")  # 每100回合保存一次模型


# 使用训练好的模型玩游戏
def play_cartpole():
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)
    agent.load("cartpole-dqn-900.pth")  # 加载训练好的模型

    for e in range(10):  # 玩10局
        state, _ = env.reset()
        for time in range(500):
            env.render()
            action = agent.act(state)
            next_state, reward, done, _, _= env.step(action)
            state = next_state
            if done:
                print(f"episode: {e}, score: {time}")
                break
    env.close()

if __name__ == '__main__':
    # 如果要训练模型,取消下面这行的注释
    # train_dqn()
    # 如果要使用训练好的模型玩游戏,取消下面这行的注释
    play_cartpole()

更多解析请参考:https://zhuanlan.zhihu.com/p/29283993

 https://zhuanlan.zhihu.com/p/29213893


网站公告

今日签到

点亮在社区的每一天
去签到