深度学习中的优化算法之RMSProp

2023-11-17

      之前在https://blog.csdn.net/fengbingchun/article/details/124766283 中介绍过深度学习中的优化算法AdaGrad,这里介绍下深度学习的另一种优化算法RMSProp。

      RMSProp全称为Root Mean Square Propagation,是一种未发表的自适应学习率方法,由Geoff Hinton提出,是梯度下降优化算法的扩展。如下图所示,截图来自:https://arxiv.org/pdf/1609.04747.pdf

     

       AdaGrad的一个限制是,它可能会在搜索结束时导致每个参数的步长(学习率)非常小,这可能会大大减慢搜索进度,并且可能意味着无法找到最优值。RMSProp和Adadelta都是在同一时间独立开发的,可认为是AdaGrad的扩展,都是为了解决AdaGrad急剧下降的学习率问题。

      RMSProp采用了指数加权移动平均(exponentially weighted moving average)。

      RMSProp比AdaGrad只多了一个超参数,其作用类似于动量(momentum),其值通常置为0.9

      RMSProp旨在加速优化过程,例如减少达到最优值所需的迭代次数,或提高优化算法的能力,例如获得更好的最终结果。

      以下是与AdaGrad不同的代码片段:

      1.在原有枚举类Optimizaiton的基础上新增RMSProp:

enum class Optimization {
	BGD, // Batch Gradient Descent
	SGD, // Stochastic Gradient Descent
	MBGD, // Mini-batch Gradient Descent
	SGD_Momentum, // SGD with Momentum
	AdaGrad, // Adaptive Gradient
	RMSProp // Root Mean Square Propagation
};

      2.calculate_gradient_descent函数:RMSProp与AdaGrad只有g[j]的计算不同

void LogisticRegression2::calculate_gradient_descent(int start, int end)
{
	switch (optim_) {
		case Optimization::RMSProp: {
			int len = end - start;
			std::vector<float> g(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					g[j] = mu_ * g[j] + (1. - mu_) * (dw * dw);
					w_[j] = w_[j] - alpha_ * dw / (std::sqrt(g[j]) + eps_);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::AdaGrad: {
			int len = end - start;
			std::vector<float> g(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					g[j] += dw * dw;
					w_[j] = w_[j] - alpha_ * dw / (std::sqrt(g[j]) + eps_);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::SGD_Momentum: {
			int len = end - start;
			std::vector<float> change(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float new_change = mu_ * change[j] - alpha_ * (data_->samples[random_shuffle_[i]][j] * dz[x]);
					w_[j] += new_change;
					change[j] = new_change;
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::SGD:
		case Optimization::MBGD: {
			int len = end - start;
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					w_[j] = w_[j] - alpha_ * (data_->samples[random_shuffle_[i]][j] * dz[x]);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::BGD:
		default: // BGD
			std::vector<float> z(m_, 0), dz(m_, 0);
			float db = 0.;
			std::vector<float> dw(feature_length_, 0.);
			for (int i = 0; i < m_; ++i) {
				z[i] = calculate_z(data_->samples[i]);
				o_[i] = calculate_activation_function(z[i]);
				dz[i] = calculate_loss_function_derivative(o_[i], data_->labels[i]);

				for (int j = 0; j < feature_length_; ++j) {
					dw[j] += data_->samples[i][j] * dz[i]; // dw(i)+=x(i)(j)*dz(i)
				}
				db += dz[i]; // db+=dz(i)
			}

			for (int j = 0; j < feature_length_; ++j) {
				dw[j] /= m_;
				w_[j] -= alpha_ * dw[j];
			}

			b_ -= alpha_*(db/m_);
	}
}

      执行结果如下图所示:测试函数为test_logistic_regression2_gradient_descent,多次执行每种配置,最终结果都相同。图像集使用MNIST,其中训练图像总共10000张,0和1各5000张,均来自于训练集;预测图像总共1800张,0和1各900张,均来自于测试集。在它们学习率为0.01及其它配置参数相同的情况下,AdaGrad耗时为17秒,RMSProp耗时为33秒;它们的识别率均为100%。当学习率调整为0.001时,AdaGrad耗时为26秒,RMSProp耗时为19秒;它们的识别率均为100%。

      GitHub: https://github.com/fengbingchun/NN_Test

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习中的优化算法之RMSProp 的相关文章

  • Bug解决:ModuleNotFoundError: No module named ‘taming‘

    from taming modules vqvae quantize import VectorQuantizer2 as VectorQuantizer ModuleNotFoundError No module named taming
  • 带你看懂CTC算法

    转自 https zhuanlan zhihu com p 161186907 在文本识别模型CRNN中 涉及到了CTC算法的使用 由于算法的原理涉及内容较多 所以特另开一篇文章对其原理进行叙述 自己在学习CTC过程中也是看了诸多资料才大概
  • 序列模型——自然语言处理与词嵌入(理论部分)

    1 词汇表征 深度学习已经给自然语言处理 Natural Language Process NLP 带来革命性的变革 其中一个很关键的概念是词嵌入 word embedding 这是语言表示的一种方式 可以让算法自动的了解一些类似的词 例如
  • Pytorch中计算自己模型的FLOPs

    转自 Pytorch中计算自己模型的FLOPs thop profile 方法 yolov5s 网络模型参数量 计算量统计 墨理学AI CSDN博客 Pytorch 用thop计算pytorch模型的FLOPs 简书 安装thop pip
  • 输入文本就可建模渲染了?!OpenAI祭出120亿参数魔法模型!

    转自 https new qq com omn 20210111 20210111A0CBRD00 html 2021刚刚开启 OpenAI又来放大招了 能写小说 哲学语录的GPT 3已经不足为奇 那就来一个多模态 图像版GPT 3 今天
  • tiny-cnn执行过程分析(MNIST)

    在http blog csdn net fengbingchun article details 50573841中以MNIST为例对tiny cnn的使用进行了介绍 下面对其执行过程进行分析 支持两种损失函数 1 mean squared
  • 图神经网络(1):图卷积神经网络GCN ICLR 2017

    图卷积神经网络GCN ICLR 2017 是曾经在美国加州大学UCI教授 现在荷兰阿姆斯特丹大学教授 Max Welling团队的大作 Max是图灵奖获得者Hinton的弟子 第一作者T N Kipf已经成为这个领域有名的学者和工程师 如果
  • 视觉注意力的循环神经网络模型

    我们观察PPT的时候 面对整个场景 不会一下子处理全部场景信息 而会有选择地分配注意力 每次关注不同的区域 然后将信息整合来得到整个的视觉印象 进而指导后面的眼球运动 将感兴趣的东西放在视野中心 每次只处理视野中的部分 忽略视野外区域 这样
  • Tensorflow错误InvalidArgumentError see above for traceback): No OpKernel was registered to support Op

    调用tensorflow gpu运行错误 错误信息如下 2023 06 21 15 36 14 007389 I tensorflow core platform cpu feature guard cc 141 Your CPU supp
  • PyTorch torch.optim.lr_scheduler 学习率设置 调参-- CyclicLR

    torch optim lr scheduler 学习率设置 CyclicLR 学习率的参数调整是深度学习中一个非常重要的一项 Andrew NG 吴恩达 认为一般如果想调参数 第一个一般就是学习率 作者初步学习者 有错误直接提出 热烈欢迎
  • Transformer——《Attention is all you need》

    本文是Google 机器翻译团队在2017 年发表 提出了一个新的简单的网络模型 Transformer 该模型基于纯注意力机制 Attention mechanisms 完全抛弃了RNN和CNN网络结构 在机器翻译任务上取得了很好的效果
  • 基于Lasagne实现限制玻尔兹曼机(RBM)

    RBM理论部分大家看懂这个图片就差不多了 Lasagne写代码首先要确定层与层 RBM 正向反向过程可以分别当作一个层 权值矩阵互为转置即可 代码 coding utf 8 data format is bc01 written by Ph
  • 2D和3D人体姿态数据集

    转自链接 https www jianshu com p c046db584a21 2D数据集 LSP 地址 http sam johnson io research lsp html 样本数 2k 关节点数 14 全身 单人 FLIC 地
  • cs231n: How to Train a Neuron Network 如何训练神经网络

    CS231N第六第七课时的一些笔记 如何训练神经网络是一个比较琐碎的事情 所以整理了一下 以后训练Neuron Network的时候可以看一下 Activation Functions ReLu good ELU leaky ReLu no
  • 深度学习系统为什么容易受到对抗样本的欺骗?

    转自 https zhuanlan zhihu com p 89665397 本文作者 kurffzhou 腾讯 TEG 安全工程师 最近 Nature发表了一篇关于深度学习系统被欺骗的新闻文章 该文指出了对抗样本存在的广泛性和深度学习的脆
  • 目标检测数据集分析

    原文链接 https ghlcode cn pages 250d97 目标检测数据集分析 新增支持数据集可视化 Ghlerrix DataAnalyze 平时我们经常需要对我们的数据集进行各种分析 以便我们找到更好的提高方式 所以我将我平时
  • Going Deeper with convolutions

    Going Deeper with convolutions 转载请注明 http blog csdn net stdcoutzyx article details 40759903 本篇论文是针对ImageNet2014的比赛 论文中的方
  • 损失函数和正则化

    参考 https www cnblogs com LXP Never p 10918704 html https blog csdn net Heitao5200 article details 83030465 https zhuanla
  • 【直观详解】什么是正则化

    转自 https charlesliuyx github io 2017 10 03 E3 80 90 E7 9B B4 E8 A7 82 E8 AF A6 E8 A7 A3 E3 80 91 E4 BB 80 E4 B9 88 E6 98
  • 经典网络ResNet介绍

    经典网络ResNet Residual Networks 由Kaiming He等人于2015年提出 论文名为 Deep Residual Learning for Image Recognition 论文见 https arxiv org

随机推荐

  • 三维旋转:旋转矩阵,欧拉角,四元数

    在介绍下面的文章前 大家如果接触到欧拉角的话 就一定要关注一个词 要顺规 在欧拉角体系里面 有12种顺规 这一点是好多文章没有让读书意识到 导致后面学习图形学里面的 heading pitch bank 时对不上号 一般百度百科里面说到的
  • 课程笔记2

    一 实现 1 区块链是去中心化的账本 比特币采用的是基于交易的账本模式 区块链的全节点需要维护一种名叫UTXO的数据结构 所有未花掉的交易的输出的集合 可以有效检测双花攻击 交易的总输入略微大于总输出 这是因为比特币的第二个激励机制 获得记
  • load data inpath出错原因及解决方法

    hive gt load data inpath hdfs Master hdp 9000 person txt into table Person1 FAILED SemanticException Error 10028 Line 1
  • java setcellvalue NA_java minioClient.setBucketPolicy 调用失败 折腾好几天了 求大佬解惑...

    方法调用后 提示 Request processing failed nested exception is java lang IllegalArgumentException unknown error code string Malf
  • 简要损益科目口诀,营业外收支和其他业务收支的区别

    一 损益科目口诀 三收三费所得税 两成三益外加减 三收 主营业务收入 其他业务收入 营业外收入 三费 管理费用 财务费用 销售费用 这是常用费用 某些企业可能还分有研究开发费用 两成 主营业务成本 其他业务成本 三益 投资收益 公允价值变动
  • java查看包的源代码

    把鼠标放在方法上 按Ctrl进去 打开的 class文件就是Java jdk1 7 0 src zip中的源码 但是在Java jdk1 7 0 src zip 中是以 java为扩展名
  • ios开发教程入门到精通

    第1集 初识macOS 点击观看 第2集 开发工具Xcode 点击观看 第3集 初识Objective C 点击观看 待续
  • 华为机试 牛客网 HJ1 字符串最后一个单词的长度

    华为机试 牛客网 HJ1 字符串最后一个单词的长度 描述 输入描述 输出描述 示例一 解法一 解法二 反思 描述 计算字符串最后一个单词的长度 单词以空格隔开 字符串长度小于5000 输入描述 输入一行 代表要计算的字符串 非空 长度小于5
  • shell简单脚本编写

    1 第一步 安装邮件服务 root server yum install s nail y 第二步 编辑配置文件 root server vim etc s nail rc set from 自己的qq邮箱地址 set smtp smtp
  • OpenCV - 基本知识

    1 读取并显示图片 namedWindow新建一个显示窗口 imshow输出图片 namedwindow可有可无 Mat image cv imread E 其他文档 图片 2 jpg 2 cv namedWindow 照片 CV WIND
  • window中gcc编译程序、编辑环境配置以及gcc编译程序的过程(含system函数以及CMD快捷键)

    1 system函数的使用 include
  • 关于rocketmq 中日志文件路径的配置

    前些天发现了一个巨牛的人工智能学习网站 通俗易懂 风趣幽默 忍不住分享一下给大家 点击跳转到网站 rocketmq 中的数据和日志文件默认都是存储在user home路径下面的 往往我们都需要修改这些路径到指定文件夹以便管理 服务端日志 网
  • ML-朴素贝叶斯

    参考 西瓜书 P151 以前对贝叶斯参数的计算过程不是很清楚 在西瓜书里讲的很详细 原来可以把X属性分为离散型与连续型 离散型的话可以直接按照频率计算 连续型的话 要用极大似然估计 首先假设概率密度函数满足一个分布 比如正态分布 然后利用已
  • 动态控制ToolStrip上ToolStripButton的大小(包括图标的大小)

    一 设置固定大小的ToolStripButton 设置固定大小的ToolStripButton很简单 ToolStripButton gt AutoSize属性设置为false size调整为自己想要的大小即可 同时配合的是ToolStri
  • Flutter与android原生通信

    Flutter 与 Android iOS 之间信息交互通过 Platform Channel 进行桥接 Flutter 定义了三种不同的 Channel 但无论是传递方法还是传递事件 其本质上都是数据的传递 MethodChannel 用
  • 因Redis分布式锁造成的P0级重大事故,整个项目组被扣了绩效......,请慎用

    作者 浪漫先生 出处 juejin im post 5f159cd8f265da22e425f71d 背景 我们项目中的抢购订单采用的是分布式锁来解决的 有一次 运营做了一个飞天茅台的抢购活动 库存 100 瓶 但是却超卖了 要知道 这个地
  • 在C++11通过SFINAE机制实现静态检查类成员是否存在并分情况处理,以及一种通用宏的实现

    目录 引入 目的 代码 测试 TIPS 引入 c 模板中 我们无法知道参数类是否具有某个成员 例如下面代码 我们希望下面的代码中能够打印t的成员变量a的值 然而当类型T不包含成员a时 调用下面的代码就会报错 template
  • iOS Push详述,了解一下?

    欢迎大家前往腾讯云 社区 获取更多腾讯海量技术实践干货哦 本文由WeTest质量开放平台团队发表于云 社区专栏 作者 陈裕发 腾讯系统测试工程师 商业转载请联系腾讯WeTest获得授权 非商业转载请注明出处 原文链接 http wetest
  • 安装eli5库的踩雷

    报错方法 在Anaconda Prompt中输入pip install eli5 conda install eli5指令 分别显示安装失败和未找到包 解决方法 在Anaconda Powershell Prompt中输入conda ins
  • 深度学习中的优化算法之RMSProp

    之前在https blog csdn net fengbingchun article details 124766283 中介绍过深度学习中的优化算法AdaGrad 这里介绍下深度学习的另一种优化算法RMSProp RMSProp全称为R