策略梯度算法(Policy Gradient)逐行代码详解

2023-05-16

理论部分以及完整代码参看之前的博客:https://blog.csdn.net/qq_47997583/article/details/124506650
本文章介绍的是策略梯度算法中的REINFORCE实现
在这里插入图片描述
上图为算法流程图,总体来说代码实现中,先实现一个episode然后从后往前计算回报,损失函数是负的回报乘于log的该状态下采取该动作的概率。每个状态动作对对应算一次loss,然后反向传播计算梯度。最后整个episode完之后进行梯度下降。
在代码实现中我们需要是实现两个类PolicyNetREINFORCE以及主函数部分。
主函数部分

for i in range(10):
    with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
        for i_episode in range(int(num_episodes / 10)):
            episode_return = 0
            transition_dict = {
                'states': [],
                'actions': [],
                'next_states': [],
                'rewards': [],
                'dones': []
            }
            state = env.reset()
            done = False
            while not done:
                action = agent.take_action(state)
                next_state, reward, done, _ = env.step(action)
                transition_dict['states'].append(state)
                transition_dict['actions'].append(action)
                transition_dict['next_states'].append(next_state)
                transition_dict['rewards'].append(reward)
                transition_dict['dones'].append(done)
                state = next_state
                episode_return += reward
            return_list.append(episode_return)
            agent.update(transition_dict)
            if (i_episode + 1) % 10 == 0:
                pbar.set_postfix({
                    'episode':
                    '%d' % (num_episodes / 10 * i + i_episode + 1),
                    'return':
                    '%.3f' % np.mean(return_list[-10:])
                })
            pbar.update(1)

首先从主函数部分看,一共是训练1000个episode,分为10个iteration,每个100个episode。每个episode我们需要先初始化一个回合总奖励和一个字典记录整个回合每个时间步的五元组信息;之后初始化状态state和done,然后进行采样直到回合结束,采样的动作通过REINFORCE类实例化的agent的take_action方法获得,然后将动作传入step函数获得四元组,将五元组传入字典中;回合结束后将累计回合奖励传入结果列表,之后通过agent的update方法更新参数。

PolicyNet

class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1) 
         # 0是对列做归一化,1是对行做归一化

REINFORCE

class REINFORCE:
    def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma,
                 device):
        self.policy_net = PolicyNet(state_dim, hidden_dim,
                                    action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(),
                                          lr=learning_rate)  # 使用Adam优化器
        self.gamma = gamma  # 折扣因子
        self.device = device

    def take_action(self, state): 
     # 根据动作概率分布随机采样
        state = torch.tensor([state], dtype=torch.float).to(self.device) 
        # 1*4
        probs = self.policy_net(state) 
         # 1*2
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        reward_list = transition_dict['rewards']
        state_list = transition_dict['states']
        action_list = transition_dict['actions']

        G = 0
        self.optimizer.zero_grad()
        for i in reversed(range(len(reward_list))):  
        # 从最后一步算起
            reward = reward_list[i]
            state = torch.tensor([state_list[i]], 
             # 1*4
                                 dtype=torch.float).to(self.device)
            action = torch.tensor([action_list[i]]).view(-1, 1).to(self.device) 
            # 1*1
            log_prob = torch.log(self.policy_net(state).gather(1, action))
             # 1*1
            G = self.gamma * G + reward
            loss = -log_prob * G  # 每一步的损失函数
            loss.backward()  # 反向传播计算梯度
        self.optimizer.step()  # 梯度下降

在实现REINFORCE类时我们需要实现两个方法,一个是take_action,一个是updatetake_action作用是传入state到神经网络获得两个动作的概率分布,然后依据概率分布进行动作的抽样,update则是本算法的核心部分,我们需要传入包含回合的五元组数据的字典transition_dict,拿出来其中的s,a,r列表。之后首先初始化累计奖励G为0并且将梯度清零,然后执行循环,循环的次数为列表的长度,每次循环从列表末尾往前遍历:获得reward,state,action,log_prob通过将state传入神经网络获得两个动作的概率,然后根据action索引得到在本次时间步的π(a|s),然后计算其log值。由于REINFORCE算法是计算当前时间之后的累计奖励作为回报,因此 G 的更新方式为 G = self.gamma * G + reward,将G乘负的对数log_prob即为损失函数,然后进行反向传播进行梯度累计。最后循环结束也就是episode结束,再进行梯度下降更新神经网络参数。

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

策略梯度算法(Policy Gradient)逐行代码详解 的相关文章

随机推荐

  • CSS之选择器(一)普通选择器

    1 CSS选择器概述 通过选择器可以选定页面中的指定元素 xff0c 对HTML页面中的元素实现一对一 一对多或者多对一的控制 HTML页面中的元素都是通过CSS选择器进行控制的 2 CSS选择器 1 一般选择器 一般选择器 选择器示例示例
  • vs code写python代码时遇到蓝色波浪线“word“: Unknown word.cSpell[1,1]解决方法

    vs code写python代码时遇到蓝色波浪线 34 word 34 Unknown word cSpell 1 1 解决方法 从上面的两张图片都出现有cSpell xff0c 这个单词其实是code spell checker扩展检查p
  • ubuntu 18.0.4以上版本系统内网双网口设置方法

    需求 xff1a 18以上版本系统与老版本有很大区别 xff0c 目前有线网卡用于SSH及本地内网连接 xff0c 无线用于外网连接 xff0c 用笔记本SecureCRT通过有线操作ubuntu设备 xff0c 而且还要保证Ubuntu设
  • 操作系统的概念、功能和目标

    大家都熟悉的操作系统 windowsAndroidiosmacoslinux 本节框架 xff1a 定义 xff1a 操作系统是指控制和管理整个计算机系统的硬件和软件资源 xff0c 并合理地组织调度计算机的工作和资源的分配 xff0c 以
  • SpringBoot整合FreeMarker

    一 FreeMarker简述 在线文档 xff1a http freemarker foofun cn FreeMarker 也是一款模板引擎技术 xff0c 它是一种基于模板和要改变的数据 xff0c 并用来生成输出文本 HTML网页 x
  • SpringMVC的执行流程

    前言 当你知道springMVC的执行流程的时候 xff0c 会达到是事半功倍的学习效果 SpringMVC执行流程 首先明确 xff1a SpringMVC的执行过程就是 xff1a 客户端或者浏览器发送请求到后端服务器 xff0c 后端
  • archLinux安装记录

    archLinux安装记录 基于wsl的arch 启用wsl 首先 xff0c 按Win 43 S搜索启用或关闭Windows功能 xff08 Turn Windows features on or off xff09 打开虚拟机平台和WS
  • mac风格的windows11

    结果 工具下载 链接 xff1a https pan baidu com s 1bVkGI2FZ1Y6tziRMFdP3fw 提取码 xff1a MACC windows11微软官网纯镜像 链接 xff1a https pan baidu
  • AD学习问题记录(四):AD21布线时如何更改线宽

    目录 问题 xff1a 布线时发现线比需要的细解决 xff1a 更改规则结果总结 目前使用的版本是AD21 问题 xff1a 布线时发现线比需要的细 在PCB布线的时候 xff0c 发现线宽比较细 xff0c 于是在右侧的Propertie
  • FAILURE: Build failed with an exception.* What went wrong:Execution failed for task ‘:app:compile...

    1 错误原因 笔记 在运行android的项目时报错 咱就是说代码不知道检查多少遍了 反正代码可以肯定的是没错的 于是就去网上搜索啊 按照提示在build gradle Module app 加了如下代码 android compileOp
  • Java实现二分搜索

    二分查找 xff1a 是一种算法 xff0c 其输入是一个有序的元素列表 xff08 必须是有序的 xff09 xff0c 如果查找的元素包含在列表中 xff0c 返回其索引 xff0c 否则返回负数 比如说有一个1 100的数字 xff0
  • Python if else条件语句你懂了吗?

    在 Python 中 xff0c 可以使用 if else 语句对条件进行判断 xff0c 然后根据不同的结果执行不同的代码 xff0c 这称为选择结构或者分支结构 Python 中的 if else 语句可以细分为三种形式 xff0c 分
  • 嵌入式学习系统里的ROM和RAM(转载)

    一个嵌入式项目在立项时 xff0c 其中有个重要的环节就是对系统所需的RAM和ROM用量进行评估 xff0c 在满足系统需求的前提下 xff0c 尽量降低硬件成本 xff0c 据说同等大小的RAM价格大概是ROM的6倍 大部分的资料都宣称程
  • 关于Mysql8.0.22服务无法启动问题

    关于Mysql8 0 22服务无法启动问题 1 官网下载 解压完成后 不存在data文件夹 也不要自己创建 后面会用命令生成 请往后看 2 创建my ini文件 xff08 一定要放在bin目录下 xff0c 不要放在mysql8 0 22
  • 查找Ubuntu中安装软件的位置

    查找Ubuntu中安装软件的位置 下面仅自我学习记录只做参考 xff0c 不可全信 通常使用ps e 找到软件的具体名字 xff0c 然后进行位置查找 自我记录 1 执行程序查看 对于有的程序没有效果 type 软件名 2 通过进程查看 p
  • Python爬虫:第三章 数据解析 xpath解析(12)

    第三章 数据解析 xpath 解析xpath 解析基础example1 爬取58二手房中的房源信息example2 解析下载图片数据example3 全国城市名称爬取 xpath 解析 xpath 解析基础 span class token
  • java获取项目文件绝对路径

    该方法是先根据指定目录创建文件目录后 xff0c 再获取起绝对路径 xff0c 可先在指定目录中放入指定文件 xff0c 这样就可以直接获取起绝对路径 span class token keyword public span span cl
  • 三分钟带你了解最成熟最流行的LAMP网站应用架构

    三分钟带你了解最成熟最流行的LAMP网站应用架构 一 LAMP概述1 各组件的主要作用2 各组件安装顺序 二 编译安装Apache httpd服务准备工作1 关闭防火墙 xff0c 将安装Apache所需软件包传到 opt目录下2 安装环境
  • IDEA通过maven配置Spring保姆级教程

    写在前面 xff1a 此篇文章主要是记录IDEA利用maven配置Spring的全过程 由于本人也是慢慢探索出来的 xff0c 所以有不全或者遗漏的地方 xff0c 还请大家斧正 请耐心看完文章 xff0c 前期工作做完后IDEA才可以配置
  • 策略梯度算法(Policy Gradient)逐行代码详解

    理论部分以及完整代码参看之前的博客 xff1a https blog csdn net qq 47997583 article details 124506650 本文章介绍的是策略梯度算法中的REINFORCE实现 上图为算法流程图 xf