深度学习中的优化算法之NAG

2023-11-03

      之前在https://blog.csdn.net/fengbingchun/article/details/124648766 介绍过Momentum SGD,这里介绍下深度学习的另一种优化算法NAG。

      NAG:Nesterov Accelerated Gradient或Nesterov momentum,是梯度优化算法的扩展,在基于Momentum SGD的基础上作了改动。如下图所示,截图来自:https://arxiv.org/pdf/1609.04747.pdf

       基于动量的SGD在最小点附近会震荡,为了减少这些震荡,我们可以使用NAG。NAG与基于动量的SGD的区别在于更新梯度的方式不同。

       以下是与Momentum SGD不同的代码片段:

       1. 在原有枚举类Optimization的基础上新增NAG:

enum class Optimization {
	BGD, // Batch Gradient Descent
	SGD, // Stochastic Gradient Descent
	MBGD, // Mini-batch Gradient Descent
	SGD_Momentum, // SGD with Momentum
	AdaGrad, // Adaptive Gradient
	RMSProp, // Root Mean Square Propagation
	Adadelta, // an adaptive learning rate method
	Adam, // Adaptive Moment Estimation
	AdaMax, // a variant of Adam based on the infinity norm
	NAG // Nesterov Accelerated Gradient
};

       2. 计算z的方式不同:NAG使用z2

float LogisticRegression2::calculate_z(const std::vector<float>& feature) const
{
	float z{0.};
	for (int i = 0; i < feature_length_; ++i) {
		z += w_[i] * feature[i];
	}
	z += b_;

	return z;
}

float LogisticRegression2::calculate_z2(const std::vector<float>& feature, const std::vector<float>& vw) const
{
	float z{0.};
	for (int i = 0; i < feature_length_; ++i) {
		z += (w_[i] - mu_ * vw[i]) * feature[i];
	}
	z += b_;

	return z;
}

       3. calculate_gradient_descent函数:

void LogisticRegression2::calculate_gradient_descent(int start, int end)
{
	switch (optim_) {
		case Optimization::NAG: {
			int len = end - start;
			std::vector<float> v(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z2(data_->samples[random_shuffle_[i]], v);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					v[j] = mu_ * v[j] + alpha_ * dw; // formula 5
					w_[j] = w_[j] - v[j];
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::AdaMax: {
			int len = end - start;
			std::vector<float> m(feature_length_, 0.), u(feature_length_, 1e-8), mhat(feature_length_, 0.);
			std::vector<float> z(len, 0.), dz(len, 0.);
			float beta1t = 1.;
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				beta1t *= beta1_;

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					m[j] = beta1_ * m[j] + (1. - beta1_) * dw; // formula 19
					u[j] = std::max(beta2_ * u[j], std::fabs(dw)); // formula 24

					mhat[j] = m[j] / (1. - beta1t); // formula 20

					// Note: need to ensure than u[j] cannot be 0.
					// (1). u[j] is initialized to 1e-8, or
					// (2). if u[j] is initialized to 0., then u[j] adjusts to (u[j] + 1e-8)
					w_[j] = w_[j] - alpha_ * mhat[j] / u[j]; // formula 25
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::Adam: {
			int len = end - start;
			std::vector<float> m(feature_length_, 0.), v(feature_length_, 0.), mhat(feature_length_, 0.), vhat(feature_length_, 0.);
			std::vector<float> z(len, 0.), dz(len, 0.);
			float beta1t = 1., beta2t = 1.;
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				beta1t *= beta1_;
				beta2t *= beta2_;

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					m[j] = beta1_ * m[j] + (1. - beta1_) * dw; // formula 19
					v[j] = beta2_ * v[j] + (1. - beta2_) * (dw * dw); // formula 19

					mhat[j] = m[j] / (1. - beta1t); // formula 20
					vhat[j] = v[j] / (1. - beta2t); // formula 20

					w_[j] = w_[j] - alpha_ * mhat[j] / (std::sqrt(vhat[j]) + eps_); // formula 21
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::Adadelta: {
			int len = end - start;
			std::vector<float> g(feature_length_, 0.), p(feature_length_, 0.);
			std::vector<float> z(len, 0.), dz(len, 0.);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					g[j] = mu_ * g[j] + (1. - mu_) * (dw * dw); // formula 10

					//float alpha = std::sqrt(p[j] + eps_) / std::sqrt(g[j] + eps_);
					float change = -std::sqrt(p[j] + eps_) / std::sqrt(g[j] + eps_) * dw; // formula 17
					w_[j] = w_[j] + change;

					p[j] = mu_ * p[j] +  (1. - mu_) * (change * change); // formula 15
				}

				b_ -= (eps_ * dz[x]);
			}
		}
			break;
		case Optimization::RMSProp: {
			int len = end - start;
			std::vector<float> g(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					g[j] = mu_ * g[j] + (1. - mu_) * (dw * dw); // formula 18
					w_[j] = w_[j] - alpha_ * dw / std::sqrt(g[j] + eps_);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::AdaGrad: {
			int len = end - start;
			std::vector<float> g(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					g[j] += dw * dw;
					w_[j] = w_[j] - alpha_ * dw / std::sqrt(g[j] + eps_); // formula 8
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::SGD_Momentum: {
			int len = end - start;
			std::vector<float> v(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					v[j] = mu_ * v[j] + alpha_ * dw; // formula 4
					w_[j] = w_[j] - v[j];
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::SGD:
		case Optimization::MBGD: {
			int len = end - start;
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					w_[j] = w_[j] - alpha_ * dw;
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::BGD:
		default: // BGD
			std::vector<float> z(m_, 0), dz(m_, 0);
			float db = 0.;
			std::vector<float> dw(feature_length_, 0.);
			for (int i = 0; i < m_; ++i) {
				z[i] = calculate_z(data_->samples[i]);
				o_[i] = calculate_activation_function(z[i]);
				dz[i] = calculate_loss_function_derivative(o_[i], data_->labels[i]);

				for (int j = 0; j < feature_length_; ++j) {
					dw[j] += data_->samples[i][j] * dz[i]; // dw(i)+=x(i)(j)*dz(i)
				}
				db += dz[i]; // db+=dz(i)
			}

			for (int j = 0; j < feature_length_; ++j) {
				dw[j] /= m_;
				w_[j] -= alpha_ * dw[j];
			}

			b_ -= alpha_*(db/m_);
	}
}

       执行结果如下图所示:测试函数为test_logistic_regression2_gradient_descent,多次执行每种配置,最终结果都相同。图像集使用MNIST,其中训练图像总共10000张,0和1各5000张,均来自于训练集;预测图像总共1800张,0和1各900张,均来自于测试集。NAG和Momentum SGD配置参数相同的情况下,即学习率为0.01,动量设为0.7,它们的耗时均为6秒,识别率均为100%

       GitHubhttps://github.com/fengbingchun/NN_Test

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习中的优化算法之NAG 的相关文章

  • Bug解决:ModuleNotFoundError: No module named ‘taming‘

    from taming modules vqvae quantize import VectorQuantizer2 as VectorQuantizer ModuleNotFoundError No module named taming
  • 带你看懂CTC算法

    转自 https zhuanlan zhihu com p 161186907 在文本识别模型CRNN中 涉及到了CTC算法的使用 由于算法的原理涉及内容较多 所以特另开一篇文章对其原理进行叙述 自己在学习CTC过程中也是看了诸多资料才大概
  • 【阅读论文方法总结】

    1 快速浏览摘要 看是否有自己需要的东西 2 如果需要 github上查找相关论文代码 对照着论文进行阅读 这样效率高 能够快速理解
  • 基于Pytorch的模型推理

    训练部分说明 假设我们现在有两个文件 first file train py 用于训练模型 second file inference py 用于推理检测 在train py文件中我们使用了定义了一个类 里面声明了我的网络模型 例如 cla
  • DOTA: A Large-scale Dataset for Object Detection in Aerial Images 翻译

    DOTA 用于航空图像中目标检测的大规模数据集 原文 https arxiv org pdf 1711 10398 pdf 官网 https captain whu github io DOTA dataset https captain
  • 3D人体重建方法漫谈

    转自 https blog csdn net Asimov Liu article details 96442990 1 概述 2 模型匹配的方法 2 1SMPL Skinned Multi Person Linear model 模型 2
  • 深度学习之图像分类(一)--分类模型的混淆矩阵

    深度学习之图像分类 一 分类模型的混淆矩阵 深度学习之图像分类 一 分类模型的混淆矩阵 1 混淆矩阵 1 1 二分类混淆矩阵 1 2 混淆矩阵计算实例 2 混淆矩阵代码 3 混淆矩阵用途 深度学习之图像分类 一 分类模型的混淆矩阵 今天开始
  • 朴素贝叶斯分类器简介及C++实现(性别分类)

    贝叶斯分类器是一种基于贝叶斯定理的简单概率分类器 在机器学习中 朴素贝叶斯分类器是一系列以假设特征之间强 朴素 独立下运用贝叶斯定理为基础的简单概率分类器 朴素贝叶斯是文本分类的一种热门 基准 方法 文本分类是以词频为特征判断文件所属类别或
  • 目标检测基础

    什么是目标检测 简单来说就是 检测图片中物体所在的位置 本文只介绍用深度学习的方法进行目标检测 同过举出几个特性来帮助各位理解目标检测任务 同时建议学习目标检测应先具备物体人工智能算法基础和物体分类现实基础 特性1 Bounding Box
  • Could not load dynamic library ‘libcupti.so.10.0‘; dlerror: libcupti.so.10.0...

    环境 Ubuntu 16 04 CUDA 10 0 CUDNN 7 6 5 nvcc NVIDIA R Cuda compiler driver Copyright c 2005 2018 NVIDIA Corporation Built
  • libsvm库简介及使用

    libsvm是基于支持向量机 support vector machine SVM 实现的开源库 由台湾大学林智仁 Chih Jen Lin 教授等开发 它主要用于分类 支持二分类和多分类 和回归 它的License是BSD 3 Claus
  • pytorch 入门 DenseNet

    知识点0 dense block的结构 知识点1 定义dense block 知识点2 定义DenseNet的主体 知识点3 add module 知识点 densenet是由 多个这种结构串联而成的 import torch import
  • 几乎最全的中文NLP资源库

    NLP民工的乐园 The Most Powerful NLP Weapon Arsenal NLP民工的乐园 几乎最全的中文NLP资源库 词库 工具包 学习资料 在入门到熟悉NLP的过程中 用到了很多github上的包 遂整理了一下 分享在
  • 深度学习中的验证集和超参数简介

    大多数机器学习算法都有超参数 可以设置来控制算法行为 超参数的值不是通过学习算法本身学习出来的 尽管我们可以设计一个嵌套的学习过程 一个学习算法为另一个学习算法学出最优超参数 在多项式回归示例中 有一个超参数 多项式的次数 作为容量超参数
  • 16个车辆信息检测数据集收集汇总(简介及链接)

    16个车辆信息检测数据集收集汇总 简介及链接 目录 1 UA DETRAC 2 BDD100K 自动驾驶数据集 3 综合汽车 CompCars 数据集 4 Stanford Cars Dataset 5 OpenData V11 0 车辆重
  • PyTorch训练简单的全连接神经网络:手写数字识别

    文章目录 pytorch 神经网络训练demo 输出结果 来源 pytorch 神经网络训练demo 数据集 MNIST 该数据集的内容是手写数字识别 其分为两部分 分别含有60000张训练图片和10000张测试图片 神经网络 全连接网络
  • 损失函数和正则化

    参考 https www cnblogs com LXP Never p 10918704 html https blog csdn net Heitao5200 article details 83030465 https zhuanla
  • SqueezeNet运用到Faster RCNN进行目标检测+OHEM

    目录 目录 一SqueezeNet介绍 MOTIVATION FIRE MODULE ARCHITECTURE EVALUATION 二SqueezeNet与Faster RCNN结合 三SqueezeNetFaster RCNNOHEM
  • nvidia深度学习加速库apex简单介绍

    介绍地址 https docs nvidia com deeplearning sdk mixed precision training index html 本人英文水平有限 有误请指正 使用理由 使用精度低于32位浮点的数值格式有许多好
  • 当我们谈人工智能 我们在谈论什么

    我们对一个事物的认识模糊往往是因为宣传过剩冲淡了理论的真实 我们陷在狂欢里 暂时忘记为什么要狂欢 如何踏上这趟飞速发展的列车成为越来越多人心心念念的事情 人工智能的浪潮更像是新闻舆论炒起来的话题 城外的人想进去 城内的人也不想出来 当我们谈

随机推荐