Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫

2023-12-19

0 专栏介绍

本专栏重点介绍强化学习技术的数学原理,并且 采用Pytorch框架对常见的强化学习算法、案例进行实现 ,帮助读者理解并快速上手开发。同时,辅以各种机器学习、数据处理技术,扩充人工智能的底层知识。

????详情: 《Pytorch深度强化学习》


1 Q-Learning算法原理

Pytorch深度强化学习1-6:详解时序差分强化学习(SARSA、Q-Learning算法) 介绍到 时序差分强化学习 是动态规划与蒙特卡洛的折中

Q π ( s t , a t ) = n 次增量 Q π ( s t , a t ) + α ( R t − Q π ( s t , a t ) )    = n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ R t + 1 − Q π ( s t , a t ) )    = n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) ) ⏟ 采样 \begin{aligned}Q^{\pi}\left( s_t,a_t \right) &\xlongequal{n\text{次增量}}Q^{\pi}\left( s_t,a_t \right) +\alpha \left( R_t-Q^{\pi}\left( s_t,a_t \right) \right) \\\,\, &\xlongequal{n\text{次增量}}Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+\gamma R_{t+1}-Q^{\pi}\left( s_t,a_t \right) \right) \\\,\, &\xlongequal{n\text{次增量}}{ \underset{\text{采样}}{\underbrace{Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+{ \gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) }-Q^{\pi}\left( s_t,a_t \right) \right) }}}\end{aligned} Q π ( s t , a t ) n 次增量 Q π ( s t , a t ) + α ( R t Q π ( s t , a t ) ) n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ R t + 1 Q π ( s t , a t ) ) n 次增量 采样 Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) Q π ( s t , a t ) )

其中 r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) r_{t+1}+\gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) -Q^{\pi}\left( s_t,a_t \right) r t + 1 + γ Q π ( s t + 1 , a t + 1 ) Q π ( s t , a t ) 称为 时序差分误差 。基于离轨策略的时序差分强化学习的代表性算法是 Q-learning算法 ,其算法流程如下所示。具体的策略改进算法推导请见之前的文章,本文重点在于应用Q-learning算法解决实际问题

在这里插入图片描述

我们先来看看最终实现的效果

训练前
在这里插入图片描述

训练后

在这里插入图片描述

接下来详细讲解如何一步步实现这个智能体

2 强化学习基本框架

强化学习(Reinforcement Learning, RL) 在潜在的不确定复杂环境中,训练一个最优决策 π \pi π 指导一系列行动实现目标最优化的机器学习方法 。在初始情况下,没有训练数据告诉强化学习智能体并不知道在环境中应该针对何种状态采取什么行动,而是通过不断试错得到最终结果,再反馈修正之前采取的策略,因此强化学习某种意义上可以视为具有“延迟标记信息”的监督学习问题。

在这里插入图片描述

强化学习的基本过程是:智能体对环境采取某种行动 a a a ,观察到环境状态发生转移 s 0 → s s_0\rightarrow s s 0 s ,反馈给智能体转移后的状态 s s s 和对这种转移的奖赏 r r r 。综上所述,一个强化学习任务可以用四元组 E = < S , A , P , R > E=\left< S,A,P,R \right> E = S , A , P , R 表征

  • 状态空间 S S S :每个状态 s ∈ S s \in S s S 是智能体对感知环境的描述;
  • 动作空间 A A A :每个动作 a ∈ A a \in A a A 是智能体能够采取的行动;
  • 状态转移概率 P P P :某个动作 a ∈ A a \in A a A 作用于处在某个状态 s ∈ S s \in S s S 的环境中,使环境按某种概率分布 P P P 转换到另一个状态;
  • 奖赏函数 R R R :表示智能体对状态 s ∈ S s \in S s S 下采取动作 a ∈ A a \in A a A 导致状态转移的期望度,通常 r > 0 r>0 r > 0 为期望行动, r < 0 r<0 r < 0 为非期望行动。

所以,程序上也需要依次实现四元组 E = < S , A , P , R > E=\left< S,A,P,R \right> E = S , A , P , R

3 机器人走迷宫算法

3.1 迷宫环境

我们创建的迷宫包含障碍物、起点和终点

class Maze(tk.Tk, object):
    '''
    * @breif: 迷宫环境类
    * @param[in]: None
    '''    
    def __init__(self):
        super(Maze, self).__init__()
        self.action_space = ['u', 'd', 'l', 'r']
        self.n_actions = len(self.action_space)
        self.title('maze game')
        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT))
        self.buildMaze()

    '''
    * @breif: 创建迷宫
    '''
    def buildMaze(self):
        self.canvas = tk.Canvas(self, bg='white', height=MAZE_H * UNIT, width=MAZE_W * UNIT)
        # 网格地图
        for c in range(0, MAZE_W * UNIT, UNIT):
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(0, MAZE_H * UNIT, UNIT):
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # 创建原点坐标
        origin = np.array([20, 20])

        # 创建障碍
        barrier_list = [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0),
                        (0, 6), (1, 6), (2, 6), (3, 6), (4, 6), (5, 6), (6, 6),
                        (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 1), (6, 2),
                        (6, 3), (6, 4), (6, 5), (1, 2), (2, 2), (4, 1), (5, 4),
                        (1, 4), (3, 3)]
        self.barriers = [self.creatObject(origin, *index) for index in barrier_list]

        # 创建终点
        self.terminus = self.creatObject(origin, 5, 5, 'blue')

3.2 状态、动作和奖励

机器人的状态可以设置为当前的位置坐标

s = self.canvas.coords(self.agent)

机器人的动作可以设为上、下左、右

if action == 0:   # up
      if s[1] > UNIT:
          base_action[1] -= UNIT
  elif action == 1:   # down
      if s[1] < (MAZE_H - 1) * UNIT:
          base_action[1] += UNIT
  elif action == 2:   # right
      if s[0] < (MAZE_W - 1) * UNIT:
          base_action[0] += UNIT
  elif action == 3:   # left
      if s[0] > UNIT:
          base_action[0] -= UNIT

机器人的奖励设置为以下几种:

  • 碰到障碍物 :-10分,并进入终止状态
  • 成功到达终点 : +50分,并进入终止状态
  • 未到达终点 :-1分,能量耗散惩罚,防止机器人原地振荡
if s_ in [self.canvas.coords(barrier) for barrier in self.barriers]:
   reward = -10
   done = True
   s_ = 'terminal'
elif s_ == self.canvas.coords(self.terminus):
   reward = 50
   done = True
   s_ = 'terminal'
else:
   reward = -1
   done = False

3.3 Q-Learning算法实现

根据算法流程,实现下面的Q-Learning训练函数

def train(self, env, episodes=1000, reward_curve=[], file=None):
	with tqdm(range(episodes)) as bar:
	    for _ in bar:
	        # 初始化环境和该幕累计奖赏
	        state = env.reset()
	        acc_reward = 0
	        while True:
	            # 刷新环境
	            env.render()
	            # 采样一个动作并进行状态转移
	            action = self.policySample(str(state))
	            next_state, reward, done = env.step(action)
	            acc_reward += reward
	            # 智能体学习策略
	            self.learn(str(state), action, reward, str(next_state))
	            state = next_state
	            if done:
	                reward_curve.append(acc_reward)
	                break
	# 保存策略
	if not file:
	    self.q_table.to_csv(file)
	env.destroy()

3.4 完成训练

训练过程如下所示,完成后保存权重文件

if __name__ == "__main__":
    env = Maze()
    agent = Agent(actions=list(range(env.n_actions)))
    reward_curve = []

    # 训练智能体
    env.after(100, agent.train, env, 50, reward_curve, './weight/csv')

    # 主循环
    env.mainloop()

4 算法分析

4.1 Q-Table

在Q-Learning算法中,我们需要维护一个Q-Table,用来记录各种状态和动作的价值。Q-Table是一个二维表格,其中每一行表示一个状态,每一列表示一个动作。Q-Table中的值表示某个状态下执行某个动作所获得的回报(或者预期回报)。Q-Table的更新是Q-Learning算法的核心。在每次执行动作后,我们会根据当前状态、执行的动作、获得的奖励和下一个状态,来更新Q-Table中对应的值,更新方式是

Q π ( s t , a t ) = Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) ) Q^{\pi}\left( s_t,a_t \right) ={ {Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+{ \gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) }-Q^{\pi}\left( s_t,a_t \right) \right) }} Q π ( s t , a t ) = Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) Q π ( s t , a t ) )

对应代码

self.q_table.loc[state, action] += self.lr * (q_target - q_predict)

在这里插入图片描述

保存的权重文件正是Q-Table,我们可以直观地看一下,其中 0-3 指的是上下左右四个动作,每行行首则是状态值,其余数是Q-Value

,0,1,2,3
"[45.0, 45.0, 75.0, 75.0]",-3.764746051087998,-4.129632180625153,2.070923999854885,-4.129632180625153
terminal,0.0,0.0,0.0,0.0
"[85.0, 45.0, 115.0, 75.0]",-3.7017636879676745,-3.2427095093971663,6.341493354722148,-2.4376270354451357
"[125.0, 45.0, 155.0, 75.0]",-2.822694674017249,12.009385340227768,-3.10550914130922,-1.7370066390489591
"[125.0, 85.0, 155.0, 115.0]",-1.018256983413196,-2.3765728565289628,19.23732307528551,-2.602996266117196
"[165.0, 85.0, 195.0, 115.0]",-2.063857163563445,27.370237164958994,-0.7307141976318489,0.14330394709222574
"[205.0, 85.0, 235.0, 115.0]",-0.4546075907459214,-0.45498153729692925,-0.490099501,0.3662096391980347
"[165.0, 125.0, 195.0, 155.0]",0.9791630128216775,35.427315495348594,-0.28782126600827374,-1.7383137616441329
"[205.0, 45.0, 235.0, 75.0]",-0.3940399,-0.38288265597631166,-0.3940399,-0.3940399
"[205.0, 125.0, 235.0, 155.0]",-0.31765122402993484,-0.3940399,-0.3940399,1.5298899806741253
...

4.2 奖励曲线

训练过程的奖励曲线如下所示

在这里插入图片描述

完整代码联系下方博主名片获取


???? 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

????源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系????
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫 的相关文章

随机推荐

  • 使用Docker搭建onlyoffice

    前提 省略docker的安装步骤 1拉去onlyoffice镜像 我的是指定版本 docker pull onlyoffice documentserver 7 1 1 2 启动onlyoffice docker run itd name
  • 40V TPHR8504PL N沟道功率MOSFET具有业界领先的低导通电阻特性,有助于提高电源效率

    TPHR8504PL是一种MOSFET 金属氧化物半导体场效应晶体管 它是40 Volt N 沟道MOSFET 由N型沟道和P型衬底构成 而P 沟道MOSFET则由P型沟道和N型衬底构成 TPHR8504PL N 沟道MOSFET的工作原理
  • 用友BIP数智采购,推动阳光采购,规避合规风险

    采购管理是企业价值链管理的核心环节 与企业经营效益密切相关 国有企业的采购工作兼具公共采购和企业采购的双重属性 既关系国有企业的经济效益和发展前景 也关系国有资产的保值增值 阳光采购是国有企业合规化采购的必由之路 国有企业在采购环节融入数智
  • thinkphp+mysql+vue实验室设备报修预约管理系统

    运行环境 phpstudy wamp xammp等 开发语言 php 后端框架 Thinkphp5 前端框架 vue js 服务器 apache 数据库 mysql 数据库工具 Navicat phpmyadmin 本站是一个B S模式系统
  • 【手势识别】深度学习卷积神经网络CNN手势识别(0-9,含识别率)【含Matlab源码 3435期】

    博主简介 热爱科研的Matlab仿真开发者 修心和技术同步精进 Matlab项目合作可私信 个人主页 海神之光 代码获取方式 海神之光Matlab王者学习之路 代码获取方式 座右铭 行百里者 半于九十 更多Matlab仿真内容点击 Matl
  • 好用的p图软件免费有哪些?帮你个性化修饰你的图片

    在信息时代的今天 很多人都喜欢将自己的所见所闻拍照下来 但是拍下来的照片难免有这样那样的问题 这就需要一款好用的p图软件来对图片进行完善了 但是市面上p图软件众多 有以功能齐全和专业性为特色的 也有以操作简单为特色的 更有些以多样化的特效为
  • 【印刷字符识别】OCR键盘数字+字母识别【含Matlab源码 807期】

    博主简介 热爱科研的Matlab仿真开发者 修心和技术同步精进 Matlab项目合作可私信 个人主页 海神之光 代码获取方式 海神之光Matlab王者学习之路 代码获取方式 座右铭 行百里者 半于九十 更多Matlab仿真内容点击 Matl
  • R10在工业自动化-485转WiFi无线路由解决方案

    R10是钡铼技术有限公司研发的一款用于工业自动化应用的485转WiFi无线路由器解决方案 该解决方案可以将传统的RS485通信设备无线化 实现数据的远程监控和管理 下面将详细介绍R10在工业自动化中的应用 首先 R10具备RS485转WiF
  • 西南科技大学数据库实验八(自定义函数)

    一 实验目的 1 掌握用户自定义变量 2 熟悉运算符与表达式 3 掌握begin end语句块 4 掌握重置命令结束标记 5 掌握创建自定义函数的语法格式以及函数的创建与调用 二 实验任务 1 创建学生表Student 由学号 Sno 姓名
  • 系列十二、索引实战

    一 索引实战 1 1 前置说明 前边的系列文章中是基于Linux中的MySQL进行案例演示的 为了后续测试百万条数据的sql性能分析 接下来的案例将会在Windows的MySQL中进行演示 MySQL版本要求需在8 0以上 我的MySQL版
  • android 13.0 SystemUI状态栏下拉快捷添加截图快捷开关

    1 概述 在13 0的系统产品rom定制化开发中 对SystemUI的定制需求也是挺多的 在下拉状态栏中 添加截图快捷开关 也是常有的开发功能 下面就以添加 截图功能为例功能的实现 2 SystemUI 状态栏下拉快捷添加截图快捷开关的核心
  • 大众点评poi数据2023年最新

    数据名称 2023年大众点评全国全品类数据特此说明 保证价格最低 数据质量最高变量名称 Id 店名 分店 b id 一类 m id 二类 s id 三类 area reg mark 外卖 订座 排队 促销 买单 酒店预约 美容预约 推荐 星
  • uniapp-使用返回的base64转换成图片

    在实际开发的时候 需要后端实时的给我返回二维码 他给我返回的是加密后的 base64字符串 我需要利用这个base64转换到canvas画布上展示 或者以图片的形式展示在页面内 在canvas画布上展示 使用官方的uni getFileSy
  • 【手势识别】肤色静态手势识别【含Matlab源码 288期】

    博主简介 热爱科研的Matlab仿真开发者 修心和技术同步精进 Matlab项目合作可私信 个人主页 海神之光 代码获取方式 海神之光Matlab王者学习之路 代码获取方式 座右铭 行百里者 半于九十 更多Matlab仿真内容点击 Matl
  • 视频剪辑软件哪个好用?这些软件值得收藏

    朋友 你有没有遇到过这样的情况 收到了一段精彩的视频 想要将其中的亮点剪切出来制作成短视频 或是想将长时间的录像文件分割成多个小段 以便更方便地进行编辑和管理 但是却不知道该选择哪款视频剪辑合成软件 别担心 今天我将会给大家介绍一些常见的视
  • 【手写数字识别】BP神经网络手写数字识别【含GUI Matlab源码 1118期】

    博主简介 热爱科研的Matlab仿真开发者 修心和技术同步精进 Matlab项目合作可私信 个人主页 海神之光 代码获取方式 海神之光Matlab王者学习之路 代码获取方式 座右铭 行百里者 半于九十 更多Matlab仿真内容点击 Matl
  • 题解 | #输出某一年的各个月份的天数#

    三方寄过去了 告诉我停止24届招聘 全部毁约 牛的 he芯 毁约应届生 34316 广西北部湾银行2022年校园招聘 广西北部湾银行股份有限公司2022届校园招聘 看终端大把大把15级的 这个14级是不是终端bg的白菜了 程序员面试六战六捷
  • 配音工具哪个好?这里有你想知道的答案

    听说你还在为找不到合适的配音工具而烦恼 没关系 我这就来给你支招 其实配音不一定得找专业的录音室 现在许多在线工具也可以帮助你将文字转化为语音 而且 互联网上的配音工具可不少呢 有的可以提供多种语音风格和语调 有的则是可以快速生成语音内容
  • remote: Support for password authentication was removed on August 13, 2021.

    往 GitHub 上推送项目时 报如下错误 remote Support for password authentication was removed on August 13 2021 remote Please see https d
  • Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫

    目录 0 专栏介绍 1 Q Learning算法原理 2 强化学习基本框架 3 机器人走迷宫算法 3 1 迷宫环境 3 2 状态 动作和奖励 3 3 Q Learning算法实现 3 4 完成训练