Deep learning:四十一(Dropout简单理解)

2023-10-29

 

  前言

  训练神经网络模型时,如果训练样本较少,为了防止模型过拟合,Dropout可以作为一种trikc供选择。Dropout是hintion最近2年提出的,源于其文章Improving neural networks by preventing co-adaptation of feature detectors.中文大意为:通过阻止特征检测器的共同作用来提高神经网络的性能。本篇博文就是按照这篇论文简单介绍下Dropout的思想,以及从用一个简单的例子来说明该如何使用dropout。

 

  基础知识:

  Dropout是指在模型训练时随机让网络某些隐含层节点的权重不工作,不工作的那些节点可以暂时认为不是网络结构的一部分,但是它的权重得保留下来(只是暂时不更新而已),因为下次样本输入时它可能又得工作了(有点抽象,具体实现看后面的实验部分)。

  按照hinton的文章,他使用Dropout时训练阶段和测试阶段做了如下操作:

  在样本的训练阶段,在没有采用pre-training的网络时(Dropout当然可以结合pre-training一起使用),hintion并不是像通常那样对权值采用L2范数惩罚,而是对每个隐含节点的权值L2范数设置一个上限bound,当训练过程中如果该节点不满足bound约束,则用该bound值对权值进行一个规范化操作(即同时除以该L2范数值),说是这样可以让权值更新初始的时候有个大的学习率供衰减,并且可以搜索更多的权值空间(没理解)。

  在模型的测试阶段,使用”mean network(均值网络)”来得到隐含层的输出,其实就是在网络前向传播到输出层前时隐含层节点的输出值都要减半(如果dropout的比例为50%),其理由文章说了一些,可以去查看(没理解)。

  关于Dropout,文章中没有给出任何数学解释,Hintion的直观解释和理由如下:

  1. 由于每次用输入网络的样本进行权值更新时,隐含节点都是以一定概率随机出现,因此不能保证每2个隐含节点每次都同时出现,这样权值的更新不再依赖于有固定关系隐含节点的共同作用,阻止了某些特征仅仅在其它特定特征下才有效果的情况。

  2. 可以将dropout看作是模型平均的一种。对于每次输入到网络中的样本(可能是一个样本,也可能是一个batch的样本),其对应的网络结构都是不同的,但所有的这些不同的网络结构又同时share隐含节点的权值。这样不同的样本就对应不同的模型,是bagging的一种极端情况。个人感觉这个解释稍微靠谱些,和bagging,boosting理论有点像,但又不完全相同。

  3. native bayes是dropout的一个特例。Native bayes有个错误的前提,即假设各个特征之间相互独立,这样在训练样本比较少的情况下,单独对每个特征进行学习,测试时将所有的特征都相乘,且在实际应用时效果还不错。而Droput每次不是训练一个特征,而是一部分隐含层特征。

  4. 还有一个比较有意思的解释是,Dropout类似于性别在生物进化中的角色,物种为了使适应不断变化的环境,性别的出现有效的阻止了过拟合,即避免环境改变时物种可能面临的灭亡。

  文章最后当然是show了一大把的实验来说明dropout可以阻止过拟合。这些实验都是些常见的benchmark,比如Mnist, Timit, Reuters, CIFAR-10, ImageNet.

                           

  实验过程:

  本文实验时用mnist库进行手写数字识别,训练样本2000个,测试样本1000个,用的是matlab的https://github.com/rasmusbergpalm/DeepLearnToolbox,代码在test_example_NN.m上修改得到。关于该toolbox的介绍可以参考网友的博文【面向代码】学习 Deep Learning(一)Neural Network。这里我只用了个简单的单个隐含层神经网络,隐含层节点的个数为100,所以输入层-隐含层-输出层节点依次为784-100-10. 为了使本例子简单话,没用对权值w进行规则化,采用mini-batch训练,每个mini-batch样本大小为100,迭代20次。权值采用随机初始化。

 

  实验结果:

  没用Dropout时:

  训练样本错误率(均方误差):0.032355

  测试样本错误率:15.500%

  使用Dropout时:

  训练样本错误率(均方误差):0.075819

  测试样本错误率:13.000%

  可以看出使用Dropout后,虽然训练样本的错误率较高,但是训练样本的错误率降低了,说明Dropout的泛化能力不错,可以防止过拟合。

 

  实验主要代码及注释:

  test_dropout.m:  

%% //导入minst数据并归一化
load mnist_uint8;
train_x = double(train_x(1:2000,:)) / 255;
test_x  = double(test_x(1:1000,:))  / 255;
train_y = double(train_y(1:2000,:));
test_y  = double(test_y(1:1000,:));
% //normalize
[train_x, mu, sigma] = zscore(train_x);% //归一化train_x,其中mu是个行向量,mu是个列向量
test_x = normalize(test_x, mu, sigma);% //在线测试时,归一化用的是训练样本的均值和方差,需要特别注意

%% //without dropout
rng(0);
nn = nnsetup([784 100 10]);% //初步构造了一个输入-隐含-输出层网络,其中包括了
                           % //权值的初始化,学习率,momentum,激发函数类型,
                           % //惩罚系数,dropout等
opts.numepochs =  20;   %  //Number of full sweeps through data
opts.batchsize = 100;  %  //Take a mean gradient step over this many samples
[nn, L] = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);
str = sprintf('testing error rate is: %f',er);
disp(str)

%% //with dropout
rng(0);
nn = nnsetup([784 100 10]);
nn.dropoutFraction = 0.5;   %  //Dropout fraction,每一次mini-batch样本输入训练时,随机扔掉50%的隐含层节点
opts.numepochs =  20;        %  //Number of full sweeps through data
opts.batchsize = 100;       %  //Take a mean gradient step over this many samples
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);
str = sprintf('testing error rate is: %f',er);
disp(str)

  下面来分析与dropout相关的代码,集中在上面test.m代码的后面with drop部分。首先在训练过程中需要将神经网络结构nn的dropoutFraction设置为一定比例,这里设置为50%:nn.dropoutFraction = 0.5;

  然后进入test_dropout.m中的nntrain()函数,没有发现与dropoutFraction相关的代码,继续进入网络前向传播函数nnff()函数中,在网络的隐含层节点激发函数值被计算出来后,有下面的代码:

    if(nn.dropoutFraction > 0)

            if(nn.testing)

                nn.a{i} = nn.a{i}.*(1 - nn.dropoutFraction);

            else

                nn.dropOutMask{i} = (rand(size(nn.a{i}))>nn.dropoutFraction);

                nn.a{i} = nn.a{i}.*nn.dropOutMask{i};

            end

        end

      由上面的代码可知,隐含层节点的输出值以dropoutFraction百分比的几率被随机清0(注意此时是在训练阶段,所以是else那部分的代码),既然前向传播时有些隐含节点值被清0了,那么在误差方向传播时也应该有相应的处理,果然,在反向传播函数nnbp()中,有下面的代码:

    if(nn.dropoutFraction>0)

            d{i} = d{i} .* [ones(size(d{i},1),1) nn.dropOutMask{i}];

        end

  也就是说计算节点误差那一项时,其误差项也应该清0。从上面可以看出,使用dropout时,其训练部分的代码更改很少。

  (有网友发私信说,反向传播计算误差项时可以不用乘以dropOutMask{i}矩阵,后面我仔细看了下bp的公式,一开始也感觉不用乘有道理。因为源码中有为:

for i = 1 : (n - 1)
    if i+1==n
        nn.dW{i} = (d{i + 1}' * nn.a{i}) / size(d{i + 1}, 1);
    else
    nn.dW{i} = (d{i + 1}(:,2:end)' * nn.a{i}) / size(d{i + 1}, 1); 
    end
end

  代码进行权重更新时,由于需要乘以nn.a{i},而nn.a{i}在前向过程中如果被mask清掉的话(使用了dropout前提下),则已经为0了。但其实这时错误的,因为对误差

敏感值作用的是与它相连接的前一层权值,并不是本层的权值,而本层的输出a只对它的下一层权值更新有效。)        

  再来看看测试部分,测试部分如hintion论文所说的,采用mean network,也就是说前向传播时隐含层所有节点的输出同时减小dropoutFraction百分比,即保留(1- dropoutFraction)百分比,代码依旧是上面贴出的nnff()函数里满足if(nn.testing)的部分:

    if(nn.dropoutFraction > 0)

            if(nn.testing)

                nn.a{i} = nn.a{i}.*(1 - nn.dropoutFraction);

            else

                nn.dropOutMask{i} = (rand(size(nn.a{i}))>nn.dropoutFraction);

                nn.a{i} = nn.a{i}.*nn.dropOutMask{i};

            end

        end

   上面只是个简单的droput实验,可以用来帮助大家理解dropout的思想和使用步骤。其中网络的参数都是采用toolbox默认的,并没有去调整它,如果该实验将训练样本增大,比如6w张,则参数不变的情况下使用了dropout的识别率还有可能会降低(当然这很有可能是其它参数没调到最优,另一方面也说明在样本比较少的情况下,droput确实可以防止过拟合),为了体现droput的优势,这里我只用了2000张训练样本。

 

  参考资料:

  Hinton, G. E., et al. (2012). "Improving neural networks by preventing co-adaptation of feature detectors." arXiv preprint arXiv:1207.0580.

      https://github.com/rasmusbergpalm/DeepLearnToolbox

     【面向代码】学习 Deep Learning(一)Neural Network

 

 

 

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Deep learning:四十一(Dropout简单理解) 的相关文章

  • 一次分配多个字段的聪明方法?

    由于遗留函数调用 我有时被迫编写像这样的丑陋的包装器 function return someWrapper someField a someField a b someField b and so on realistically it
  • 继续在 Matlab 中一遍又一遍地播放声音?

    我正在尝试创建一个 MATLAB 程序来每隔几分钟一遍又一遍地播放声音 现在我将其设置为每隔几秒播放一次 只是为了消除系统中的一些错误 但是 当我的程序尝试重播声音时 我收到此错误 Error using gt audioplayer au
  • 如何将条形图的 XtickLabels 向左移动?

    我目前正在尝试创建频率直方图 为此 我必须创建一个条形图 条形图之间没有空格 然而 这集中于XTickLabels在酒吧的中间 由于它是一个直方图 我希望数值位于每个条形之间的线上 以便它可以直观地指示间隔 本质上 我需要将所有刻度标签移至
  • 对多个属性使用一种设置方法 MATLAB

    我有几个属性基本上使用相同的属性set method classdef MyClass properties A B end methods function mc MyClass a b Constructor mc A a mc B b
  • 如何在Matlab中自定义轮廓线?

    我正在准备一个等高线图 我应该在其中突出显示特定级别的等高线 例如 我的轮廓线值位于 1 和 1 之间 我想突出显示与值 0 相对应的线 我尝试使用以下过程来执行此操作 M c contourf longitude latitude del
  • 如何让MCR启动时间快

    我将 matlab 程序转换为 net 程序集 即 dll 文件 我制作了一个控制台 C 应用程序 添加了 dll 文件并从 php 调用它 每次调用 exe 时都会调用 MCR 如何使 MCR 在服务器启动时初始化 并且即使在一段时间后调
  • 在 MATLAB 中使用 FFT 的频率响应

    这是场景 使用频谱分析仪 我有输入值和输出值 样本数是32000采样率为2000样本 秒 输入是正弦波50 hz 输入为电流 输出为压力 单位 psi 我如何使用 MATLAB 根据这些数据计算频率响应 使用 MATLAB 中的 FFT 函
  • Matlab:如何显示数组的“真实”值?

    我有一个在脚本中计算的向量 计算后 我将值显示到命令窗口 显示如下 finalResults 1 0e 05 0 0001 0 0 0005 0 0002 0 0001 0 0027 0 0033 0 0001 0 0000 0 0000
  • 按元素出现的频率对数组元素进行排序

    是否可以在 matlab octave 中使用sort函数根据元素的相对频率对数组进行排序 例如数组 m 4 4 4 10 10 10 4 4 5 应该产生这个数组 5 10 10 10 4 4 4 4 4 5是出现频率较低的元素 位于顶部
  • 两个 y 轴与相同的 x 轴[重复]

    这个问题在这里已经有答案了 可能的重复 在单个图中绘制 4 条曲线 具有 3 个 y 轴 https stackoverflow com questions 1719048 plotting 4 curves in a single plo
  • Matlab - 如果值包含xxx,则删除元胞数组中的行

    在 Matlab 中 如何删除包含变量字符串的元胞数组中的元胞 假设我的元胞数组是 C svnTrunk RadarLib radarlb utilities scatteredInterpolant m C svnTrunk RadarL
  • 如何使用matlab生成不同频率的正弦波?

    对于我的项目 我需要使用 matlab 生成一个正弦波 它有 100 000 个样本 并且频率在每 10 000 个样本后随机变化 采样率和频率可以根据方便而定 matlab中有没有函数可以生成这个 好的另一个例子 生成 5 个随机频率 r
  • 为什么旋转 3D 点云后顶点法线会翻转?

    我有两个人脸 3D 点云样本 蓝色点云表示目标面 红色点云表示模板 下图显示目标面和模板面在不同方向上对齐 目标面大致沿 x 轴 模板面大致沿 y 轴 Figure 1 The region around the nose is displ
  • 如何找到在matlab中重复的矩阵的每一行的索引?

    我想找到矩阵中所有有重复项的行的索引 例如 A 1 2 3 4 1 2 3 4 2 3 4 5 1 2 3 4 6 5 4 3 要返回的向量将是 1 2 4 很多类似的问题建议使用unique函数 我已经尝试过 但我能得到的最接近我想要的功
  • 使用 MATLAB 进行线路跟踪

    我有一个图像 我想将其转换为逻辑图像 包括线条为黑色 背景为白色 当然 可以使用阈值方法来实现这一点 但我不想使用这种方式来做到这一点 我想通过使用线路跟踪方法或类似的方法来检测它 这是关于视网膜血管检测的 我找到了一个article ht
  • MATLAB 特征函数

    我很好奇哪里可以找到完整的描述FEATURE功能 它接受哪些论点 没有找到文档 我只听说过memstats and getpid 还要别的吗 gt gt which feature built in undocumented 注意 更完整的
  • 黑白随机着色的六角格子

    我正在尝试绘制一个 10 000 x 10 000 随机半黑半白的六边形格子 我不知道如何将该格子的六边形随机填充为黑色和白色 这是我真正想要从这段代码中得到的示例 但我无法做到 https i stack imgur com RkdCw
  • 理解高斯混合模型的概念

    我试图通过阅读在线资源来理解 GMM 我已经使用 K 均值实现了聚类 并且正在了解 GMM 与 K 均值的比较 以下是我的理解 如有错误请指出 GMM 类似于 KNN 在这两种情况下都实现了聚类 但在 GMM 中 每个簇都有自己独立的均值和
  • 了解 fminunc 参数和匿名函数、函数处理程序

    请多多包涵 问题在最后 我试图找出 fminunc 调用方式的差异 这个问题源于 Andrew Ng 在他的 Coursera 机器学习课程中的第 3 周材料 我正在回答这个问题 Matlab Andrew Ng 机器学习课程中 t cos
  • 在matlab中绘制给定区域内(两个圆之间)的向量场

    我想在 Matlab 中绘制下面的向量场 u cos x x 0 y y 0 v sin x x 0 y y 0 我可以在网格中轻松完成 例如 x 和 y 方向从 2 到 2 x 0 2 y 0 1 x y meshgrid 2 0 2 2

随机推荐

  • Sqlserver中如何快速写入千万级测试数据

    数据库结构 id int username nvarchar 50 password nvarchar 50 addtime datetime token nvarchar 50 roleid int 一 程序中写for循环 实测一分钟写入
  • STM32_3(GPIO)

    一 GPIO简介 GPIO General Purpose Input Output 通用输入输出口 8种输入输出模式 输出模式可控制端口输出高电平 驱动LED 蜂鸣器 模拟通信协议输出时许等 输入模式可读取端口的高低电平或电压 用于读取按
  • Qt扩展-KDDockWidgets 简介及配置

    Qt扩展 KDDockWidgets 简介及配置 一 概述 二 编译 KDDockWidgets 库 1 Cmake Gui 中选择源文件和编译后的路径 2 点击Config 配置好编译器 3 点击Generate 4 在存放编译的文件夹输
  • Win10+OpenCV2.4.13+VS2013+CUDA7.5配置教程

    首先说明一下 OpenCV2 3 13之前的版本不支持CUDA7 5 因此配置总是会出问题 在OpenCV官网下载OpenCV2 4 13版本 此版本支持CUDA7 5 另外OpenCV2 4 13是支持VS2013的 但不清楚支不支持VS
  • 力扣:旋转数组(Java)

    给你一个数组 将数组中的元素向右轮转 k 个位置 其中 k 是非负数 class Solution public void rotate int nums int k int n nums length k n rotate 2 nums
  • MySQL脏读、不可重复读、幻读

    MySQL脏读 不可重复读 幻读 事务的特性 ACID 原子性 Atomicity 指处于同一个事务中的多条语句是不可分割的 即一个事务内的所有语句 要么全部成功要么全部失败 一致性 Consistency 事务必须使数据库从一个一致性状态
  • gpio上拉下拉区别

    gpio上拉下拉区别 GPIO是一颗芯片 MCU 必须具备的最基本外设功能 GPIO通常有三种状态 高电平 低电平和高阻态 高阻态换句话说就是断开状态或浮空态 因此上拉和下拉其中一个强大的理由就是为了防止输入端悬空 使其有确定的状态 减弱外
  • 【经典】修改SpringBoot的默认服务器Tomcat,替换Tomcat

    以下将介绍如何替换掉SpringBoot默认服务器Tomcat 我们将从两个案例 替换为Jetty和替换为UnderTow Tomcat是目前较流行的web容器 但过于臃肿 Jetty是个内嵌WEB容器 支持长连接 如聊天等长时间保持连接
  • 图论----同构图(详解)

    图论当中的术语 假设G V E 和G1 V1 E1 是两个图 如果存在一个双射m V V1 使得对所有的x y V均有xy E等价于m x m y E1 则称G和G1是同构的 这样的一个映射m称之为一个同构 如果G G1 则称他为一个自同构
  • JS:各种遍历方式总结

    js的遍历方式真的是有很多 有用于遍历数组的 也有用于遍历对象的 各种方式有什么样的应用场景 如何选择恰当的遍历方式 很容易就让人迷糊 所以做一下总结吧 第一种 普通for循环 直接遍历出的是索引 注意每次遍历都需要获取一次arr的长度 f
  • c#运算符

    一运算符 运算符是一种告诉编译器执行特定的数学或逻辑操作的符号 C 有丰富的内置运算符 分类如下 1 算术运算符 下表显示了 C 支持的所有算术运算符 假设变量 A 的值为 10 变量 B 的值为 20 则 例如 假如A 21 B 10 i
  • 网址与域名的区别

    目录 一 网址与域名的区别 二 主域名与子域名 一 网址与域名的区别 以网址https www baidu com为例 网址由协议加域名组成所以协议是https 域名 www baidu com 区别 1 包含与被包含的关系 网址包含域名
  • 如何判断某个值更改就让按钮可用_【教程】 如何创建自己的 NFT? 这里有份教程, 请收下!...

    AtomicHub 提供了 NFT 创建工具 让任何人都可以创建自己的 NFT 非同质代币 喜欢 NFT 的小伙伴们 一起搞起来吧 除了 WAX 之外 目前 AtomicAssets 也支持了 EOS 区块链 所以 两条链上的朋友都可以参考
  • pandas的定义以及pandas的DataFrame的初步使用(二)

    补充 Series自动对齐 当多个series对象之间进行运算的时候 如果不同series之间具有不同的索引值 那么运算会自动对齐不同索引值的数据 如果某个series没有某个索引值 那么最终结果会赋值为NaN 示例 DataFrame对象
  • 计算机网络c类网络划分子网介绍,IP地址的子网划分详解

    原标题 IP地址的子网划分详解 来源 今日头条北京炫亿时代 一 子网划分基础 1 子网划分的若干个好处 减少网络流量 提高网络性能 简化管理 可以更为灵活的形成大覆盖范围的网络 2 你最好遵循以下步骤来进行子网划分 确认所需要的网络ID数
  • 文件分片上传demo

    知识点 File File 接口也继承了 Blob 接口的属性 File 接口没有定义任何方法 但是它从 Blob 接口继承了以下方法 Blob slice start end contentType new File 字符串数组 file
  • 【C语言】错题本(4)

    一 题目及选项 答案解析 知识点 字符型在内存中的数据存储 char类型数据在内存中的图示 unsigned char类型数据在内存中的图示 二 题目及选项 答案解析 A B C D 三 题目及选项 答案解析 数据在计算机中是先转换成补码
  • @ApiImplicitParam注解使用说明

    ApiImplicitParam注解使用说明 ApiImplicitParam是Swagger框架中的一个注解 用于描述请求参数的详细信息 它可以帮助开发人员生成API文档 并提供给用户更清晰的接口信息 以下是对 ApiImplicitPa
  • 别再乱写了,Controller 层代码这样写才足够规范!

    前言 本篇主要要介绍的就是controller层的处理 一个完整的后端请求由4部分组成 接口地址 也就是URL地址 请求方式 一般就是get set 当然还有put delete 请求数据 request 有head跟body 响应数据 r
  • Deep learning:四十一(Dropout简单理解)

    前言 训练神经网络模型时 如果训练样本较少 为了防止模型过拟合 Dropout可以作为一种trikc供选择 Dropout是hintion最近2年提出的 源于其文章Improving neural networks by preventin