深度学习_调参经验

2023-11-09

面对一个图像分类问题,可以有以下步骤:

1.建立一个简单的CNN模型,一方面能够快速地run一个模型,以了解这个任务的难度

卷积层1:卷积核大小3*3,卷积核移动步长1,卷积核个数64,池化大小2*2,池化步长2,池化类型为最大池化,激活函数ReLU。

卷积层2:卷积核大小3*3,卷积核移动步长1,卷积核个数128,池化大小2*2,池化步长2,池化类型为最大池化,激活函数ReLU。

卷积层3:卷积核大小3*3,卷积核移动步长1,卷积核个数256,池化大小2*2,池化步长2,池化类型为最大池化,激活函数ReLU。

全连接层:隐藏层单元数1024,激活函数ReLU。

分类层:隐藏层单元数10,激活函数softmax。

参数初始化,所有权重矩阵使用random_normal(0.0, 0.001),所有偏置向量使用constant(0.0)。使用cross entropy作为目标函数,使用Adam梯度下降法进行参数更新,学习率设为固定值0.001。

 

该网络是一个有三层卷积层的神经网络,能够快速地完成图像地特征提取。全连接层用于将图像特征整合成分类特征,分类层用于分类。cross entropy也是最常用的目标函数之一,分类任务使用cross entropy作为目标函数非常适合。Adam梯度下降法也是现在非常流行的梯度下降法的改进方法之一,学习率过大会导致模型难以找到较优解,设置过小则会降低模型训练效率,因此选择适中的0.001。这样,我们最基础版本的CNN模型就已经搭建好了,接下来进行训练和测试以观察结果。

训练5000轮,观察到loss变化曲线、训练集准确率变化曲线和验证集准确率变化曲线如下图。测试集准确率为69.36%

 

2.使用数据增强技术(data augmentation),主要是在训练数据上增加微小的扰动或者变化,一方面可以增加训练数据,从而提升模型的泛化能力,另一方面可以增加噪声数据,从而增强模型的鲁棒性。主要的数据增强方法有:翻转变换 flip、随机修剪(random crop)、色彩抖动(color jittering)、平移变换(shift)、尺度变换(scale)、对比度变换(contrast)、噪声扰动(noise)、旋转变换/反射变换 (rotation/reflection)等,可以参考Keras的官方文档 [2] 。获取一个batch的训练数据,进行数据增强步骤之后再送入网络进行训练。

我主要做的数据增强操作有如下方面:

图像切割:生成比图像尺寸小一些的矩形框,对图像进行随机的切割,最终以矩形框内的图像作为训练数据。

图像翻转:对图像进行左右翻转。

图像白化:对图像进行白化操作,即将图像本身归一化成Gaussian(0,1)分布。

结果分析:我们观察训练曲线和验证曲线,很明显地发现图像白化的效果好,其次是图像切割,再次是图像翻转,而如果同时使用这三种数据增强技术,不仅能使训练过程的loss更稳定,而且能使验证集的准确率提升至82%左右,提升效果十分明显。而对于测试集,准确率也提升至80.42%。说明图像增强确实通过增加训练集数据量达到了提升模型泛化能力以及鲁棒性的效果,从准确率上看也带来了将近10%左右的提升,因此,数据增强确实有很大的作用。但是对于80%左右的识别准确率我们还是不够满意,接下来继续调参。

 

3.从模型入手,使用一些改进方法

接下来的步骤是从模型角度进行一些改进,这方面的改进是诞生论文的重要区域,由于某一个特定问题对某一个模型的改进千变万化,没有办法全部去尝试,因此一般会实验一些general的方法,比如批正则化(batch normalization)、权重衰减(weight decay)。我这里实验了4种改进方法,接下来依次介绍。

权重衰减(weight decay):对于目标函数加入正则化项,限制权重参数的个数,这是一种防止过拟合的方法,这个方法其实就是机器学习中的l2正则化方法,只不过在神经网络中旧瓶装新酒改名为weight decay [3]。

dropout:在每次训练的时候,让某些的特征检测器停过工作,即让神经元以一定的概率不被激活,这样可以防止过拟合,提高泛化能力 [4]。

批正则化(batch normalization):batch normalization对神经网络的每一层的输入数据都进行正则化处理,这样有利于让数据的分布更加均匀,不会出现所有数据都会导致神经元的激活,或者所有数据都不会导致神经元的激活,这是一种数据标准化方法,能够提升模型的拟合能力 [5]。

LRN:LRN层模仿生物神经系统的侧抑制机制,对局部神经元的活动创建竞争机制,使得响应比较大的值相对更大,提高模型泛化能力。

 

结果分析:我们观察训练曲线和验证曲线,随着每一个模型提升的方法,都会使训练集误差和验证集准确率有所提升,其中,批正则化技术和dropout技术带来的提升非常明显,而如果同时使用这些模型提升技术,会使验证集的准确率从82%左右提升至88%左右,提升效果十分明显。而对于测试集,准确率也提升至85.72%。我们再注意看左图,使用batch normalization之后,loss曲线不再像之前会出现先下降后上升的情况,而是一直下降,这说明batch normalization技术可以加强模型训练的稳定性,并且能够很大程度地提升模型泛化能力。所以,如果能提出一种模型改进技术并且从原理上解释同时也使其适用于各种模型,那么就是非常好的创新点,也是我想挑战的方向。现在测试集准确率提升至85%左右,接下来我们从其他的角度进行调参。

 

4.变化的学习率,进一步提升模型性能

在很多关于神经网络的论文中,都采用了变化学习率的技术来提升模型性能,大致的想法是这样的:

首先使用较大的学习率进行训练,观察目标函数值和验证集准确率的收敛曲线。

如果目标函数值下降速度和验证集准确率上升速度出现减缓时,减小学习率。

循环步骤2,直到减小学习率也不会影响目标函数下降或验证集准确率上升为止。

为了进行对比实验,实验1只使用0.01的学习率训练,实验2前10000个batch使用0.01的学习率,10000个batch之后学习率降到0.001,实验3前10000个batch使用0.01的学习率,10000~20000个batch使用0.001的学习率,20000个batch之后学习率降到0.0005。同样都训练5000轮,观察到loss变化曲线、训练集准确率变化曲线和验证集准确率变化曲线对比如下图。

 

 
 

结果分析:我们观察到,当10000个batch时,学习率从0.01降到0.001时,目标函数值有明显的下降,验证集准确率有明显的提升,而当20000个batch时,学习率从0.001降到0.0005时,目标函数值没有明显的下降,但是验证集准确率有一定的提升,而对于测试集,准确率也提升至86.24%。这说明,学习率的变化确实能够提升模型的拟合能力,从而提升准确率。学习率在什么时候进行衰减、率减多少也需要进行多次尝试。一般在模型基本成型之后,使用这种变化的学习率的方法,以获取一定的改进,精益求精。

5.加深网络层数,会发生什么事情?

现在深度学习大热,所以,在计算资源足够的情况下,想要获得模型性能的提升,大家最常见打的想法就是增加网络的深度,让深度神经网络来解决问题,但是简单的网络堆叠不一定就能达到很好地效果,抱着深度学习的想法,我按照plain-cnn模型 [6],我做了接下来的实验。

卷积层1:卷积核大小3*3,卷积核移动步长1,卷积核个数16,激活函数ReLU,使用batch_normal和weight_decay,接下来的n层,卷积核大小3*3,卷积核移动步长1,卷积核个数16,激活函数ReLU,使用batch_normal和weight_decay。

卷积层2:卷积核大小3*3,卷积核移动步长2,卷积核个数32,激活函数ReLU,使用batch_normal和weight_decay,接下来的n层,卷积核大小3*3,卷积核移动步长1,卷积核个数32,激活函数ReLU,使用batch_normal和weight_decay。

卷积层3:卷积核大小3*3,卷积核移动步长2,卷积核个数64,激活函数ReLU,使用batch_normal和weight_decay,接下来的n层,卷积核大小3*3,卷积核移动步长1,卷积核个数64,激活函数ReLU,使用batch_normal和weight_decay。

池化层:使用全局池化,对64个隐藏单元分别进行全局池化。

全连接层:隐藏层单元数10,激活函数softmax,使用batch_normal和weight_decay。

为了进行对比实验,进行4组实验,每组的网络层数分别设置8,14,20和32。同样都训练5000轮,观察到loss变化曲线、训练集准确率变化曲线和验证集准确率变化曲线对比如下图。

 
 

结果分析:我们惊讶的发现,加深了网络层数之后,性能反而下降了,达不到原来的验证集准确率,网络层数从8层增加到14层,准确率有所上升,但从14层增加到20层再增加到32层,准确率不升反降,这说明如果网络层数过大,由于梯度衰减的原因,导致网络性能下降,因此,需要使用其他方法解决梯度衰减问题,使得深度神经网络能够正常work。

 

6.终极武器,残差网络

2015年,Microsoft用残差网络 [7] 拿下了当年的ImageNet,这个残差网络就很好地解决了梯度衰减的问题,使得深度神经网络能够正常work。由于网络层数加深,误差反传的过程中会使梯度不断地衰减,而通过跨层的直连边,可以使误差在反传的过程中减少衰减,使得深层次的网络可以成功训练,具体的过程可以参见其论文[7]。

通过设置对比实验,观察残差网络的性能,进行4组实验,每组的网络层数分别设置20,32,44和56。观察到loss变化曲线和验证集准确率变化曲线对比如下图。

 
 

结果分析:我们观察到,网络从20层增加到56层,训练loss在稳步降低,验证集准确率在稳步提升,并且当网络层数是56层时能够在验证集上达到91.55%的准确率。这说明,使用了残差网络的技术,可以解决梯度衰减问题,发挥深层网络的特征提取能力,使模型获得很强的拟合能力和泛化能力。当我们训练深度网络的时候,残差网络很有可能作为终极武器发挥至关重要的作用。



作者:骄傲的少年233
链接:https://www.jianshu.com/p/96acc5e5deb1
来源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

转载于:https://www.cnblogs.com/mfryf/p/11393653.html

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习_调参经验 的相关文章

随机推荐

  • Python爬虫入门案例6:scrapy的基本语法+使用scrapy进行网站数据爬取

    几天前在本地终端使用pip下载scrapy遇到了很多麻烦 总是报错 花了很长时间都没有解决 最后发现pycharm里面自带终端 狂喜 于是直接在pycharm终端里面写scrapy了 这样的好处就是每次不用切换路径了 pycharm会直接把
  • 网络层协议------IP协议

    这里写目录标题 IP协议 基本概念 协议头格式 网段划分 特殊的ip地址 私网ip地址和公网ip地址 ip地址的数量限制 路由 IP协议 IP协议 其实就是TCP IP协议中对于网络层的一个协议 注意IP协议是TCP IP协议族中最为核心的
  • 查看localstorage容量

    1 function if window localStorage console log 浏览器不支持localStorage var size 0 for let item in window localStorage if windo
  • 电路实验---全桥整流电路

    全桥整流电路作用 采用四个二极管将交流电转换成直流电 全桥整流电路图 全桥整流电路原理 220V交流电经过变压器T1降压输出电压U2 当U1正半周从L1经过T1 到达L2 极性表现为上正下负 此时电流流过方向 L2上正 gt VD1 gt
  • uniapp的onPullDownRefresh失效 不执行

    需要在 pages json 里 找到的当前页面的pages节点 并在 style 选项中开启 enablePullDownRefresh path pages install uploadImg style navigationBarTi
  • 数据分析-数据集划分-交叉验证

    目录 交叉验证 k折交叉验证 k fold cross validation 分层k折交叉验证 stratified cross validation Sklearn的实现 k折交叉分类器 分层k折交叉分类器 打乱数据集后再划分 模型验证
  • angular 代理http到https

    api target https www XXXX com changeOrigin true public target https www XXXX com changeOrigin true
  • uniapp switch按钮的使用

    switch使用官方文档 https uniapp dcloud io component switch 想要改变switch按钮的大小
  • Cloudera CDH 5.1版本的Hive与LDAP-2.4.44集成

    文章目录 0 没集成之前测试 1 安装LDAP 2 4 44 2 增加组织 3 添加用户 4 CDH配置LDAP 5 beeline测试1 5 beeline测试2 0 没集成之前测试 可以看到没有输入用户密码可以登录 1 安装LDAP 2
  • OpenGL学习笔记(十)-几何着色器-实例化

    参考网址 LearnOpenGL 中文版 4 7 几何着色器 4 7 1 基本概念 1 顶点和片段着色器之间有一个可选的几何着色器 几何着色器的输入是一个图元 如点或三角形 的一组顶点 顶点发送到下一着色器之前可对它们随意变换 将顶点变换为
  • 【Web Crawler】Python 的 urllib.request 用于 HTTP 请求

    如果您需要使用 Python 发出 HTTP 请求 那么您可能会发现自己被引导至 brilliantrequests库 尽管它是一个很棒的库 但您可能已经注意到它并不是 Python 的内置部分 如果您出于某种原因更喜欢限制依赖项并坚持使用
  • qt 中lineEdit->setText()输出double

    在qt中需要将获取到的double 值在ui界面上显示出来 便于观察 但是lineEdit控件的setText 要求的参数是string 所以我们先要进行转化 将double 转化为string QString QString number
  • 计算方法实验(二):龙贝格积分法

    Romberg积分法数学原理 利用复化梯形求积公式 复化辛普生求积公式 复化柯特斯求积公式的误差估计式计算积分 a b f x
  • 有效的数独

    LeetCode 之 有效的数独 判断一个 9x9 的数独是否有效 只需要根据以下规则 验证已经填入的数字是否有效即可 数字 1 9 在每一行只能出现一次 数字 1 9 在每一列只能出现一次 数字 1 9 在每一个以粗实线分隔的 3x3 宫
  • python机器学习-乳腺癌细胞挖掘(基于真实美国临床数据)

    随着人们生活水平提高 大家不仅关注如何生活 而且关注如何生活得更好 在这个背景下 精准治疗和预测诊断成为当今热门话题 据权威医学资料统计 全球大约每13分钟就有一人死于乳腺癌 乳腺癌已成为威胁当代人健康的主要疾病之一 并且随着发病率的增加
  • Error mounting /dev/sr0 at /media/ VBox

    重新安装Linux映像 sudo apt get install reinstall linux image uname r
  • IBM WAS简介

    IBM WAS简介 IBM WAS 的全称是 IBM WebSphere Application Server 和 Weblogic 一样 是 当前主流的 App Server 应用服务器 之一 App Server 是运行 Java 企
  • DataWhale集成学习(下)——Task14 案例分析1幸福感预测

    目 录 背景介绍 数据信息 评价指标 案例 背景介绍 数据来源于国家官方的 中国综合社会调查 CGSS 文件中的调查结果中的数据 数据来源可靠可依赖 用139维的信息来预测其对幸福感的影响 数据信息 139维 8000余组 幸福感预测值为1
  • 【indemind双目惯性相机调试】sudo后找不到命令问题,环境变量问题

    问题1 报错 kanhao100 ubuntu x86 roslaunch imsee ros wrapper start launch RLException start launch is neither a launch file i
  • 深度学习_调参经验

    面对一个图像分类问题 可以有以下步骤 1 建立一个简单的CNN模型 一方面能够快速地run一个模型 以了解这个任务的难度 卷积层1 卷积核大小3 3 卷积核移动步长1 卷积核个数64 池化大小2 2 池化步长2 池化类型为最大池化 激活函数