使用 DQN 增加 Cartpole-v0 损失

2024-01-01

您好,我正在尝试训练 DQN 来解决健身房的 Cartpole 问题。 由于某种原因Loss https://i.stack.imgur.com/uHxpR.png看起来像这样(橙色线)。你们能看一下我的代码并帮忙解决这个问题吗?我已经对超参数进行了相当多的研究,所以我认为它们不是这里的问题。

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.linear1 = nn.Linear(input_dim, 16)
        self.linear2 = nn.Linear(16, 32)
        self.linear3 = nn.Linear(32, 32)
        self.linear4 = nn.Linear(32, output_dim)


    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        return self.linear4(x)


final_epsilon = 0.05
initial_epsilon = 1
epsilon_decay = 5000
global steps_done
steps_done = 0


def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = final_epsilon + (initial_epsilon - final_epsilon) * \
                    math.exp(-1. * steps_done / epsilon_decay)
    if sample > eps_threshold:
        with torch.no_grad():
            state = torch.Tensor(state)
            steps_done += 1
            q_calc = model(state)
            node_activated = int(torch.argmax(q_calc))
            return node_activated
    else:
        node_activated = random.randint(0,1)
        steps_done += 1
        return node_activated


class ReplayMemory(object): # Stores [state, reward, action, next_state, done]

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = [[],[],[],[],[]]

    def push(self, data):
        """Saves a transition."""
        for idx, point in enumerate(data):
            #print("Col {} appended {}".format(idx, point))
            self.memory[idx].append(point)

    def sample(self, batch_size):
        rows = random.sample(range(0, len(self.memory[0])), batch_size)
        experiences = [[],[],[],[],[]]
        for row in rows:
            for col in range(5):
                experiences[col].append(self.memory[col][row])
        return experiences

    def __len__(self):
        return len(self.memory[0])


input_dim, output_dim = 4, 2
model = DQN(input_dim, output_dim)
target_net = DQN(input_dim, output_dim)
target_net.load_state_dict(model.state_dict())
target_net.eval()
tau = 2
discount = 0.99

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

memory = ReplayMemory(65536)
BATCH_SIZE = 128


def optimize_model():
    if len(memory) < BATCH_SIZE:
        return 0
    experiences = memory.sample(BATCH_SIZE)
    state_batch = torch.Tensor(experiences[0])
    action_batch = torch.LongTensor(experiences[1]).unsqueeze(1)
    reward_batch = torch.Tensor(experiences[2])
    next_state_batch = torch.Tensor(experiences[3])
    done_batch = experiences[4]

    pred_q = model(state_batch).gather(1, action_batch)

    next_state_q_vals = torch.zeros(BATCH_SIZE)

    for idx, next_state in enumerate(next_state_batch):
        if done_batch[idx] == True:
            next_state_q_vals[idx] = -1
        else:
            # .max in pytorch returns (values, idx), we only want vals
            next_state_q_vals[idx] = (target_net(next_state_batch[idx]).max(0)[0]).detach()


    better_pred = (reward_batch + next_state_q_vals).unsqueeze(1)

    loss = F.smooth_l1_loss(pred_q, better_pred)
    optimizer.zero_grad()
    loss.backward()
    for param in model.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    return loss


points = []
losspoints = []

#save_state = torch.load("models/DQN_target_11.pth")
#model.load_state_dict(save_state['state_dict'])
#optimizer.load_state_dict(save_state['optimizer'])



env = gym.make('CartPole-v0')
for i_episode in range(5000):
    observation = env.reset()
    episode_loss = 0
    if episode % tau == 0:
        target_net.load_state_dict(model.state_dict())
    for t in range(1000):
        #env.render()
        state = observation
        action = select_action(observation)
        observation, reward, done, _ = env.step(action)

        if done:
            next_state = [0,0,0,0]
        else:
            next_state = observation

        memory.push([state, action, reward, next_state, done])
        episode_loss = episode_loss + float(optimize_model(i_episode))
        if done:
            points.append((i_episode, t+1))
            print("Episode {} finished after {} timesteps".format(i_episode, t+1))
            print("Avg Loss: ", episode_loss / (t+1))
            losspoints.append((i_episode, episode_loss / (t+1)))
            if (i_episode % 100 == 0):
                eps = final_epsilon + (initial_epsilon - final_epsilon) * \
                    math.exp(-1. * steps_done / epsilon_decay)
                print(eps)
            if ((i_episode+1) % 5001 == 0):
                save = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
                torch.save(save, "models/DQN_target_" + str(i_episode // 5000) + ".pth")
            break
env.close()




x = [coord[0] * 100 for coord in points]
y = [coord[1] for coord in points]

x2 = [coord[0] * 100 for coord in losspoints]
y2 = [coord[1] for coord in losspoints]

plt.plot(x, y)
plt.plot(x2, y2)
plt.show()

我基本上遵循了 pytorch 的教程,除了使用 env 返回的状态而不是像素。我还更改了重播内存,因为我在那里遇到了问题。除此之外,其他一切我都保持原样。

Edit:

我尝试过拟合一小批,损失看起来像this https://i.stack.imgur.com/ZwjrU.png无需更新目标网络和this https://i.stack.imgur.com/VOOPM.png更新的时候

Edit 2:

这绝对是目标网络的问题,我尝试将其删除,损失似乎并没有呈指数级增长


Your tau值太小,目标网络更新小导致DQN训练不稳定。您可以尝试使用 1000(OpenAI Baselines DQN 示例)或 10000(Deepmind 的 Nature 论文)。

Deepmind 2015 年 Nature 论文中指出:

The second modification to online Q-learning aimed at further improving the stability of our method with neural networks is to use a separate network for generating the traget yj in the Q-learning update. More precisely, every C updates we clone the network Q to obtain a target network Q' and use Q' for generating the Q-learning targets yj for the following C updates to Q. This modification makes the algorithm more stable compared to standard online Q-learning, where an update that increases Q(st,at) often also increases Q(st+1, a) for all a and hence also increases the target yj, possibly leading to oscillations or divergence of the policy. Generating the targets using the older set of parameters adds a delay between the time an update to Q is made and the time the update affects the targets yj, making divergence or oscillations much more unlikely.

通过深度强化实现人性化控制 学习,Mnih 等,2015 https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf

我已经使用以下设置运行了您的代码tau=2, tau=10, tau=100, tau=1000 and tau=10000。更新频率为tau=100解决问题(达到最大步数 200)。

tau=2

tau=10 enter image description here

tau=100 enter image description here

tau=1000 enter image description here

tau=10000 enter image description here

以下是您的代码的修改版本。

import random
import math
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
import gym

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.linear1 = nn.Linear(input_dim, 16)
        self.linear2 = nn.Linear(16, 32)
        self.linear3 = nn.Linear(32, 32)
        self.linear4 = nn.Linear(32, output_dim)


    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        return self.linear4(x)


final_epsilon = 0.05
initial_epsilon = 1
epsilon_decay = 5000
global steps_done
steps_done = 0


def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = final_epsilon + (initial_epsilon - final_epsilon) * \
                    math.exp(-1. * steps_done / epsilon_decay)
    if sample > eps_threshold:
        with torch.no_grad():
            state = torch.Tensor(state)
            steps_done += 1
            q_calc = model(state)
            node_activated = int(torch.argmax(q_calc))
            return node_activated
    else:
        node_activated = random.randint(0,1)
        steps_done += 1
        return node_activated


class ReplayMemory(object): # Stores [state, reward, action, next_state, done]

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = [[],[],[],[],[]]

    def push(self, data):
        """Saves a transition."""
        for idx, point in enumerate(data):
            #print("Col {} appended {}".format(idx, point))
            self.memory[idx].append(point)

    def sample(self, batch_size):
        rows = random.sample(range(0, len(self.memory[0])), batch_size)
        experiences = [[],[],[],[],[]]
        for row in rows:
            for col in range(5):
                experiences[col].append(self.memory[col][row])
        return experiences

    def __len__(self):
        return len(self.memory[0])


input_dim, output_dim = 4, 2
model = DQN(input_dim, output_dim)
target_net = DQN(input_dim, output_dim)
target_net.load_state_dict(model.state_dict())
target_net.eval()
tau = 100
discount = 0.99

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

memory = ReplayMemory(65536)
BATCH_SIZE = 128


def optimize_model():
    if len(memory) < BATCH_SIZE:
        return 0
    experiences = memory.sample(BATCH_SIZE)
    state_batch = torch.Tensor(experiences[0])
    action_batch = torch.LongTensor(experiences[1]).unsqueeze(1)
    reward_batch = torch.Tensor(experiences[2])
    next_state_batch = torch.Tensor(experiences[3])
    done_batch = experiences[4]

    pred_q = model(state_batch).gather(1, action_batch)

    next_state_q_vals = torch.zeros(BATCH_SIZE)

    for idx, next_state in enumerate(next_state_batch):
        if done_batch[idx] == True:
            next_state_q_vals[idx] = -1
        else:
            # .max in pytorch returns (values, idx), we only want vals
            next_state_q_vals[idx] = (target_net(next_state_batch[idx]).max(0)[0]).detach()


    better_pred = (reward_batch + next_state_q_vals).unsqueeze(1)

    loss = F.smooth_l1_loss(pred_q, better_pred)
    optimizer.zero_grad()
    loss.backward()
    for param in model.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    return loss


points = []
losspoints = []

#save_state = torch.load("models/DQN_target_11.pth")
#model.load_state_dict(save_state['state_dict'])
#optimizer.load_state_dict(save_state['optimizer'])



env = gym.make('CartPole-v0')
for i_episode in range(5000):
    observation = env.reset()
    episode_loss = 0
    if i_episode % tau == 0:
        target_net.load_state_dict(model.state_dict())
    for t in range(1000):
        #env.render()
        state = observation
        action = select_action(observation)
        observation, reward, done, _ = env.step(action)

        if done:
            next_state = [0,0,0,0]
        else:
            next_state = observation

        memory.push([state, action, reward, next_state, done])
        episode_loss = episode_loss + float(optimize_model())
        if done:
            points.append((i_episode, t+1))
            print("Episode {} finished after {} timesteps".format(i_episode, t+1))
            print("Avg Loss: ", episode_loss / (t+1))
            losspoints.append((i_episode, episode_loss / (t+1)))
            if (i_episode % 100 == 0):
                eps = final_epsilon + (initial_epsilon - final_epsilon) * \
                    math.exp(-1. * steps_done / epsilon_decay)
                print(eps)
            if ((i_episode+1) % 5001 == 0):
                save = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
                torch.save(save, "models/DQN_target_" + str(i_episode // 5000) + ".pth")
            break
env.close()




x = [coord[0] * 100 for coord in points]
y = [coord[1] for coord in points]

x2 = [coord[0] * 100 for coord in losspoints]
y2 = [coord[1] for coord in losspoints]

plt.plot(x, y)
plt.plot(x2, y2)
plt.show()

这是您的绘图代码的结果。

tau=100 enter image description here

tau=10000 Blue is the average period and orange is the loss

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 DQN 增加 Cartpole-v0 损失 的相关文章

  • Django REST序列化器:创建对象而不保存

    我已经开始使用 Django REST 框架 我想做的是使用一些 JSON 发布请求 从中创建一个 Django 模型对象 然后使用该对象而不保存它 我的 Django 模型称为 SearchRequest 我所拥有的是 api view
  • 如何在python中读取多个文件中的文本

    我的文件夹中有许多文本文件 大约有 3000 个文件 每个文件中第 193 行是唯一包含重要信息的行 我如何使用 python 将所有这些文件读入 1 个文本文件 os 模块中有一个名为 list dir 的函数 该函数返回给定目录中所有文
  • Python PAM 模块的安全问题?

    我有兴趣编写一个 PAM 模块 该模块将利用流行的 Unix 登录身份验证机制 我过去的大部分编程经验都是使用 Python 进行的 并且我正在交互的系统已经有一个 Python API 我用谷歌搜索发现pam python http pa
  • 如何使用固定的 pandas 数据框进行动态 matplotlib 绘图?

    我有一个名为的数据框benchmark returns and strategy returns 两者具有相同的时间跨度 我想找到一种方法以漂亮的动画风格绘制数据点 以便它显示逐渐加载的所有点 我知道有一个matplotlib animat
  • 如何生成给定范围内的回文数列表?

    假设范围是 1 X 120 这是我尝试过的 gt gt gt def isPalindrome s check if a number is a Palindrome s str s return s s 1 gt gt gt def ge
  • Pycharm Python 控制台不打印输出

    我有一个从 Pycharm python 控制台调用的函数 但没有显示输出 In 2 def problem1 6 for i in range 1 101 2 print i end In 3 problem1 6 In 4 另一方面 像
  • __del__ 真的是析构函数吗?

    我主要用 C 做事情 其中 析构函数方法实际上是为了销毁所获取的资源 最近我开始使用python 这真的很有趣而且很棒 我开始了解到它有像java一样的GC 因此 没有过分强调对象所有权 构造和销毁 据我所知 init 方法对我来说在 py
  • 从列表中的数据框列中搜索部分字符串匹配 - Pandas - Python

    我有一个清单 things A1 B2 C3 我有一个 pandas 数据框 其中有一列包含用分号分隔的值 某些行将包含与上面列表中的一项的匹配 它不会是完美的匹配 因为它在其中包含字符串的其他部分 该列 例如 该列中的一行可能有 哇 这里
  • python 集合可以包含的值的数量是否有限制?

    我正在尝试使用 python 设置作为 mysql 表中 ids 的过滤器 python集存储了所有要过滤的id 现在大约有30000个 这个数字会随着时间的推移慢慢增长 我担心python集的最大容量 它可以包含的元素数量有限制吗 您最大
  • 当玩家触摸屏幕一侧时,如何让 pygame 发出警告?

    我使用 pygame 创建了一个游戏 当玩家触摸屏幕一侧时 我想让 pygame 给出类似 你不能触摸屏幕两侧 的错误 我尝试在互联网上搜索 但没有找到任何好的结果 我想过在屏幕外添加一个方块 当玩家触摸该方块时 它会发出警告 但这花了很长
  • Geopandas 设置几何图形:MultiPolygon“等于 len 键和值”的 ValueError

    我有 2 个带有几何列的地理数据框 我将一些几何图形从 1 个复制到另一个 这对于多边形效果很好 但对于任何 有效 多多边形都会返回 ValueError 请指教如何解决这个问题 我不知道是否 如何 为什么应该更改 MultiPolygon
  • Python:尝试检查有效的电话号码

    我正在尝试编写一个接受以下格式的电话号码的程序XXX XXX XXXX并将条目中的任何字母翻译为其相应的数字 现在我有了这个 如果启动不正确 它将允许您重新输入正确的数字 然后它会翻译输入的原始数字 我该如何解决 def main phon
  • 循环中断打破tqdm

    下面的简单代码使用tqdm https github com tqdm tqdm在循环迭代时显示进度条 import tqdm for f in tqdm tqdm range 100000000 if f gt 100000000 4 b
  • 从 pygame 获取 numpy 数组

    我想通过 python 访问我的网络摄像头 不幸的是 由于网络摄像头的原因 openCV 无法工作 Pygame camera 使用以下代码就像魅力一样 from pygame import camera display camera in
  • Nuitka 未使用 nuitka --recurse-all hello.py [错误] 编译 exe

    我正在尝试通过 nuitka 创建一个简单的 exe 这样我就可以在我的笔记本电脑上运行它 而无需安装 Python 我在 Windows 10 上并使用 Anaconda Python 3 我输入 nuitka recurse all h
  • 在 Pandas DataFrame Python 中添加新列[重复]

    这个问题在这里已经有答案了 例如 我在 Pandas 中有数据框 Col1 Col2 A 1 B 2 C 3 现在 如果我想再添加一个名为 Col3 的列 并且该值基于 Col2 式中 如果Col2 gt 1 则Col3为0 否则为1 所以
  • 从 Python 中的类元信息对 __init__ 函数进行类型提示

    我想做的是复制什么SQLAlchemy确实 以其DeclarativeMeta班级 有了这段代码 from sqlalchemy import Column Integer String from sqlalchemy ext declar
  • Python:元类属性有时会覆盖类属性?

    下面代码的结果让我感到困惑 class MyClass type property def a self return 1 class MyObject object metaclass MyClass a 2 print MyObject
  • 改变字典的哈希函数

    按照此question https stackoverflow com questions 37100390 towards understanding dictionaries 我们知道两个不同的字典 dict 1 and dict 2例
  • Python 分析:“‘select.poll’对象的‘poll’方法”是什么?

    我已经使用 python 分析了我的 python 代码cProfile模块并得到以下结果 ncalls tottime percall cumtime percall filename lineno function 13937860 9

随机推荐