adamax参数_5 Optimizer-庖丁解牛之pytorch

2023-11-02

优化器是机器学习的很重要部分,但是在很多机器学习和深度学习的应用中,我们发现用的最多的优化器是 Adam,为什么呢?pytorch有多少优化器,我什么时候使用其他优化器?本文将详细讲述:

在torch.optim 包中有如下优化器torch.optim.adam.Adamtorch.optim.adadelta.Adadeltatorch.optim.adagrad.Adagradtorch.optim.sparse_adam.SparseAdamtorch.optim.adamax.Adamaxtorch.optim.asgd.ASGDtorch.optim.sgd.SGDtorch.optim.rprop.Rproptorch.optim.rmsprop.RMSproptorch.optim.optimizer.Optimizertorch.optim.lbfgs.LBFGStorch.optim.lr_scheduler.ReduceLROnPlateau

这些优化器都派生自Optimizer,这是一个所有优化器的基类,我们来看看这个基类:class Optimizer(object):

def __init__(self, params, defaults):        self.defaults = defaults        self.state = defaultdict(dict)        self.param_groups =  list(params)        for param_group in param_groups:

self.add_param_group(param_group)params 代表网络的参数,是一个可以迭代的对象net.parameters()

第二个参数default是一个字典,存储学习率等变量的值。

构造函数最重要的工作就是把params加入到param_groups组中

zero_graddef zero_grad(self):

r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""

for group in self.param_groups:            for p in group['params']:                if p.grad is not None:

p.grad.detach_()

p.grad.zero_()

遍历param_groups,将每个组中参数值,有梯度的都解除链接,然后清零。

state_dictdef state_dict(self):

......

param_groups = [pack_group(g) for g in self.param_groups]        # Remap state to use ids as keys

packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v                        for k, v in self.state.items()}        return {            'state': packed_state,            'param_groups': param_groups,

}

state当前优化器状态,param_groups,整理格式,以字典方式返回

def load_state_dict(self, state_dict):

state = defaultdict(dict)        for k, v in state_dict['state'].items():            if k in id_map:

param = id_map[k]

state[param] = cast(param, v)            else:

state[k] = v        # Update parameter groups, setting their 'params' value

param_groups = [

update_group(g, ng) for g, ng in zip(groups, saved_groups)]        self.__setstate__({'state': state, 'param_groups': param_groups})

整理格式,更新state和param_groups

SGD

这个优化器是最基本的优化器,d_p = p.grad.data # 梯度值

...

p.data.add_(-group['lr'], d_p) # 更新值,只是一个lr和梯度

减去学习率和梯度值的乘积,果然够简单

我们给出计算公式whitle True:

wights_grad = evaluate_gradient(loss_fun, data, weights)

weights += -step_size * weights_grad

SGD就是计算随机梯度值,然后更新当前参数。

Adam 这个名字来源于 adaptive moment estimation,自适应矩估计。概率论中矩的含义是:如果一个随机变量 X 服从某个分布,X 的一阶矩是 E(X),也就是样本平均值,X 的二阶矩就是 E(X^2),也就是样本平方的平均值。Adam 算法根据损失函数对每个参数的梯度的一阶矩估计和二阶矩估计动态调整针对于每个参数的学习速率。Adam 也是基于梯度下降的方法,但是每次迭代参数的学习步长都有一个确定的范围,不会因为很大的梯度导致很大的学习步长,参数的值比较稳定。exp_avg.mul_(beta1).add_(1 - beta1, grad)

exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)                if amsgrad:

torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)

denom = max_exp_avg_sq.sqrt().add_(group['eps'])                else:

denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']

bias_correction2 = 1 - beta2 ** state['step']

step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 # 动态调整计算步长

p.data.addcdiv_(-step_size, exp_avg, denom) # 更新值

Adagradstate['sum'].addcmul_(1, grad, grad)

std = state['sum'].sqrt().add_(1e-10)

p.data.addcdiv_(-clr, grad, std)

据说这个梯度可变,先累加addcmul_平方,还带根号,防止除零还带平滑项1e-10,果然代码不骗人

Adadelta

其实Adagrad累加平方和梯度也会猛烈下降,如果限制把历史梯度累积窗口限制到固定的尺寸,学习的过程中自己变化,看看下面的代码能读出这个意思吗?square_avg.mul_(rho).addcmul_(1 - rho, grad, grad)

std = square_avg.add(eps).sqrt_()

delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)

p.data.add_(-group['lr'], delta)

acc_delta.mul_(rho).addcmul_(1 - rho, delta, delta)

SparseAdam

实现适用于稀疏张量的Adam算法的懒惰版本。在这个变体中,只有在渐变中出现的时刻才会更新,只有渐变的那些部分才会应用于参数。exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

beta1, beta2 = group['betas']                # Decay the first and second moment running average coefficient

#      old 

# <==> old += (1 - b) * (new - old)

old_exp_avg_values = exp_avg._sparse_mask(grad)._values()

exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)

exp_avg.add_(make_sparse(exp_avg_update_values))

old_exp_avg_sq_values = exp_avg_sq._sparse_mask(grad)._values()

exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)

exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))                # Dense addition again is intended, avoiding another _sparse_mask

numer = exp_avg_update_values.add_(old_exp_avg_values)

exp_avg_sq_update_values.add_(old_exp_avg_sq_values)

denom = exp_avg_sq_update_values.sqrt_().add_(group['eps'])                del exp_avg_update_values, exp_avg_sq_update_values

bias_correction1 = 1 - beta1 ** state['step']

bias_correction2 = 1 - beta2 ** state['step']

step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

p.data.add_(make_sparse(-step_size * numer.div_(denom)))

这么复杂的公式,只能看出通过一个矩阵计算,然后更新梯度

Adamaxtorch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long()))

bias_correction = 1 - beta1 ** state['step']

clr = group['lr'] / bias_correction

p.data.addcdiv_(-clr, exp_avg, exp_inf)

看到torch.max估计明白为甚叫Adamax了,给学习率的边界做个上限

ASGDstate['step'] += 1                if group['weight_decay'] != 0:

grad = grad.add(group['weight_decay'], p.data)                # decay term

p.data.mul_(1 - group['lambd'] * state['eta'])                # update parameter

p.data.add_(-state['eta'], grad)                # averaging

if state['mu'] != 1:

state['ax'].add_(p.data.sub(state['ax']).mul(state['mu']))                else:

state['ax'].copy_(p.data)                # update eta and mu

state['eta'] = (group['lr'] /

math.pow((1 + group['lambd'] * group['lr'] * state['step']), group['alpha']))

state['mu'] = 1 / max(1, state['step'] - group['t0'])

使劲看,唯一能看出平均的含义就是eta 和 mu要累加统计。

Rprop# update stepsizes with step size updates

step_size.mul_(sign).clamp_(step_size_min, step_size_max)                # for dir<0, dfdx=0

# for dir>=0 dfdx=dfdx

grad = grad.clone()

grad[sign.eq(etaminus)] = 0

# update parameters

p.data.addcmul_(-1, grad.sign(), step_size)

设定变化范围,根据符合调整

RMSpropsquare_avg = state['square_avg']

alpha = group['alpha']

state['step'] += 1                if group['weight_decay'] != 0:

grad = grad.add(group['weight_decay'], p.data)

square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)                if group['centered']:

grad_avg = state['grad_avg']

grad_avg.mul_(alpha).add_(1 - alpha, grad)

avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt().add_(group['eps'])                else:

avg = square_avg.sqrt().add_(group['eps'])                if group['momentum'] > 0:

buf = state['momentum_buffer']

buf.mul_(group['momentum']).addcdiv_(grad, avg)

p.data.add_(-group['lr'], buf)                else:

p.data.addcdiv_(-group['lr'], grad, avg)

记录每一次梯度变化,由梯度变化决定更新比例,根据符号调整步长

LBFGS

ReduceLROnPlateau

作者:readilen

链接:https://www.jianshu.com/p/18e9bef2b967

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

adamax参数_5 Optimizer-庖丁解牛之pytorch 的相关文章

  • CCSpriteFrameCache的用法

    转自 https www cnblogs com pengyingh articles 2436648 html 让我们首先创建一个工程骨架 使用cocos2d工程模板创建一个新的项目并取名为AnimBear 接下来 下载一些由我的老婆制作
  • qt day 5

    1 gt 实现闹钟功能 pro QT core gui texttospeech greaterThan QT MAJOR VERSION 4 QT widgets CONFIG c 11 The following define make
  • C# USB通讯

    关注 星标公众号 及时获取更多技术分享 作者 冰茶奥利奥 微信公众号 嵌入式电子创客街 项目工程文件下载 工程文件下载地址 看了很多网上的博客 讲述如何用C 进行USB设备操作 很多都是不对的 以至于南辕北辙 我们可以使用usb库 在c下有
  • 打开方式中无法添加程序,无法用指定程序打开

    用T32打开 ts2文件时 右击 打开方式 中 浏览 到t32start exe的安装目录 点击 确定 可是在 打开方式 中找不到t32start exe程序 可能是因为注册表中t32start exe程序的路径指定错误 原因 解压安装包
  • 通过自定义 Vue 指令实现前端曝光埋点

    前言 互联网发展至今 数据的重要性已经不言而喻 尤其是在电商公司 数据的统计分析尤为重要 通过数据分析可以提升用户的购买体验 方便运营和产品调整销售策略等等 埋点就是网站分析的一种常用的数据采集方法 埋点按照获取数据的方式一般可以分为以下
  • flink学习day05:checkpoint 原理与实践

    flink checkpoint checkpointe是什么 基于state出发 flink基于与state可以做非常多复杂的事情 但是state是存储在内存中 内存中的数据是不安全的易丢失的 所以flink为了解决这个问题就引入了che
  • 牛客网C语言编程初学者入门训练135题

    文章目录 1 实践出真知 2 我是大V 3 有容乃大 4 小飞机 5 反向输出四位数 6 大小写转换 7 缩短二进制 8 十六进制转十进制 9 printf的返回值 10 成绩输入输出 11 学生基本信息输入输出 12 字符金字塔 13 判
  • 线稿图视频制作--从此短视频平台不缺上传视频了

    博客首页 knighthood2001 欢迎点赞 评论 热爱python 期待与大家一同进步成长 给大家推荐一款很火爆的刷题 面试求职网站 跟我一起来巩固基础 开启刷题之旅吧 这年头还不来尝试线稿图视频 之前笔者也写过将视频转换为线稿图视频
  • 基于Docker的mysql主从复制

    目录 一 拉取mysql 二 启动两个mysql容器 2 1 主master 2 2 从slave 三 配置master 3 1 进入master内部配置 3 2 安装vim命令 3 3 重启mysql 3 4 创建数据同步用户 可不建 直
  • Linux重启nfs出现没有权限,Linux NFS搭建与错误提示解决

    Linux NFS搭建与错误提示解决 服务端设置 root server cat etc redhat release 查看操作系统版本信息 CentOS release 5 5 Final root server uname r 查看当前
  • 常见的错误-04

    引言 在公司配置新电脑环境时候 在安装和配置完所有VSCode软件以及C 环境后 ubuntun环境下 尝试使用debug进行代码调试 遇到了在debug过程中不输出结果的bug 如下图 未输出array以及zheli 解决方法 在ubun
  • vue3+ts中对getCurrentInstance的使用

    1 在main js中挂载一个全局属性 拿axios举例 import App from App vue import axios from http 封装的axios方法 const app createApp App 创建应用 app
  • 【100%通过率 】【华为OD机试 c++/java/python】对称字符串【 2023 Q1 A卷

    华为OD机试 题目列表 2023Q1 点这里 2023华为OD机试 刷题指南 点这里 题目描述 对称美学 对称就是最大的美学 现有一道关于对称字符串的美学 已知 第 1 个字符串 R 第 2 个字符串 BR 第 3 个字符串 RBBR 第
  • resetlog

    来自于itpub的一篇文章 http space itpub net 16628454 很多人说 resetlogs就是不完全恢复 这是不对的 做不完全恢复必须使用resetlogs 但resetlogs也可以做完全恢复 而noresetl
  • # 第四届蓝桥杯JavaB组省赛-马虎的算式

    第四届蓝桥杯JavaB组省赛 马虎的算式 题目描述 小明是个急性子 上小学的时候经常把老师写在黑板上的题目抄错了 有一次 老师出的题目是 36 x 495 他却给抄成了 396 x 45 但结果却很戏剧性 他的答案竟然是对的 因为 36 4
  • 解决idea文件properties中文乱码问题

    有时候将项目代码拉取至本地用idea打开时会出现中文乱码问题 遇到这种问题不要慌 重新设置一下编码为UTF 8即可 那么如何将idea的编码统一设置为UTF 8格式呢 接下来我们一一解决此类问题 1 打开idea编译器 有时候会看到打开的文
  • WebGIS工程师进阶训练营

    WebGIS工程师进阶训练营 1 WebGIS课程综述 2 多类情景部署SuperMap iServer 2 1 Linux环境部署SuperMap iServer 2 2 war包部署 2 3 常见问题排查 3 SuperMap iSer
  • word添加、更新目录

    1 显示导航窗口 视图 导航窗口 2 文档中的目录 2 1 插入目录 引用 目录 2 2 更新目录 方式一 点击下图 更新目录 方式二 引用 更新目录

随机推荐

  • WinForm使用鼠标裁剪图像

    之前做一个试卷识别的项目的时候需要预先将各个部分裁剪开然后进行识别 而网上的裁剪函数都是记录鼠标的位置然后进行裁剪 public static Bitmap PartDraw Image src Rectangle cutpart 切割图片
  • (休息几天)读米什金之货币银行学——货币与汇率

    1货币 当一国货币升值时 相对于其他货币价值上升 则该国商品在国外变得更贵 而外国商品唉本国则变得更便宜 相反 一国货币贬值 则该国商品在国外更便宜 而外国商品在本国则变得更贵 货币升值使得本国制造的商品在国外竞争力下降 而国外商品在本国竞
  • Koa2.js router 异步返回ctx.body失效的问题

    koa2 js 用router返回数据时 正常写法如下 我是将接口封装了 一个很普通的koa2 js get请求 router put getUserInfo ctx next gt const data ctx request body
  • PHP自己的框架2.0版本目录结构和命名空间自动加载类(重构篇一)

    目录 1 目录结构演示效果 2 搭建目录结构 以及入口public gt index php 3 引入core下面core gt base php 4 自动加载实现core gt fm gt autoload php 5 框架运行文件cor
  • Basic Level 1012 数字分类 (20分)

    题目 给定一系列正整数 请按要求对数字进行分类 并输出以下 5 个数字 A 1 A 1 A1 能被 5 整除的数字中所有偶数的和 A 2
  • matlab 取余(rem)和取模(mod)的区别

    取余 rem 和取模 mod 的区别 Matlab 生成机制 取余 采取fix 函数 向0方向取整 取模 采取floor 函数 向无穷小方向取整 当A B异号时 其实同号也是这个规律 取余 结果和A同号 取模 结果和B同号 PS 在js c
  • ASP .net core 整合 nacos 通过Spring Cloud Gateway 网关访问

    ASP net core 整合 nacos 通过Spring Cloud Gateway 网关访问 使用vs创建web项目 选择api 注意这里要取消掉Https配置否则使用网关转发也需要配置为https请求这里我们直接取消 添加nacos
  • WebRTC实现多人视频聊天

    写在前面 实现房间内人员的视频聊天 由于并未很完善 所以需要严格按照步骤来 当然基于此完善 就是时间的问题了 架构 整个设计架构如下 图片来自于参考博文 我使用的是第一种Mesh 架构 无需任何流媒体服务器 直接利用成熟的WebRTC 协议
  • windows10进程查询命令、端口占用查询命令、杀进程命令

    windows环境下编码开发经常遇到端口占用问题 解决时需要找到对应进程杀掉 释放占用 自己常用的几项操作命令如下 首先 打开Windows的命令窗口 键盘 win R 输入cmd 回车 1 查询端口被占用的进程 命令 netstat ao
  • 马虎的算式 有一次,老师出的题目是:36 x 495 = ?他却给抄成了:396 x 45 = ? 但结果却很戏剧性,他的答案竟然是对的!!

    马虎的算式 小明是个急性子 上小学的时候经常把老师写在黑板上的题目抄错了 有一次 老师出的题目是 36 x 495 他却给抄成了 396 x 45 但结果却很戏剧性 他的答案竟然是对的 因为 36 495 396 45 17820 类似这样
  • 信息传递【NOIP2015】【强连通分量 Tarjan】

    题目链接 题目描述 有 n 个同学 编号为 1 到 n 正在玩一个信息传递的游戏 在游戏里每人都有一个固定的信息传递对象 其中 编号为 i 的同学的信息传递对象是编号为Ti的同学 游戏开始时 每人都只知道自己的生日 之后每一轮中 所有人会同
  • python链家新房信息获取练习

    使用python对链家新房相关数据进行爬取 并进行持久化存储 文章目录 前言 一 页面分析 二 代码编写 1 数据库表的建立 2 代码编写 结果 前言 保持练习 以下是本篇文章正文内容 下面案例可供参考 一 页面分析 老样子进行页面分析 u
  • 解决在win10下DNS_PROBE_FINISHED_BAD_CONFIG问题

    解决在win10下DNS PROBE FINISHED BAD CONFIG问题 打开控制面板 进入 网络和 Internet 进入 网络和共享中心 进入 更改适配器设置 选择当前使用的网络链接适配器 点击 属性 选择 Internet协议
  • C++:带内嵌对象成员的派生类的构造函数,析构函数的声明方式与执行的先后顺序

    声明了某个带内嵌对象成员的派生类的对象并进行初始化时 我们要使用到派生类的构造函数 在这时 派生类的构造函数会调用内嵌对象 父类 基类 的构造函数 那么 这些构造函数的执行顺序是什么呢 我们知道 被继承的类可以被叫做父类或基类 因此它作为构
  • 黑马Python教程实战项目--美多商城(五)

    一 用户基本信息 首先需要为用户模型类 也就是用户数据表 补充一个邮箱验证状态字段 用来记录用户的邮箱是否验证成功 然后新建用户中心视图类 继承LoginRequiredMixin和View类 在子路由中添加路由 定义get方法 在requ
  • 虚拟机非正常关机,重启网卡

    在命令行运行以下命令即可重新连接上网络 sudo service network manager stop sudo rm var lib NetworkManager NetworkManager state sudo service n
  • Google云

    Google 云计算 Cloud Computing 是个新概念 但也不过是分布式处理 Distributed Computing 并行处理 Parallel Computing 和网格计算 Grid Computing 的发展 也许是一个
  • 余弦计算相似度度量

    目录 pytorch 余弦相似度 余弦计算相似度度量 pytorch 余弦相似度 余弦相似度1到 1之间 1代表正相关 0代表不相关 1代表负相关 def l2 norm input axis 1 norm torch norm input
  • [改善Java代码]适当设置阻塞队列长度

    阻塞队列BlockingQueue扩展了Queue Collection接口 对元素的插入和提取使用了 阻塞 处理 我们知道Collection下的实现类一般都采用了长度自行管理方式 也就是变长
  • adamax参数_5 Optimizer-庖丁解牛之pytorch

    优化器是机器学习的很重要部分 但是在很多机器学习和深度学习的应用中 我们发现用的最多的优化器是 Adam 为什么呢 pytorch有多少优化器 我什么时候使用其他优化器 本文将详细讲述 在torch optim 包中有如下优化器torch