用Tensorflow Agents实现强化学习DQN

2023-11-01

在我之前的博客中强化学习笔记(4)-深度Q学习_gzroy的博客-CSDN博客,实现了用Tensorflow keras搭建DQN模型,解决小车上山问题。在代码里面,需要自己实现经验回放,采样等过程,比较繁琐。

Tensorflow里面有一个agents库,实现了很多强化学习的算法和工具。我尝试用agents来实现一个DQN模型来解决小车上山问题。Tensorflow网上的DQN教程是解决CartPole问题的,如果直接照搬这个代码来解决小车上山问题,则会发现模型无法收敛。经过一番研究,我发现原来是在agents里面,默认环境的回合步数是限制在200步,这样导致小车一直无法到达回合结束的位置,模型学习到的总回报一直保持不变。

以下代码是加载训练环境和评估环境,需要注意的是max_episode_steps需要设置为0,即不限制回合的最大步数:

from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.agents.dqn import dqn_agent
from tf_agents.networks import q_network
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.policies import random_tf_policy
from tf_agents.utils import common
from tf_agents.drivers import dynamic_step_driver
from tf_agents.policies import EpsilonGreedyPolicy
import tensorflow as tf
from tqdm import trange
from tf_agents.policies.q_policy import QPolicy
import seaborn as sns
from matplotlib.ticker import MultipleLocator
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib notebook

env_name = 'MountainCar-v0'
env = suite_gym.load(env_name)
train_py_env = suite_gym.load(env_name, max_episode_steps=0)
eval_py_env = suite_gym.load(env_name, max_episode_steps=0)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

然后我们建立一个DQN agent,这个agent包括了一个Q_net和一个target network,这两个network的结构是相同的,其中Q_net用于学习状态动作对的Q值,target network分享Q_net的权重,用于给定状态输入下找到最大Q值的动作。target_update_tau和target_update_period两个参数用于控制何时更新target network的权重,这里的设定是每一步更新target network的权重W_target = (1-0.005)*W_target + 0.005*W_q。gamma参数表示下一状态对应的Q值有多少计入到U值。epsilion_greedy用于控制有多少百分比的概率是随机挑选动作而不是根据Q值。

q_net = q_network.QNetwork(
    train_env.time_step_spec().observation,
    train_env.action_spec(),
    fc_layer_params=(64,))

target_q_net = q_network.QNetwork(
    train_env.time_step_spec().observation,
    train_env.action_spec(),
    fc_layer_params=(64,))

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    target_update_tau=0.005,
    target_update_period=1,
    gamma=0.99,
    epsilon_greedy=0.1,
    td_errors_loss_fn=common.element_wise_squared_loss,
    optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.001))

设置一个缓冲池,用于存放和回放历史经验数据

replay_buffer_capacity = 10000

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_capacity)

# Add an observer that adds to the replay buffer:
replay_observer = [replay_buffer.add_batch]

先用一个随机动作策略来收集一些历史数据到缓冲池

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec())
initial_driver = dynamic_step_driver.DynamicStepDriver(
      train_env,
      random_policy,
      observers=replay_observer,
      num_steps=1)
for _ in range(1):
    time_step = train_env.reset()
    step = 0
    while not time_step.is_last():
        step += 1
        if step>1000:
            break
        time_step, _ = initial_driver.run(time_step)

搜集数据之后,我们可以把replay_buffer转换为dataset来方便读取数据。这里的num_steps=2表示每次需要取两条相邻的经验数据,因为计算U值的时候需要用下一条数据的Q值来计算。

dataset = replay_buffer.as_dataset(
    num_parallel_calls=3,
    sample_batch_size=128,
    num_steps=2).prefetch(3)

iterator = iter(dataset)

定义一个评估函数,用于评估训练效果:

def compute_avg_return(environment, policy, num_episodes=10):
    total_return = 0.0
    for _ in range(num_episodes):
        time_step = environment.reset()
        episode_return = 0.0
        step = 0
        while not time_step.is_last():
            step += 1
            if step>1000:
                break
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
        total_return += episode_return
    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]

定义一个函数绘制训练过程中的每回合回报和评估回报:

class Chart:
    def __init__(self):
        self.fig, self.ax = plt.subplots(figsize = (8, 6))
        x_major_locator = MultipleLocator(1)
        self.ax.xaxis.set_major_locator(x_major_locator)
        self.ax.set_xlim(0.5, 50.5)

    def plot(self, data):
        self.ax.clear()
        sns.lineplot(data=data, x=data['episode'], y=data['reward'], hue=data['type'], ax=self.ax)
        self.fig.canvas.draw()

最后是训练和评估的代码。这里设置随着训练回合的增加,epsilion_greedy的值也逐渐减小,相当于在训练初期,随机寻找动作的概率较大,随着训练的增加,Q_net能更好的反映真实的Q值,因此随机动作的概率需要相应减小。另外要注意的是,由于初始回合里面需要通过一定的随机概率才能找到合适的动作结束回合,有可能会碰到回合经过很多步仍不能到达回合结束的条件,例如我曾经碰到第一回合运行了15000多步仍不能结束回合,这是可以重新进行训练。

train_episodes = 50
num_eval_episodes = 5
epsilon = 0.1
chart = Chart()

for episode in range(1,train_episodes):
    lr_step.assign(episode)
    learning_rate = learning_rate_fn(episode)
    episodes.append(episode)
    episode_reward = 0
    if epsilon>0.01:
        train_policy = EpsilonGreedyPolicy(agent.policy, epsilon=epsilon)
        train_driver = dynamic_step_driver.DynamicStepDriver(
              train_env,
              train_policy,
              observers=replay_observer,
              num_steps=1)
        epsilon -= 0.01
    time_step = train_env.reset()
    total_loss = 0
    step = 0
    while not time_step.is_last():
        step += 1
        time_step, _ = train_driver.run(time_step, _)
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience).loss
        total_loss += train_loss
        episode_reward += time_step.reward.numpy()[0]
        if step%100==0:
            print("Epsiode_{}, step_{}, loss:{}".format(episode, step, total_loss/step))
    if episode==1:
        rewards_df = pd.DataFrame([[episode, episode_reward, 'train']], columns=['episode','reward','type'])
    else:
        rewards_df = rewards_df.append({'episode':episode, 'reward':episode_reward, 'type':'train'}, ignore_index=True)

    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    rewards_df = rewards_df.append({'episode':episode, 'reward':avg_return, 'type':'eval'}, ignore_index=True)
    chart.plot(rewards_df)

训练完成后,以下代码可以把训练后的策略在评估环境上运行,并生成视频,可以看到训练效果:

import imageio
import base64
import IPython

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

def create_policy_eval_video(policy, filename, num_episodes=1, fps=30):
  filename = filename + ".mp4"
  with imageio.get_writer(filename, fps=fps) as video:
    for _ in range(num_episodes):
      time_step = eval_env.reset()
      video.append_data(eval_py_env.render())
      while not time_step.is_last():
        action_step = policy.action(time_step)
        time_step = eval_env.step(action_step.action)
        video.append_data(eval_py_env.render())
  return embed_mp4(filename)

create_policy_eval_video(agent.policy, "trained-agent")

视频如下:

trained-agent

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

用Tensorflow Agents实现强化学习DQN 的相关文章

随机推荐

  • Oracle插入或修改数据怎么也不行的解决方法

    今天在公司操作数据库 在删除一条数据的时候忘记提交事务了 之后就去添加别的了 但是后来发现怎么也添加不上 所以觉的是事务锁住了 1 直接判断未提交事务引起的表的行锁 1 1判断哪个SESSION执行了DML Insert Update De
  • C语言-蓝桥杯-算法训练 印章

    问题描述 共有 n 种图案的印章 每种图案的出现概率相同 小A买了 m 张印章 求小A集齐 n 种印章的概率 输入格式 一行两个正整数n和m 输出格式 一个实数P表示答案 保留4位小数 样例输入 2 3 样例输出 0 7500 解题思路 共
  • PPTP穿透NAT之深入分析

    PPTP穿透NAT之深入分析 bytxl的专栏 CSDN博客大家好 现在是人静时分 我公司人员都以溜光 只有我还在面对computer 在经过不解 迷惑 结论之后 现与大家分享结果 感谢朋友Zyliday 见贤思齐的实验帮助 在研究技术原理
  • URP自定义后处理(相机滤镜)

    前言 之前做游戏一直想弄个可以实时触发相机滤镜的效果 自处找了教程和资料 想要做到自定义效果的话最好办法是在unity 内部实现 这个办法比较硬核 其实不适合我这样的小白 所以我在实现的过程中非常痛苦 我用的unity URP 模式其实自带
  • OMG!解释执行java字节码文件的命令

    美团一面 收到了HR的信息 通知我去面试 说实话真的挺紧张的 自己准备了近一个月的时间 很担心面试不过 到时候又后悔不该 裸辞 自我介绍 spring的IOC AOP原理 springmvc的工作流程 handlemapping接收的是什么
  • python中的list格式化输出

    在使用python时 我们经常会用到列表 list 由于它可以保存不同类型的数据 因此很多场景下我们都会使用它来保存数据 在写代码的过程中我们经常想要显示list的内容 直接调用print又会显得很丑 还会带着方括号 和逗号 这个太丑 又不
  • Hive数据库连接-连接池实现

    Hive数据库连接 连接池实现 通过HiveJDBC获取Hive的连接Connection 下面我们简单介绍HiveJDBC数据库连接实现 HiveJDBC配置文件 连接池配置文件hive jdbc properties 初始化连接池数 d
  • Linux运维跳槽必备的40道面试精华题

    1 什么是运维 什么是游戏运维 1 运维是指大型组织已经建立好的网络软硬件的维护 就是要保证业务的上线与运作的正常 在他运转的过程中 对他进行维护 他集合了网络 系统 数据库 开发 安全 监控于一身的技术 运维又包括很多种 有DBA运维 网
  • 鼠标点击获得opencv图像坐标和像素值

    目录 一 核心函数 二 在类中定义并且使用 1 将回调函数直接声明为友元函数 2 h 3 DW S OnMou cpp 4 main cpp 三 函数调用 1 OnMouse h 2 OnMouse cpp 一 核心函数 setMouseC
  • 如何在没有 USB 数据线的情况下使用 Android Studio 在手机中安装 Android

    背景 如何在没有 USB 数据线的情况下使用 Android Studio 在手机中安装 Android 应用程序 运行调式一个Android项目 写下必要的代码后 接下来的任务是在模拟器或手机上运行应用程序 测试应用程序是否正常 及deb
  • python numpy中对ndarry按照index(位置下标)增删改查

    在numpy中的ndarry是一个数组 因此index就是位置下标 注意下标是从0开始 增加 在插入时使用np insert 在末尾添加时使用np append 删除 需要使用np delete 修改 直接指定下标 查找 直接指定下标 示例
  • 【Shell】find文件查找

    语法格式 find 路径 选项 操作 选项参数对照表 常用选项 name 查找 etc目录下以conf结尾的文件ind etc nam iname 查找当前目录下文件名为aa的文件 不区分大小写 find iname aa user 查找文
  • [激光原理与应用-69]:激光焊接的10大常见缺陷及解决方法

    激光焊接是一种以高能量密度的激光束作为热源的高效精密焊接方法 如今 激光焊接已广泛应用于各个行业 如 电子零件 汽车制造 航空航天等工业制造领域 但是 在激光焊接的过程中 难免会出现一些缺陷或次品 只有充分了解这些缺陷并学习如何避免它们 才
  • 九轴传感器之测试篇

    关于九轴传感器的数据测试处理
  • CORS与CSRF

    本文首发于我的Github博客 本篇文章介绍了CORS和CSRF的概念 作者前几天在和带佬们聊天的时候把两个概念搞混了 所以才想要了解 简单来说 CORS Cross Origin Resource Sharing 跨域资源分享 是一种机制
  • (1)基础学习——图解pin、pad、port、IO、net 的区别

    本文内容有参考多位博主的博文 综合整理如下 仅做和人学习记录 如有专业性错误还请指正 谢谢 参考1 芯片资料中的pad和pin的区别 imxiangzi的博客 CSDN博客 pin和pad的区别 参考2
  • IntelliJ IDEA 运行卡顿解决方案

    IntelliJ IDEA 运行卡顿解决方案 1 开启IntelliJ IDEA缓慢 想要提升启动速度 则打开D JetBrains IntelliJ IDEA 2020 3 2 bin 依据实际安装路径 目录下对应文件idea64 exe
  • 对csv文件,又get了新的认知

    背景 在数据分析时 有时我们会碰到csv格式文件 需要先进行数据处理 转换成所需要的数据格式 然后才能进行分析 业务侧的同学可能对Excel文件比较熟悉 Excel可以把单个sheet直接保存为csv文件 也可以直接读取csv文件 变成Ex
  • Qt 进程间通信

    Qt进程间通信的方法 TCP IP Local Server Socket 共享内存 D Bus Unix库 QProcess 会话管理 TCP IP 使用套接字的方式 进行通信 之前介绍了 这里就不介绍了 Local Server Soc
  • 用Tensorflow Agents实现强化学习DQN

    在我之前的博客中强化学习笔记 4 深度Q学习 gzroy的博客 CSDN博客 实现了用Tensorflow keras搭建DQN模型 解决小车上山问题 在代码里面 需要自己实现经验回放 采样等过程 比较繁琐 Tensorflow里面有一个a