Keras 深度学习之猫狗大战

2023-10-26

项目地址和代码:Project_Dogs_vs_Cats
项目详细报告:Report_dogs_vs_cats.pdf
keras 版本:2.1.5

使用滴滴云AI大师码【0212】消费GPU有9折优惠哦!


1.问题定义和数据集获取:
        项目属于计算机视觉领域中的图像分类问题。图像分类的过程非常明确:给定已经标记的数据集,提取特征,训练得到分类器。项目使用Kaggle竞赛提供的Dogs vs. Cats数据集,任务是对给定的猫和狗的图片进行分类,因此是二分类问题(Binary Classification)。
        项目要解决的问题是使用12500张猫和12500张狗的图片作为测试集,训练一个合适的模型,能够在给定的12500张未见过的图像中分辨猫和狗。在机器学习领域,用于分类的策略,包括K均值聚类、支持向量机等,均能够用于处理该二分类问题。但在图像分类领域,神经网络技术具有更加明显的优势,特别是深度卷积神经网络,已成功地应用于图像识别领域。

2.选择度量成功的标准(损失函数):
        
分类准确度(accuracy)和代价函数(Cost Function)是常用的分类评估指标。为了对模型进行更细致的评价,代价函数更加合理。通过代价函数计算结果的值越小,就代表模型拟合的越好。神经网络的代价函数是用于logistic回归的一个泛化。Kaggle官网在此次竞赛中对预测结果的评估采用的是对数损失函数log loss:
                                                  \textrm{LogLoss} = - \frac{1}{n} \sum_{i=1}^n \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)\right],                                                           
        式中, n为测试集图像总数量; \hat{y}_{i}指图像类别为狗的预测概率;如果图像类别为狗, \hat{y}_{i}为1,如果为猫, \hat{y}_{i}为0。对数损失值越小,预测结果越优。

3.确定评估方法:
        
三种常见的评估方法(evaluation protocols):(1)维持一个验证集不变,这通常在数据充足的情况下使用;(2)k-折交叉验证,数据量较少情况下;(2)迭代的k-折交叉验证,数据量较少而要求高精度模型评价时使用。这里选择第一种方法。
        
4.准备数据(数据预处理):
        训练集包括25000张被标记的猫和狗的图片(各占一半)。测试集为12500张未被标记的图片,同样是猫狗图片各占一半。测试集将作为输入来训练分类器。在数据用于训练前须进行预处理,使数据格式与模型相匹配。另外,数据预处理还可以应用数据增强技术,包括图片裁剪或填充、归一化。如果计算资源有限,还可以将图像转化为灰度图像存储。数据集的部分图片如图所示:
 


        训练集图像的尺寸分布如下图所示。通过该散点图,我们可以非常直观地了解数据的构成:绝大部分图像的尺寸都分布在500×500内。猫和狗的图片各自都只有一张图像的尺寸偏离正常范围内,属于异常值。于是下一步,我们首先找出这两幅图像。
 


      这两幅图像如下方左图所示。但是,我们的图像在输入模型训练前,是需要进行数据预处理的,例如我们可能对图像进行尺寸缩放或裁剪以使输入数据格式统一。图像裁剪可能会影响模型的训练,因为它可能将图像的有用信息丢弃。而在这个项目中,我们将只使用尺寸缩放。缩放后的猫狗图像如下方右图所示,我们还是可以很确定地辨别出猫和狗,这说明了图像缩放这种预处理方式并不会对模型分类带来坏影响,并不影响其学习,所以这对异常值是没有必要处理的!
                                                                    

       既然图像的尺寸对模型的训练不会带来影响或者影响甚微,那么便接着探索数据集是否存在其他异常值。由于读取图像时三个通道的值都限制在[0,255]内,分布范围较小,因此图像通道值的最大值和最小值和均值都不会成为异常值。因而,将目光投向通道值的标准差,原因是通过标准差更可能发现异常值。设想一下,假如一幅图像中通道值的标准差很小,这意味着图像越接近于纯色,则其包含的信息量越少,越不容易用于提取可用于分类的特征。为了简化问题,将图像读取为灰度图,这样就只需要分析单通道了。
       下面两幅图像分别是所有猫狗图像中,灰度图模型下,通道值标准差最小的情况。可以看到,虽然图像颜色比较单一,但是我们还是能够通过人眼清晰地识别出猫和狗,因此这部分不存在异常值!
 

       训练数据的异常值还有另一种情况,那就是类别标签标注上的异常值。寻找这些异常值最直接但同时也是最繁琐的方法,就是耐心地将训练集里面的图片都手动检查一遍!但事实上我们可以用其他方式解决,那就是使用预处理模型来进行排查,这将大大减少我们的工作量。基本思路是,使用表现最佳的预处理模型(这里选择在Imagenet上表现很好的Xception)对训练集的图片进行预测,给出预测正确分类概率最高的n个分类。检验猫、狗标签是否在这n个类别之内,若不在则将该图片视为可能的异常值。
       利用Xception预测后,使用 decode_predictions ,设置top参数为n(一开始均设置为30),将预测结果转化,获得预测正确分类概率最高的n个分类。检验猫、狗标签是否在这n个类别之内,若不在则将图像视为可能的异常值。此时,猫狗图片中的可能异常值数量分别为80和23,如下图所示为部分查找得到的可能异常值。
 

                                                                     图a. 猫图片可能异常值(Top参数均为30
 

                                                                  图b. 狗图片可能异常值(Top参数均为30
       可以看到,当top参数均设置为30时,猫图片的异常值统计数量较狗图片多,同时错误统计也多。对于狗图片的异常值统计情况,则几乎没有失误。这可能是由于imagenet在构建数据集时对狗的归类做的比较详细(imagenet上狗的类别有118种,而猫的类别只有7种)。因此,我认为设置top参数时,应该将猫和狗分开讨论,不能一概而论。经过一番尝试,我将猫的top参数设置为35,狗的top参数设置为10,最后获得的部分可能的异常值如下图所示,其中猫有62幅,狗有36幅。相较于完全手动检查一遍,这已经大大降低了工作量。
 


                                                                 图a. 猫图片可能异常值(Top参数均为35
 

                                                                 图b. 狗图片可能异常值(Top参数均为10

       从上面的分析可以看到,猫狗大战数据集的异常值主要是标注上的异常值。在将训练数据输入模型训练之前,需要将这些异常数据进行预处理。我将考虑一下处理方式:(1)将与主题完全无关的图片删除;(2)对于分类错误的图片,修改类别标注(例如本来是狗的图片却标注为猫,这时候就需要将标注修改);(3)将背景复杂的图片进行裁剪。
        这些异常值很难通过数据科学的方法(均值、平均值等)进行描述和发现,在处理的时候选择也只能手动处理。


5.设计一个优于随机预测的简单模型:
       对于图片分类问题,存在一些分类效果很好的深度卷积网络,例如AlexNet、Inception,但在一开始,我们会先使用简单的分类模型进行训练并观察其表现。在该项目中,我首先建立了如下结构的简单模型,它的分类准确性确实要明显优于随机分类:


       可以看到,该简单CNN模型由若干个层“堆叠”而成,且可以分成这三部分:(1)输入层;(2)卷积层和池化层的组合;(3)全连接的多层感知机分类器。在需要处理过拟合问题时通常还会引入dropout层。

6.设计过拟合模型:
       
自己从头设计CNN有时并不是一个好的决定,它既费时而且又不一定能获得优秀的表现。所以,在该项目将采用迁移学习的方法,其中使用的用于fine tune的预训练模型包括 VGG16,ResNet50,Inception v3和Xception。
       本项目中,在使用各个预训练模型时,都采取了同一个流程。首先,尝试使用特征提取方法。第一步,加载去掉顶层分类器的模型(留下卷积层)之结构和权重,并将训练数据输入模型,提取特征向量;第二步,为卷积层添加全连接层和drop out层,输入上一步提取的特征向量进行分类,观察分类结果并评价。第三步,为卷积层添加一个包含卷积核的顶层分类器,将加载的卷积基础层冻结(不冻结顶层分类器的卷积核),重新编译模型。输入图片数据进行训练,得到一个训练过的顶层分类器。之后,尝试使用参数微调方法。在使用特征提取手段时,我们已经得到了一个包含卷积层的顶层分类器。现在,可以将冻结的卷积层解开部分层,与顶层分类器一起训练。参数微调就是微调加载的模型的部分卷积层以及新加入的顶层分类器的权重。
       这里需要注意的是:为了顺利地进行参数微调,模型的各个层都需要以预先训练过的权重为初始值。所以,不能将随机初始化权重的全连接层放在预训练过的卷积层之上,否则会破坏卷积层预训练获得的权重。体现在前面谈到的流程中,就是先使用特征提取方法训练顶层分类器,再基于这个顶层分类器进行参数微调。
       另外,参数微调时不会选择训练整个网络的权重,而只微调位于模型中较深的部分卷积层,这在一定程度上可以防止过拟合。在前面已经提到,因为由底层卷积模块学习到的特征更具一般性,而不是抽象性。
       对于VGG16模型,我首先尝试选择将其卷积模块的末尾3个卷积层冻结然后再进行参数微调,但效果并不理想。于是我又将冻结层扩大至末尾4个卷积层,进行参数微调;对于InceptionV3,选择将249层之前的层冻结;对于Resnet50,选择将168层之前的层冻结;对于Xception,选择将126层之前的层冻结。下图所示表示了VGG16参数微调的设置方法,其他模型的设置原理与此类似。
 


7.调整模型、调整超参数:
       应用参数微调技术时应该在比较低的学习率下训练,这里使用的学习率为0.00001,优化器为RMSprop。较低的学习率可以使训练过程中保持较低的更新幅度,以防止破坏模型卷积层预训练的特征。
       在一开始,我将Batch size 设置为50,但是在训练时发现模型不容易收敛,而且在测试集和验证集上的损失函数值波动较大,引入Keras的回调函数EarlyStopping(EarlyStopping能够让模型在训练时根据设定的条件停止)时,需将patience参数设置为较大的值。故后面将Batch size调整为较大值。但Batch size太大对GPU显存资源也是一个很大的负荷,最终将训练时的Batch size调整为100。
       为了应对模型过拟合问题,我使用了数据增强技术来训练一个新的网络,所以所有epoch中是不会有两次相同的输入的。但是,这些输入仍然是相互关联的,因为它们来自有限的原始图像,数据增强并不能产生新的信息,而只能重新混合现有的信息。因此,这可能不足以完全摆脱过度拟合。为了进一步克服过度拟合,我还将在模型中全连接层分类器的前面添加一个dropout层这是一种正则化手段,不过跟Regularization不同的是,它是通过将训练的层中的部分神经元的输出置零来实现的。在这里,使用的Dropout参数为0.5。
       但在参数微调过程中,我发现Xception很快就停止训练了,因为使用了EarlyStopping回调函数,而且val_acc和val_loss都呈现出模型的性能在下降的趋势,最后参数微调效果也不理想。于是我增大了数据增强的幅度,重新跑了一遍程序,这时Xception模型依然很快就停止训练,但是val_acc和val_loss却是往好的趋势变化。在这之后,我果断将EarlyStopping中的stop参数调整为较大的整数,从而增加Xception训练的epoch数量。
       所有模型的训练都在Jupyter Notebook中完成。程序批量读取图片,并根据具体模型进行相应的数据预处理,包括数据增加技术。各模型训练过程如图中所示,由于参数微调时的模型的顶层分类器之权重是由特征提取过程中训练过的,因此在训练的一开始,模型就已经表现得比较好了。后三个模型在整个参数微调过程中,模型的改进幅度并不明显。最终训练得到的结果准确率已经很高了。
 

       可以看到模型表现相对较优的是Xception、Inception V3,它们在验证集上的分类准确率均高于99%,损失值均低于0.03。而Inception V3在验证集上的分类准确率已经达到了0.993%,logloss值为0.025左右,参数微调效果非常理想,和Xception十分接近!ResNet50虽然也达到了不错的分类准确性,但结果却和Inception V3想比却处于劣势。初步猜测是模型训练时参数设置不够合理导致的,例如数据增强的设置,在往后的学习中会进行验证。VGG16的表现相对于前面三种模型则处于明显的劣势,毕竟VGG16相对较老,网络结构也较浅。

8.特征融合:
       
使用预训练的模型进行参数微调确实比使用自己搭建的模型能够获得更小的损失函数值,但是训练过程还是十分缓慢的。在对每一个模型进行参数微调前,我都有都尝试直接使用特征向量来进行分类,而且个别模型的分类结果十分理想,包括ResNet50、InceptionV3和Xception。因此,尝试将提取的多个较优模型的特征进行组合,是一个非常可行的方案!
       整个构建过程很简单。首先,我们选择在前面的训练中表现最好的三个模型ResNet50、InceptionV3和Xception作为特征融合的对象模型。接着,先将数据输入各个模型中,提取特征向量并保存下来。然后,将对应同一个训练样本的来自这三个模型的三个特征向量在长度上进行堆叠,最终得到共12500个维度为3×2048=6144的特征向量。
       最后,构建一个简单的神经网络,输入上一步特征融合得到的特征向量集合,进行训练。应用特征融合技术构建的模型如图所示(这部分代码和思路借鉴自杨培文老师,由衷感谢!)。
 

       运用特征融合方法的训练曲线如下图所示,最终在验证集上的准确率很轻松地就达到了99.6%,val_loss也只有0.012。
 




 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras 深度学习之猫狗大战 的相关文章

  • 【ES6】Iterator迭代器

    文章目录 一 Iterator 二 用法详解 1 基本用法 2 遍历不可迭代对象 2 1 原生具备Iterator接口的数据结构 2 2 遍历不可迭代的对象 2 3 Generator 遍历不可迭代对象 总结 一 Iterator 遍历器
  • 怎么设置html代码中的编码格式,html怎么设置编码

    在html中 可以使用meta标签来设置编码 语法格式 meta标签提供了HTML文档的元数据 元数据不会显示在客户端 但是会被浏览器解析 而charset属性用于定义文档的字符编码 本教程操作环境 windows7系统 HTML5版 De
  • Linux RTC 驱动实验

    目录 Linux 内核RTC 驱动简介 I MX6U 内部RTC 驱动分析 RTC 时间查看与设置 RTC 也就是实时时钟 用于记录当前系统时间 对于Linux 系统而言时间是非常重要的 就和我们使用Windows 电脑或手机查看时间一样
  • SpringMvc-json处理

    SpringMvc json处理 在 JSON 中 使用以下两种方式来表示数据 Object 对象 键 值对 名称 值 的集合 使用花括号 定义 在每个键 值对中 以键开头 后跟一个冒号 最后是值 多个键 值对之间使用逗号 分隔 例如 na
  • CTFShow Web12

    先打开靶机 看到下面的网站 发现啥都点不了 所有按钮都没有实际的动作 根据没啥思路就抓个包 扫描个路径的原则 可以看到有robots txt 访问之 得到关键提示路径 admin 访问之后出现提示框 要求输入账号和密码 账号显然是admin
  • 使用cefsharp在winform中嵌套浏览器,解决程序闪退问题,你也可以做一个红芯浏览器^v^

    使用cefsharp在winform中嵌套浏览器 简单使用cefsharp在winform中嵌套浏览器 在上一节 我们学习了如何简单地在winform中嵌入chromium浏览器 我在使用这个开发项目时 需要点击一个按钮 弹出嵌入浏览器的窗
  • 测试bug 类型及原因分类

    空间管理 测试bug 类型及原因分类 Bug类型 QA设置 代码错误 界面优化 设计缺陷 配置相关 安装部署 安全相关 性能问题 标准规范 测试脚本 其他 bug状态更新备注 DE更新 设计如此 重复bug 外部原因 已解决 无法重现 延期
  • 怎么将英文网页整篇翻译成中文

    作为一个实打实的英语渣渣 这个技能还是需要必备的 英语大神勿笑 当然英语遛的大神是不会知道我们英语渣渣的苦的 话不对说 今天我就跟大家分享一下将一个整篇的英文网页翻译成中文的小技巧 大神跳过 工具 这么牛逼的操作当然要用到Google的Ch
  • Gson实现接口自定义反序列化

    在项目中同样遇到了对json字符串进行反序列化时 遇到了多态情况下 无法找到对应类 所以写这篇文章来mark一下 首先抛出原始代码 再给上解决方案 原始代码 原始json串 type int specs min 1 max 12 unit
  • iOS开发设置状态栏字体颜色

    状态栏的字体为黑色 UIStatusBarStyleDefault 状态栏的字体为白色 UIStatusBarStyleLightContent 一 在info plist中 将View controller based status ba
  • 蓝桥杯C/C++省赛:剪格子

    目录 题目描述 思路分析 AC代码 题目描述 如图p1 jpg所示 3 x 3 的格子中填写了一些整数 我们沿着图中的红色线剪开 得到两个部分 每个部分的数字和都是60 本题的要求就是请你编程判定 对给定的m x n 的格子中的整数 是否可
  • Tensorflow——端到端车牌识别(数据制作、训练、评估、预测)

    利用周末时间断断续续实现端到端车牌识别项目 具备完整的数据集 数据制作 训练 评估 预测业务 项目特点 采用tensorflow中的keras库 训练时数据生成器data generator 对学习keras API有一些参考意义 项目地址
  • TCP窗口字段理解

    TCP窗口字段理解 转载自 https blog 51cto com shjrouting 1612855 TCP数据传输过程中 序列号增长的单元是包的个数 解释 这是初学者最常犯的一个错误 原因是绝大多数老师为了方便学生理解 刚开始举例子
  • C++14几种计时方法的对比

    1 C 14 版本 程序如下 include
  • Mysql中索引的最左前缀原则图文剖析(全)

    目录 前言 1 定义 2 全索引顺序 3 部分索引顺序 3 1 正序 3 2 乱序 4 模糊索引 5 范围索引 前言 之所以有这个最左前缀索引 归根结底是mysql的数据库结构 B 树 在实际问题中 比如 索引index a b c 有三个
  • 多个空格的正则表达式

    一 借鉴别人 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
  • webpack5

    为什么有webpack web1 0阶段 还没有明确前端岗位 主要职责是编写静态页面 用Js来进行表单验证或动画效果 为了在页面上动态填充数据 后面也出现了php jsp这种开发模式 web2 0阶段 伴随ajax的诞生 不止负责展示界面

随机推荐

  • C# this.InvokeRequired

    C 为控件单独开辟了一个线程 当另外一个线程的方法需要修改控件或者调用控件的方法时 需要通过控件的InvokeRequired方法来进行 InvokeRequired
  • tp5如何跨数据库查询

    gt 当然前提是 这两个的数据库都在同一个服务器上才可以的 如果没有在同一个服务器上 gt 解决办法 mysql设置一下权限 a 可以对b进行select的操作权限 tp5使用原生查询 代码如下 admin Db query SELECT
  • orangepi5使用sata ssd启动系统

    使用sata ssd启动香橙派官方的Ubuntu系统 以Orangepi5 1 1 0 ubuntu jammy server linux5 10 110为例 因为烧录系统到外接的ssd需要另一个系统辅助所以我们还要烧录一个带桌面的系统到T
  • 解决 LINK : fatal error LNK1104: 无法打开文件“freeglutd.lib”问题

    最近跑程序 发现总有这样的错误 如下图 开始我以为是缺少了freeglutd lib这个文件 之后才发现压根没有这个文件 于是找到了解决办法 鼠标右键单击项目 选择属性 出现如下图 找到C C 预处理器 点开预处理器定义 点编辑 添加 ND
  • libevent源码学习(0):libevent库安装与简单使用

    目录 1 下载并解压libevent库 2 安装libevent库 3 简单使用libevent库 1 下载并解压libevent库 这里下载的是libevent 2 0 21 stable版本的 使用wget命令如下所示 下载地址可通过h
  • .git文件夹_Git入门细致讲解

    什么是 git 分布式的版本管理与协作系统 安装 Git 下载安装就不详说了 安装之后 右键会出现 Git bush here 在当前文件夹打开 bash 是一个小型的 linux shell 可以在上面进行关于 git 的操作 他自带 m
  • 妙用mov edi,edi和5个nop实现inline hook

    妙用mov edi edi和5个nop实现inline hook 2008年2月22日 分类 其它技术 标签 inline hook nop 这方法MJ很早时就说过了 简单重复下 大家应该发现大部分API的第一条指令都是mov edi ed
  • MySQL 5.7版本简介

    MySQL的优势 MySQL的主要优势如下 1 速度 运行速度快 2 价格 MySQL对多数个人来说是免费的 3 容易使用 与其他大型数据库的设置和管理相比 其复杂程度较低 易于学习 4 可移植性 能够工作在众多不同的系统平台上 例如 Wi
  • MySQL—SQL优化详解(上)

    作者 小刘在C站 个人主页 小刘主页 努力不一定有回报 但一定会有收获加油 一起努力 共赴美好人生 学习两年总结出的运维经验 以及思科模拟器全套网络实验教程 专栏 云计算技术 小刘私信可以随便问 只要会绝不吝啬 感谢CSDN让你我相遇 前言
  • 【vue】npm install -g @vue/cli出现错误

    进行到npm install g vue cli这一步出现错误 操作步骤如下 1 先下载node js 不知道有没有下载 可以在cmd输入 node v 出现版本号则电脑已经有了node js 没有的话去官网下一个 csdn有其他小伙伴给了
  • Spring Boot全后端实现验证码

    验证码通常是利用前端技术实现的 前端的验证码需要先在后端进行保存 再传到前端 再于前端传输的数据对比校验 一些前后端分离项目的工作量大大增加 而如果完全是由后端独立实现的 那么在代码量和复杂程度上就大大降低了 框架 Spring Boot
  • JavaWeb使用ajax实现定时自动保存草稿功能

    在Web程序开发中 难免有时候会遇到一些定时业务 如考试系统中的自动提交试卷 还有平时写博客时定时自动保存草稿的功能 在JavaWeb中也可以利用ajax技术来实现这定时自动保存草稿这一功能 index jsp关键代码 html代码 加载时
  • 编写QT程序时发现内存泄漏的解决方法

    最近项目结尾进行测试的时候 发现项目持续运行产生大量数据后内存的消耗会无休止的增加 当关闭该窗口时 内存却并没有如期释放 理论上QT中所有的子对象在窗口被销毁时都会一同销毁 最终发现是我在这篇博客上写的添加长按的效果导致的 https bl
  • ipynb 格式文件

    最近碰到文件名后缀为 ipynb文件 起初没太在意这种文件格式 用Notepad 打开之后看到也是类似于JSON格式的信息 以为也是为其他的一些文件服务的 类似于配置一些HTML文件的配置文件 但是后来才发现这也是一种文本表示形式 只不过需
  • Lua:调试篇

    1 Lua代码编辑工具 辣博推荐 ZeroBrane Studio编写Lua脚本还是不错滴 基本的代码补全和提示都具有 按照从下往上的代码逻辑 还可以自动对齐格式 实话讲 还不是很完美 毕竟 作为一个使用习惯Qt如此完美的IDE工具的 Qt
  • [线性dp] aw897. 最长公共子序列(重要模板题+最长公共子序列模型)

    文章目录 0 前言 1 LCS 模板题 0 前言 LCS longest common sub sequences 最长公共子序列 子串 按原顺序依次出现 禁止跳过某元素的序列 具有连续性 子序列 在保持元素前后关系的前提下 可以跳过某些元
  • C语言文件指针设置偏移量--fseek

    一 fseek fseek是设置文件指针偏移量的函数 具体传参格式为 int fseek FILE stream long int offset int whence 返回一个整数 其中 1 stream是指向文件的指针 2 offset是
  • (亲测可用)html5 file调用手机摄像头

    在切图网一个客户的webapp项目中需要用到 html5调用手机摄像头 找了很多资料 大都是 js调用api 然后怎样怎样 做了几个demo测试发现根本不行 后来恍然大悟 用html5自带的 input file 纯html5 并且不涉及到
  • composer.json和composer.lock到底是什么以及区别?

    composer方文档 https docs phpcomposer com 04 schema html我们在做项目的时候 总是要安装一些依赖 composer给我们提供了很多方便 直接运行composer install 当我们运行co
  • Keras 深度学习之猫狗大战

    项目地址和代码 Project Dogs vs Cats 项目详细报告 Report dogs vs cats pdf keras 版本 2 1 5使用滴滴云AI大师码 0212 消费GPU有9折优惠哦 1 问题定义和数据集获取 项目属于计