Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫

2023-12-19

0 专栏介绍

本专栏重点介绍强化学习技术的数学原理,并且 采用Pytorch框架对常见的强化学习算法、案例进行实现 ,帮助读者理解并快速上手开发。同时,辅以各种机器学习、数据处理技术,扩充人工智能的底层知识。

????详情: 《Pytorch深度强化学习》


1 Q-Learning算法原理

Pytorch深度强化学习1-6:详解时序差分强化学习(SARSA、Q-Learning算法) 介绍到 时序差分强化学习 是动态规划与蒙特卡洛的折中

Q π ( s t , a t ) = n 次增量 Q π ( s t , a t ) + α ( R t − Q π ( s t , a t ) )    = n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ R t + 1 − Q π ( s t , a t ) )    = n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) ) ⏟ 采样 \begin{aligned}Q^{\pi}\left( s_t,a_t \right) &\xlongequal{n\text{次增量}}Q^{\pi}\left( s_t,a_t \right) +\alpha \left( R_t-Q^{\pi}\left( s_t,a_t \right) \right) \\\,\, &\xlongequal{n\text{次增量}}Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+\gamma R_{t+1}-Q^{\pi}\left( s_t,a_t \right) \right) \\\,\, &\xlongequal{n\text{次增量}}{ \underset{\text{采样}}{\underbrace{Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+{ \gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) }-Q^{\pi}\left( s_t,a_t \right) \right) }}}\end{aligned} Q π ( s t , a t ) n 次增量 Q π ( s t , a t ) + α ( R t Q π ( s t , a t ) ) n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ R t + 1 Q π ( s t , a t ) ) n 次增量 采样 Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) Q π ( s t , a t ) )

其中 r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) r_{t+1}+\gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) -Q^{\pi}\left( s_t,a_t \right) r t + 1 + γ Q π ( s t + 1 , a t + 1 ) Q π ( s t , a t ) 称为 时序差分误差 。基于离轨策略的时序差分强化学习的代表性算法是 Q-learning算法 ,其算法流程如下所示。具体的策略改进算法推导请见之前的文章,本文重点在于应用Q-learning算法解决实际问题

在这里插入图片描述

我们先来看看最终实现的效果

训练前
在这里插入图片描述

训练后

在这里插入图片描述

接下来详细讲解如何一步步实现这个智能体

2 强化学习基本框架

强化学习(Reinforcement Learning, RL) 在潜在的不确定复杂环境中,训练一个最优决策 π \pi π 指导一系列行动实现目标最优化的机器学习方法 。在初始情况下,没有训练数据告诉强化学习智能体并不知道在环境中应该针对何种状态采取什么行动,而是通过不断试错得到最终结果,再反馈修正之前采取的策略,因此强化学习某种意义上可以视为具有“延迟标记信息”的监督学习问题。

在这里插入图片描述

强化学习的基本过程是:智能体对环境采取某种行动 a a a ,观察到环境状态发生转移 s 0 → s s_0\rightarrow s s 0 s ,反馈给智能体转移后的状态 s s s 和对这种转移的奖赏 r r r 。综上所述,一个强化学习任务可以用四元组 E = < S , A , P , R > E=\left< S,A,P,R \right> E = S , A , P , R 表征

  • 状态空间 S S S :每个状态 s ∈ S s \in S s S 是智能体对感知环境的描述;
  • 动作空间 A A A :每个动作 a ∈ A a \in A a A 是智能体能够采取的行动;
  • 状态转移概率 P P P :某个动作 a ∈ A a \in A a A 作用于处在某个状态 s ∈ S s \in S s S 的环境中,使环境按某种概率分布 P P P 转换到另一个状态;
  • 奖赏函数 R R R :表示智能体对状态 s ∈ S s \in S s S 下采取动作 a ∈ A a \in A a A 导致状态转移的期望度,通常 r > 0 r>0 r > 0 为期望行动, r < 0 r<0 r < 0 为非期望行动。

所以,程序上也需要依次实现四元组 E = < S , A , P , R > E=\left< S,A,P,R \right> E = S , A , P , R

3 机器人走迷宫算法

3.1 迷宫环境

我们创建的迷宫包含障碍物、起点和终点

class Maze(tk.Tk, object):
    '''
    * @breif: 迷宫环境类
    * @param[in]: None
    '''    
    def __init__(self):
        super(Maze, self).__init__()
        self.action_space = ['u', 'd', 'l', 'r']
        self.n_actions = len(self.action_space)
        self.title('maze game')
        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT))
        self.buildMaze()

    '''
    * @breif: 创建迷宫
    '''
    def buildMaze(self):
        self.canvas = tk.Canvas(self, bg='white', height=MAZE_H * UNIT, width=MAZE_W * UNIT)
        # 网格地图
        for c in range(0, MAZE_W * UNIT, UNIT):
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(0, MAZE_H * UNIT, UNIT):
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # 创建原点坐标
        origin = np.array([20, 20])

        # 创建障碍
        barrier_list = [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0),
                        (0, 6), (1, 6), (2, 6), (3, 6), (4, 6), (5, 6), (6, 6),
                        (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 1), (6, 2),
                        (6, 3), (6, 4), (6, 5), (1, 2), (2, 2), (4, 1), (5, 4),
                        (1, 4), (3, 3)]
        self.barriers = [self.creatObject(origin, *index) for index in barrier_list]

        # 创建终点
        self.terminus = self.creatObject(origin, 5, 5, 'blue')

3.2 状态、动作和奖励

机器人的状态可以设置为当前的位置坐标

s = self.canvas.coords(self.agent)

机器人的动作可以设为上、下左、右

if action == 0:   # up
      if s[1] > UNIT:
          base_action[1] -= UNIT
  elif action == 1:   # down
      if s[1] < (MAZE_H - 1) * UNIT:
          base_action[1] += UNIT
  elif action == 2:   # right
      if s[0] < (MAZE_W - 1) * UNIT:
          base_action[0] += UNIT
  elif action == 3:   # left
      if s[0] > UNIT:
          base_action[0] -= UNIT

机器人的奖励设置为以下几种:

  • 碰到障碍物 :-10分,并进入终止状态
  • 成功到达终点 : +50分,并进入终止状态
  • 未到达终点 :-1分,能量耗散惩罚,防止机器人原地振荡
if s_ in [self.canvas.coords(barrier) for barrier in self.barriers]:
   reward = -10
   done = True
   s_ = 'terminal'
elif s_ == self.canvas.coords(self.terminus):
   reward = 50
   done = True
   s_ = 'terminal'
else:
   reward = -1
   done = False

3.3 Q-Learning算法实现

根据算法流程,实现下面的Q-Learning训练函数

def train(self, env, episodes=1000, reward_curve=[], file=None):
	with tqdm(range(episodes)) as bar:
	    for _ in bar:
	        # 初始化环境和该幕累计奖赏
	        state = env.reset()
	        acc_reward = 0
	        while True:
	            # 刷新环境
	            env.render()
	            # 采样一个动作并进行状态转移
	            action = self.policySample(str(state))
	            next_state, reward, done = env.step(action)
	            acc_reward += reward
	            # 智能体学习策略
	            self.learn(str(state), action, reward, str(next_state))
	            state = next_state
	            if done:
	                reward_curve.append(acc_reward)
	                break
	# 保存策略
	if not file:
	    self.q_table.to_csv(file)
	env.destroy()

3.4 完成训练

训练过程如下所示,完成后保存权重文件

if __name__ == "__main__":
    env = Maze()
    agent = Agent(actions=list(range(env.n_actions)))
    reward_curve = []

    # 训练智能体
    env.after(100, agent.train, env, 50, reward_curve, './weight/csv')

    # 主循环
    env.mainloop()

4 算法分析

4.1 Q-Table

在Q-Learning算法中,我们需要维护一个Q-Table,用来记录各种状态和动作的价值。Q-Table是一个二维表格,其中每一行表示一个状态,每一列表示一个动作。Q-Table中的值表示某个状态下执行某个动作所获得的回报(或者预期回报)。Q-Table的更新是Q-Learning算法的核心。在每次执行动作后,我们会根据当前状态、执行的动作、获得的奖励和下一个状态,来更新Q-Table中对应的值,更新方式是

Q π ( s t , a t ) = Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) ) Q^{\pi}\left( s_t,a_t \right) ={ {Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+{ \gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) }-Q^{\pi}\left( s_t,a_t \right) \right) }} Q π ( s t , a t ) = Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) Q π ( s t , a t ) )

对应代码

self.q_table.loc[state, action] += self.lr * (q_target - q_predict)

在这里插入图片描述

保存的权重文件正是Q-Table,我们可以直观地看一下,其中 0-3 指的是上下左右四个动作,每行行首则是状态值,其余数是Q-Value

,0,1,2,3
"[45.0, 45.0, 75.0, 75.0]",-3.764746051087998,-4.129632180625153,2.070923999854885,-4.129632180625153
terminal,0.0,0.0,0.0,0.0
"[85.0, 45.0, 115.0, 75.0]",-3.7017636879676745,-3.2427095093971663,6.341493354722148,-2.4376270354451357
"[125.0, 45.0, 155.0, 75.0]",-2.822694674017249,12.009385340227768,-3.10550914130922,-1.7370066390489591
"[125.0, 85.0, 155.0, 115.0]",-1.018256983413196,-2.3765728565289628,19.23732307528551,-2.602996266117196
"[165.0, 85.0, 195.0, 115.0]",-2.063857163563445,27.370237164958994,-0.7307141976318489,0.14330394709222574
"[205.0, 85.0, 235.0, 115.0]",-0.4546075907459214,-0.45498153729692925,-0.490099501,0.3662096391980347
"[165.0, 125.0, 195.0, 155.0]",0.9791630128216775,35.427315495348594,-0.28782126600827374,-1.7383137616441329
"[205.0, 45.0, 235.0, 75.0]",-0.3940399,-0.38288265597631166,-0.3940399,-0.3940399
"[205.0, 125.0, 235.0, 155.0]",-0.31765122402993484,-0.3940399,-0.3940399,1.5298899806741253
...

4.2 奖励曲线

训练过程的奖励曲线如下所示

在这里插入图片描述

完整代码联系下方博主名片获取


???? 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

????源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系????
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫 的相关文章

  • 如何在python中读取多个文件中的文本

    我的文件夹中有许多文本文件 大约有 3000 个文件 每个文件中第 193 行是唯一包含重要信息的行 我如何使用 python 将所有这些文件读入 1 个文本文件 os 模块中有一个名为 list dir 的函数 该函数返回给定目录中所有文
  • 使用 openCV 对图像中的子图像进行通用检测

    免责声明 我是计算机视觉菜鸟 我看过很多关于如何在较大图像中查找特定子图像的堆栈溢出帖子 我的用例有点不同 因为我不希望它是具体的 而且我不确定如何做到这一点 如果可能的话 但我感觉应该如此 我有大量图像数据集 有时 其中一些图像是数据集的
  • 如何在android上的python kivy中关闭应用程序后使服务继续工作

    我希望我的服务在关闭应用程序后继续工作 但我做不到 我听说我应该使用startForeground 但如何在Python中做到这一点呢 应用程序代码 from kivy app import App from kivy uix floatl
  • DreamPie 不适用于 Python 3.2

    我最喜欢的 Python shell 是DreamPie http dreampie sourceforge net 我想将它与 Python 3 2 一起使用 我使用了 添加解释器 DreamPie 应用程序并添加了 Python 3 2
  • 导入错误:没有名为 _ssl 的模块

    带 Python 2 7 的 Ubuntu Maverick 我不知道如何解决以下导入错误 gt gt gt import ssl Traceback most recent call last File
  • 如何打印没有类型的defaultdict变量?

    在下面的代码中 from collections import defaultdict confusion proba dict defaultdict float for i in xrange 10 confusion proba di
  • 运行多个 scrapy 蜘蛛的正确方法

    我只是尝试使用在同一进程中运行多个蜘蛛新的 scrapy 文档 http doc scrapy org en 1 0 topics practices html但我得到 AttributeError CrawlerProcess objec
  • IRichBolt 在storm-1.0.0 和 pyleus-0.3.0 上运行拓扑时出错

    我正在运行风暴拓扑 pyleus verbose local xyz topology jar using storm 1 0 0 pyleus 0 3 0 centos 6 6并得到错误 线程 main java lang NoClass
  • PyTorch 中的后向函数

    我对 pytorch 的后向功能有一些疑问 我认为我没有得到正确的输出 import numpy as np import torch from torch autograd import Variable a Variable torch
  • python 集合可以包含的值的数量是否有限制?

    我正在尝试使用 python 设置作为 mysql 表中 ids 的过滤器 python集存储了所有要过滤的id 现在大约有30000个 这个数字会随着时间的推移慢慢增长 我担心python集的最大容量 它可以包含的元素数量有限制吗 您最大
  • 表达式中的 Python 'in' 关键字与 for 循环中的比较 [重复]

    这个问题在这里已经有答案了 我明白什么是in运算符在此代码中执行的操作 some list 1 2 3 4 5 print 2 in some list 我也明白i将采用此代码中列表的每个值 for i in 1 2 3 4 5 print
  • 如何将 numpy.matrix 提高到非整数幂?

    The 运算符为numpy matrix不支持非整数幂 gt gt gt m matrix 1 0 0 5 0 5 gt gt gt m 2 5 TypeError exponent must be an integer 我想要的是 oct
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • Python 3 中“map”类型的对象没有 len()

    我在使用 Python 3 时遇到问题 我得到了 Python 2 7 代码 目前我正在尝试更新它 我收到错误 类型错误 map 类型的对象没有 len 在这部分 str len seed candidates 在我像这样初始化它之前 se
  • 如何在 Django 中使用并发进程记录到单个文件而不使用独占锁

    给定一个在多个服务器上同时执行的 Django 应用程序 该应用程序如何记录到单个共享日志文件 在网络共享中 而不保持该文件以独占模式永久打开 当您想要利用日志流时 这种情况适用于 Windows Azure 网站上托管的 Django 应
  • VSCode:调试配置中的 Python 路径无效

    对 Python 和 VSCode 以及 stackoverflow 非常陌生 直到最近 我已经使用了大约 3 个月 一切都很好 当尝试在调试器中运行任何基本的 Python 程序时 弹出窗口The Python path in your
  • 如何使用google colab在jupyter笔记本中显示GIF?

    我正在使用 google colab 想嵌入一个 gif 有谁知道如何做到这一点 我正在使用下面的代码 它并没有在笔记本中为 gif 制作动画 我希望笔记本是交互式的 这样人们就可以看到代码的动画效果 而无需运行它 我发现很多方法在 Goo
  • Python - 字典和列表相交

    给定以下数据结构 找出这两种数据结构共有的交集键的最有效方法是什么 dict1 2A 3A 4B list1 2A 4B Expected output 2A 4B 如果这也能产生更快的输出 我可以将列表 不是 dict1 组织到任何其他数
  • 改变字典的哈希函数

    按照此question https stackoverflow com questions 37100390 towards understanding dictionaries 我们知道两个不同的字典 dict 1 and dict 2例
  • PyAudio ErrNo 输入溢出 -9981

    我遇到了与用户相同的错误 Python 使用 Pyaudio 以 16000Hz 录制音频时出错 https stackoverflow com questions 12994981 python error audio recording

随机推荐

  • 使用Docker搭建onlyoffice

    前提 省略docker的安装步骤 1拉去onlyoffice镜像 我的是指定版本 docker pull onlyoffice documentserver 7 1 1 2 启动onlyoffice docker run itd name
  • 40V TPHR8504PL N沟道功率MOSFET具有业界领先的低导通电阻特性,有助于提高电源效率

    TPHR8504PL是一种MOSFET 金属氧化物半导体场效应晶体管 它是40 Volt N 沟道MOSFET 由N型沟道和P型衬底构成 而P 沟道MOSFET则由P型沟道和N型衬底构成 TPHR8504PL N 沟道MOSFET的工作原理
  • 用友BIP数智采购,推动阳光采购,规避合规风险

    采购管理是企业价值链管理的核心环节 与企业经营效益密切相关 国有企业的采购工作兼具公共采购和企业采购的双重属性 既关系国有企业的经济效益和发展前景 也关系国有资产的保值增值 阳光采购是国有企业合规化采购的必由之路 国有企业在采购环节融入数智
  • thinkphp+mysql+vue实验室设备报修预约管理系统

    运行环境 phpstudy wamp xammp等 开发语言 php 后端框架 Thinkphp5 前端框架 vue js 服务器 apache 数据库 mysql 数据库工具 Navicat phpmyadmin 本站是一个B S模式系统
  • 【手势识别】深度学习卷积神经网络CNN手势识别(0-9,含识别率)【含Matlab源码 3435期】

    博主简介 热爱科研的Matlab仿真开发者 修心和技术同步精进 Matlab项目合作可私信 个人主页 海神之光 代码获取方式 海神之光Matlab王者学习之路 代码获取方式 座右铭 行百里者 半于九十 更多Matlab仿真内容点击 Matl
  • 好用的p图软件免费有哪些?帮你个性化修饰你的图片

    在信息时代的今天 很多人都喜欢将自己的所见所闻拍照下来 但是拍下来的照片难免有这样那样的问题 这就需要一款好用的p图软件来对图片进行完善了 但是市面上p图软件众多 有以功能齐全和专业性为特色的 也有以操作简单为特色的 更有些以多样化的特效为
  • 【印刷字符识别】OCR键盘数字+字母识别【含Matlab源码 807期】

    博主简介 热爱科研的Matlab仿真开发者 修心和技术同步精进 Matlab项目合作可私信 个人主页 海神之光 代码获取方式 海神之光Matlab王者学习之路 代码获取方式 座右铭 行百里者 半于九十 更多Matlab仿真内容点击 Matl
  • R10在工业自动化-485转WiFi无线路由解决方案

    R10是钡铼技术有限公司研发的一款用于工业自动化应用的485转WiFi无线路由器解决方案 该解决方案可以将传统的RS485通信设备无线化 实现数据的远程监控和管理 下面将详细介绍R10在工业自动化中的应用 首先 R10具备RS485转WiF
  • 西南科技大学数据库实验八(自定义函数)

    一 实验目的 1 掌握用户自定义变量 2 熟悉运算符与表达式 3 掌握begin end语句块 4 掌握重置命令结束标记 5 掌握创建自定义函数的语法格式以及函数的创建与调用 二 实验任务 1 创建学生表Student 由学号 Sno 姓名
  • 系列十二、索引实战

    一 索引实战 1 1 前置说明 前边的系列文章中是基于Linux中的MySQL进行案例演示的 为了后续测试百万条数据的sql性能分析 接下来的案例将会在Windows的MySQL中进行演示 MySQL版本要求需在8 0以上 我的MySQL版
  • android 13.0 SystemUI状态栏下拉快捷添加截图快捷开关

    1 概述 在13 0的系统产品rom定制化开发中 对SystemUI的定制需求也是挺多的 在下拉状态栏中 添加截图快捷开关 也是常有的开发功能 下面就以添加 截图功能为例功能的实现 2 SystemUI 状态栏下拉快捷添加截图快捷开关的核心
  • 大众点评poi数据2023年最新

    数据名称 2023年大众点评全国全品类数据特此说明 保证价格最低 数据质量最高变量名称 Id 店名 分店 b id 一类 m id 二类 s id 三类 area reg mark 外卖 订座 排队 促销 买单 酒店预约 美容预约 推荐 星
  • uniapp-使用返回的base64转换成图片

    在实际开发的时候 需要后端实时的给我返回二维码 他给我返回的是加密后的 base64字符串 我需要利用这个base64转换到canvas画布上展示 或者以图片的形式展示在页面内 在canvas画布上展示 使用官方的uni getFileSy
  • 【手势识别】肤色静态手势识别【含Matlab源码 288期】

    博主简介 热爱科研的Matlab仿真开发者 修心和技术同步精进 Matlab项目合作可私信 个人主页 海神之光 代码获取方式 海神之光Matlab王者学习之路 代码获取方式 座右铭 行百里者 半于九十 更多Matlab仿真内容点击 Matl
  • 视频剪辑软件哪个好用?这些软件值得收藏

    朋友 你有没有遇到过这样的情况 收到了一段精彩的视频 想要将其中的亮点剪切出来制作成短视频 或是想将长时间的录像文件分割成多个小段 以便更方便地进行编辑和管理 但是却不知道该选择哪款视频剪辑合成软件 别担心 今天我将会给大家介绍一些常见的视
  • 【手写数字识别】BP神经网络手写数字识别【含GUI Matlab源码 1118期】

    博主简介 热爱科研的Matlab仿真开发者 修心和技术同步精进 Matlab项目合作可私信 个人主页 海神之光 代码获取方式 海神之光Matlab王者学习之路 代码获取方式 座右铭 行百里者 半于九十 更多Matlab仿真内容点击 Matl
  • 题解 | #输出某一年的各个月份的天数#

    三方寄过去了 告诉我停止24届招聘 全部毁约 牛的 he芯 毁约应届生 34316 广西北部湾银行2022年校园招聘 广西北部湾银行股份有限公司2022届校园招聘 看终端大把大把15级的 这个14级是不是终端bg的白菜了 程序员面试六战六捷
  • 配音工具哪个好?这里有你想知道的答案

    听说你还在为找不到合适的配音工具而烦恼 没关系 我这就来给你支招 其实配音不一定得找专业的录音室 现在许多在线工具也可以帮助你将文字转化为语音 而且 互联网上的配音工具可不少呢 有的可以提供多种语音风格和语调 有的则是可以快速生成语音内容
  • remote: Support for password authentication was removed on August 13, 2021.

    往 GitHub 上推送项目时 报如下错误 remote Support for password authentication was removed on August 13 2021 remote Please see https d
  • Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫

    目录 0 专栏介绍 1 Q Learning算法原理 2 强化学习基本框架 3 机器人走迷宫算法 3 1 迷宫环境 3 2 状态 动作和奖励 3 3 Q Learning算法实现 3 4 完成训练