强化学习算法复现(六):DoubleDQN_gym倒立摆

2023-05-16

在这里插入图片描述

建立RL_brain.py

import torch
import torch.nn as nn
import torch.nn.functional as F                 # 导入torch.nn.functional (激活函数)
import numpy as np


class Net(nn.Module):  # 建立网络
    def __init__(self, N_STATES, N_ACTIONS):
        nn.Module.__init__(self)
        self.input_num = N_STATES
        self.output_num = N_ACTIONS

        self.fc1 = nn.Linear(self.input_num, 50)             # 输入层——————>隐藏层: 状态数————>50个神经元
        self.fc1.weight.data.normal_(0, 0.1)           # 权重初始化 (均值为0,方差为0.1的正态分布)

        self.out = nn.Linear(50, self.output_num)            # 隐藏层——————>输出层: 50个神经元——————>各动作的价值
        self.out.weight.data.normal_(0, 0.1)           # 权重初始化 (均值为0,方差为0.1的正态分布)

    def forward(self, state):  # 前向传播
        x = self.fc1(state)                                # 经过第一层
        x = F.relu(x)                                  # 使用激励函数ReLU
        actions_value = self.out(x)                    # 隐藏层————————>输出层

        return actions_value                           # 输出动作的价值


class DQN(object):  # 定义DQN类 (定义两个网络)
    def __init__(self, N_STATES, N_ACTIONS, MEMORY_CAPACITY = 2000, EPSILON = 0.9, LR = 0.1, BATCH_SIZE = 32 , TARGET_REPLACE_ITER = 100, GAMMA = 0.8):
        # 超参数介绍
        self.n_states = N_STATES  # 状态个数
        self.n_actions = N_ACTIONS  # 动作个数
        self.memory_capacity = MEMORY_CAPACITY  # 记忆库容量
        self.epsilon = EPSILON   # e_贪心
        self.batch_size = BATCH_SIZE  #
        self.taget_replace_iter = TARGET_REPLACE_ITER
        self.gamma = GAMMA  # 折扣系数

        net = Net(N_STATES, N_ACTIONS)
        self.eval_net, self.target_net = net, net       # 利用Net创建评估网络和目标网络

        self.learn_step_counter = 0

        self.memory_counter = 0
        self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2), dtype=float)   # 存储训练数据

        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)    # 使用Adam优化器
        self.loss_func = nn.MSELoss()                                           # 使用均方误差为损失函数

    def choose_action(self, state):
        state = torch.FloatTensor(state)    # 转换成tensor类型
        state = torch.unsqueeze(state, 0)   # 增加一个维度

        if np.random.uniform() < self.epsilon:                   # greedy
            actions_value = self.eval_net.forward(state)         # 对评估网络前向传播,得到action_value(tensor)
            action = torch.max(actions_value, 1)[1].numpy()      # 输出每一行最大值的索引,并转化为数组形式
            action = action[0]                                   # 从数组中提取数字

        else:                                                    # explore 随机选择
            action = np.random.choice(self.n_actions)

        return action

    def store_transition(self, s, a, r, s_):    # 将一个transition储存进记忆库中
        transition = np.hstack((s, a, r, s_))   # 在水平方向上向右边拼接
        index = self.memory_counter % self.memory_capacity   # 取余数,使得记忆库内容中新的数据自动覆盖
        self.memory[index, :] = transition
        self.memory_counter += 1                                           # memory_counter自加1

    def learn(self):
        # 目标网络参数更新
        if self.learn_step_counter % self.taget_replace_iter == 0:
            self.target_net.load_state_dict(self.eval_net.state_dict())        # 将评估网络的参数赋给目标网络
        self.learn_step_counter += 1                                            # 学习步数自加1

        # 从记忆池中随机抽取一批数据batch
        sample_index = np.random.choice(self.memory_capacity, self.batch_size)
        batch_memory = self.memory[sample_index, :]

        # 将其代表的储存为tensor形式,float类型
        b_s = torch.FloatTensor(batch_memory[:, :self.n_states])
        b_a = torch.LongTensor(batch_memory[:, self.n_states:self.n_states+1].astype(int))
        b_r = torch.FloatTensor(batch_memory[:, self.n_states+1:self.n_states+2])
        b_s_ = torch.FloatTensor(batch_memory[:, -self.n_states:])

        q_eval = self.eval_net(b_s).gather(1, b_a)  # Q(s)估计   提取b_s,b_a对应的值

        q_next = self.target_net(b_s_).detach()
        q_target = b_r + self.gamma * q_next.max(1)[0].view(self.batch_size, 1)   # Q(s')现实 = b_r+gamma*max【】

        loss = self.loss_func(q_eval, q_target)

        # 反向传播,使用优化器更新参数
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
import time
import gym
from DQN_brain import *

start = time.time()
env = gym.make('CartPole-v1')

N_ACTIONS = env.action_space.n  # 杆子动作空间(2个动作)  (0左1右)
N_STATES = env.observation_space.shape[0]  # 杆子状态空间 (4维)
MEMORY_CAPACITY = 1000

dqn = DQN(N_STATES, N_ACTIONS, MEMORY_CAPACITY)

for i_episode in range(400):  # 400个episode循环
    s = env.reset()
    ep_r = 0
    for t in range(1000):
        env.render()
        a = dqn.choose_action(s)  # 根据网络选择动作
        s_, r, done, info = env.step(a)

        # 修改reward , 为了更好的收敛
        x, x_dot, theta, theta_dot = s_
        r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8
        r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
        r = r1 + r2

        dqn.store_transition(s, a, r, s_)  # 存储进记忆池
        ep_r += r

        if dqn.memory_counter == MEMORY_CAPACITY:
            print("记忆库已经满,开始学习")

        if dqn.memory_counter > MEMORY_CAPACITY:
            dqn.learn()
            if done:
                print('Ep: ', i_episode, '| Ep_r: ', ep_r)  # round()方法返回ep_r的小数点四舍五入到2个数字

        if done:
            break  # 该episode结束
        s = s_  # 更新状态

env.close()
interval = time.time() - start
print("Time: %.4f" % interval)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

强化学习算法复现(六):DoubleDQN_gym倒立摆 的相关文章

  • Springboot整合RabbitMQ实现广播模型

    生产者代码 package com example newrabbitmq span class token punctuation span span class token function import span org junit
  • opencv 视频自动截图-小工具

    采用 xff1a python多进程 实测 cpu i7 9700 92秒 13456张 xff0c 平均约每秒146张 span class token keyword import span multiprocessing span c
  • halcon基础语法

    判断 xff1a if else 与 switch 与 and 或 or 非 not 条件 其中1个成立 xff0c 则为真 xff0c 其他情况均为假 xor dev open window span class token punctu
  • C# 字符串各种操作

    C 字符串 xff1a 字符串 转 char类型的数组字符串 批量合成字符串大小写转换字符串分割字符串替换字符串 是否包含字符串 比较字符串截取字符串 是否 以什么开头 结尾 字符串 第一个 最后一个 字符串 去除空格字符串 空与null
  • C# 里氏转换

    子类可以赋值给父类 子类可以转换为父类 protected 修饰符 是让这个字段 xff0c 在子类中也可以访问 设定访问权限 span class token keyword using span span class token nam
  • Node.js — 内置API模块

    文章目录 0 JS 与 Node js的理解 内置API模块1 导入 fs 模块 xff0c 导入文件系统模块2 导入 path 模块 xff0c 读取文件 xff0c 路径处理模块3 http xff08 创建web服务器的 xff09
  • C# 列表:ArrayList、字典:Hashtable、增删改查

    添加单个 对象或多个 删除指定单个 范围 清空 是否 包含 不包含 索引直接修改 列表 xff1a ArrayList span class token comment 集合 xff08 我称之为列表 xff09 span span cla
  • C# path类:操作路径、File类:操作文件、文件流读写

    路径操作 span class token class name span class token keyword string span span str span class token operator 61 span span cl
  • C# 列表:list 字典:dict

    列表 list 增删改查 与数组转换 span class token comment 创建泛型集合 span span class token class name List span class token punctuation lt
  • C# 虚方法多态、抽象类多态、接口

    C 虚方法多态 抽象类多态 虚方法 xff1a 希望重新父类中的某个方法时 xff0c 使用虚方法 抽象类 xff1a 有多个规定的处理方式 xff0c 但实际实现的方式不同 xff0c 使用抽象类 抽象类就是为了设立规范 xff0c 为了
  • halcon 学习:图像读取与保存、查看类型、图像大小、转为灰度、分割与合并通道、获取图像指针、

    halcon 学习 xff1a 图像读取与保存 查看类型 图像大小 转为灰度 分割与合并通道 获取图像指针 read image span class token punctuation span Image span class toke
  • halcon 获取XLD亚像素的测量距离

    read image span class token punctuation span Image span class token string 39 F xue xi 1 png 39 span span class token pu
  • halcon 基础语法:数组、向量、字典、

    数组 数组 Tuple 1 span class token operator 61 span span class token punctuation span span class token number 1 span span cl
  • halcon 电气柜绿灯位置安装是否正确

    思路 xff1a 找到电气柜的区域根据绿色通道对绿色敏感 xff0c 找到绿色按钮特征过滤 xff0c 与空对象相比 xff0c 确定绿色按钮是否存在填充后截取出来 xff0c 转为亚像素 xff0c 筛选和计算中心点中心点是否在规定区域内
  • C# 基础语法示例

    基础语法 span class token keyword using span span class token namespace System span span class token punctuation span span c
  • C# 委托的基础使用

    委托 我认为委托实际上是一种设计模式的封装 委托的本质就是 xff0c 将大致的流程定下来 xff08 包含输入与输出格式 xff09 xff0c 将其中计算的细节由一个被委托的函数进行具体实现 委托 xff1a 将函数当作形参进行传递 要
  • Node.js — 模块化

    文章目录 1 模块化1 模块化与作用域2 module 模块3 npm与包 1 模块化 1 模块化与作用域 编程领域中的模块化 xff1a 遵守固定的规则 xff0c 把一个大文件拆成独立并相互依赖的多个小模块 模块化好处 xff1a 提高
  • C# 多线程示例

    百度网盘原代码连接 xff1a 链接 xff1a https pan baidu com s 19W3RFOarQtaUQDv L4tmkw 提取码 xff1a q47x span class token keyword using spa
  • C# 读取和写入json文件

    方式一 首先添加引用 导入 using System Web Script Serialization using System IO 需要封装一个类 演示 xff1a 为了方便演示 xff0c 本次使用的是控制台 span class t
  • flask框架学习 git、post请求

    templates 文件夹是放置html文件的 xff0c 否则路径不对会报错 demo1 py 文件内容 span class token keyword from span flask span class token keyword

随机推荐

  • falsk框架 使用post请求,发送与接收json格式的内容(最小的示例)

    falsk框架 使用post请求 xff0c 发送与接收json格式的内容 先运行接收 xff0c 再运行发送 发送 span class token keyword import span requests url span class
  • C# 打包项目

    在线安装 xff1a 打包 xff1a 拓展 在线搜索 xff1a installer 离线安装 xff1a 网址 xff1a https marketplace visualstudio com 搜索输入 xff1a setup inst
  • Windows磁盘管理右键无法删除卷,右键只有帮助选项按钮(转发)

    转发自 xff1a https blog csdn net github 39581355 article details 107670379 Windows磁盘管理右键无法删除卷 xff0c 右键只有帮助选项按钮 问题 xff1a 电脑更
  • SQL语句学习笔记(对库、表、字段、的操作)

    查看mysql的状态 status 启动 停止 mySQL服务 图像界面方法 xff1a dos窗口执行 xff1a services msc 控制面板 gt 管理工具 gt 服务 命令行方法 xff1a 启动 xff1a net star
  • 请求的转发和重定向

    请求的转发和重定向实际上是用来派遣视图页面的 xff0c 不同于ajax xff0c ajax用于页面中数据的的取出和保存 请求的转发 xff1a 用户在页面发送请求到后台 xff0c 如果没有在web xml配置中拦截请求的话 xff0c
  • springmvc-文字解释(无图)

    1 前端控制器 xff08 DisatcherServlet xff09 我们的请求发送到后台 xff0c 通过web xml截取请求 xff0c 通过前端控制器分配该请求 2 处理器映射 xff08 HandlerMapping xff0
  • 在线考试系统

    在线考试系统源码 前端开发语言有 xff1a html xff0c css xff0c javascript xff0c jsp xff0c jstl等 xff0c 前端框架 xff1a jQuery xff0c easyui layui
  • Node.js-- Express

    文章目录 0 学习目标 1 理解express2 基本使用3 路由4 Express中间件1 调用流程 xff1a 2 格式 xff1a 3 next 函数作用 xff1a 4 定义中间件函数5 全局生效的中间件 6 中间件作用 xff1a
  • 数据结构笔记

    一 数据结构是什么 xff1f 数据结构就是已某种特定方式存储数据 xff0c 按某种结构把数据结构化然后存储到内存容器当中 二 我们为什么需要数据结构 xff1f 结构化存储可以让数据有不同的形态 xff0c 我们通过构造多种结构来解决数
  • ES6新特性-含代码-通俗易懂

    一 新增const let变量 const用来定义常量 xff0c 它保存的值是不能再次改变的 这里说的是基本类型 xff0c 如果是对象类型则不可改变其内存地址 可以改变对象中的内容 xff0c 同时也不能多次定义同名变量 const v
  • 考试系统-新版

    最新考试系统 通过JSP xff08 Java Server Page xff09 技术和Tomcat服务器搭建的一个在线考试系统的设计与实现 针对目前的教学考核都普遍存在有选择题 xff0c 题型都是有固定的答案形式 本在线考试系统设计成
  • 金融系统-基金管理

    金融系统 基金管理 本项目为携投基金系统 xff0c 在客户端浏览器输入网址 xff0c 即可载入该系统 xff0c 本系统采用当前主流前端开发语言有 xff1a layui js等前端主流技术 采用的后端开发语言框架等有 xff1a SS
  • 学校教材管理系统-毕设、课设(最佳参考)

    下载地址 高校教材管理系统 项目介绍 基于springboot 43 mybatis 43 jwt 43 layui 43 mybatis 43 html 43 javaScript的用于高校管理教材的系统 项目主要功能 教材信息管理 教材
  • android-校园拍卖管理系统-毕设课设-含源码

    校园拍卖系统 android 源码私信 xff0c 有回必应 xff0c 三连关注 xff0c 免费 xff01 xff01 xff01 android实现校园拍卖系统 xff0c 使用语言为java xff0c 工具idea或者andro
  • ChatGPT 未来的前景以及发展趋势

    当谈到ChatGPT的未来和发展趋势时 xff0c 需要考虑人工智能技术以及文本生成和交互的迅速发展 在这方面 xff0c ChatGPT的前景非常有希望 xff0c 因为它是一种迄今为止最先进的人工智能技术之一 ChatGPT是一种基于机
  • 同步FIFO 两种方法

    RAM 43 空满信号判断 xff0c 两种方法 一 空满标志用指针位置得到 二 空满标志用fifo的中数据的计数得到 一 当写指针超过读指针一圈 xff0c 写满 xff1b 写指针等于读指针 xff0c 读空 96 timescale
  • linux内核串口日志抓取-minicom工具使用方法

    linux抓串口日志 抓串口日志方式minicom保存串口日志log抓取主板串口日志minicom man手册 抓串口日志方式 1 xff09 问题机上 xff0c 找到串口设备 xff0c 比如 dev ttyAMA 0 1 2 3 st
  • 二叉树(七):二叉树的高度与平衡二叉树

    一 二叉树的深度与高度 1 二叉树的深度 对于二叉树中的某个节点 xff0c 其深度是从根节点到该节点的最长简单路径所包含的节点个数 xff0c 是从上面向下面数的 因此访问某个节点的深度要使用先序遍历 2 二叉树的高度 对于二叉树中的某个
  • Python --语法自纠

    文章目录 1 输入2 数据类型转换 xff0c 字符串3 字典 xff0c 列表 xff0c 元组4 语法0 错题 1 输入 输入eval作用一次输入一个或多个 map print format m n format输出 2 数据类型转换
  • 强化学习算法复现(六):DoubleDQN_gym倒立摆

    建立RL brain py span class token keyword import span torch span class token keyword import span torch span class token pun