基于深度学习的花卉图像关键点检测

2023-10-30

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

在本文中,我们描述了我们如何使用卷积神经网络 (CNN) 来估计花卉图像中关键点的位置,并且在 3D 模型上渲染这些图像上茎和花的位置等关键点。

为了能够与真实花束的照片对比,所创建的图像必须尽可能逼真。这是通过使用从多个角度拍摄的真实花朵照片并将它们渲染在 3D 模型上来实现的。对于每一朵新花,他们都会从 7 个不同的角度拍摄照片。在照相亭中,花朵由电机自动旋转。

相比之下,图片的后期处理还没有完全自动化。目前数据库中有数千种鲜花,每天都会添加新的鲜花。将此乘以角度数,将获得大量要手动处理的图片。后处理步骤之一是定位 3D 模型所需的图像上的几个关键点,最重要的是茎位和花顶位置。

数据集

在数据集中,成千上万的图像已经手动标注了关键点,所以我们有大量的训练数据可以使用。

以上是训练数据集中的一些带注释的花,它从几个不同的角度展示了同一朵花。茎位置为蓝色,花顶部位置为绿色。在一些图片中,茎的起源被花本身隐藏了。在这种情况下,我们需要“有根据的猜测”最有可能在哪里。

网络模型

因为模型必须输出一个数字而不是一个类,所以我们实际上是在做回归。CNN 以分类任务而闻名,但在回归方面也表现良好。例如,DensePose使用基于 CNN 的方法进行人体姿势估计。

网络从几个标准卷积块开始。这些块由3个卷积层组成,然后是最大池、批量标准化层和退出层。

  • 所述卷积层含有多个滤波器。每个过滤器就像一个模式识别器。下一个卷积块有更多的过滤器,所以它可以在模式中找到模式。

  • 最大池化会降低图像的分辨率。这限制了模型中的参数数量。通常,对于图像分类,我们对某个对象在图像中的位置不感兴趣,只要它在那里即可。在我们的例子中,我们对位置感兴趣。尽管如此,拥有几个最大池化层并不会影响性能。

  • 批量标准化层有助于模型更快地训练(收敛)。在一些深度网络中,没有它们,训练完全失败。

  • 退出层将随机禁用节点,这将防止过度拟合模型。

在卷积块之后,我们将张量展平,使其与密集层兼容。全局最大池化或平均最大池化也将实现平坦张量,但会丢失所有空间信息。扁平化在我们的实验中效果更好,尽管它的(计算)成本是拥有更多模型参数导致更长的训练时间。

在两个带有Relu激活的密集隐藏层之后是输出层,我们想要预测2 个关键点的x和y坐标,所以我们需要在输出层有 4 个节点。图像可以有不同的分辨率,因此我们将坐标缩放到 0 到 1 之间,并在使用前将它们放大。输出层没有激活函数。即使目标变量在 0 和 1 之间,这对我们来说也比使用sigmoid效果更好。作为参考,以下是我们使用的 Python 深度学习库Keras的完整模型摘要:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 126, 126, 64)      2368      
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 124, 124, 64)      36928     
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 122, 122, 64)      36928     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 61, 61, 64)        0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 61, 61, 64)        256       
_________________________________________________________________
dropout_1 (Dropout)          (None, 61, 61, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 59, 59, 128)       73856     
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 57, 57, 128)       147584    
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 55, 55, 128)       147584    
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 27, 27, 128)       0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 27, 27, 128)       512       
_________________________________________________________________
dropout_2 (Dropout)          (None, 27, 27, 128)       0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 93312)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               23888128  
_________________________________________________________________
batch_normalization_3 (Batch (None, 256)               1024      
_________________________________________________________________
dropout_3 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 256)               65792     
_________________________________________________________________
batch_normalization_4 (Batch (None, 256)               1024      
_________________________________________________________________
dropout_4 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 4)                 1028      
=================================================================
Total params: 24,403,012
Trainable params: 24,401,604
Non-trainable params: 1,408
_________________________________________________________________

你们可能会问:为什么是 3 个卷积层?或者为什么是 2 个卷积块?我们在超参数搜索中将这些数字作为超参数包括在内。连同诸如密集层数、退出层、批量标准化和卷积滤波器数量之类的参数,我们进行了随机搜索以找到超参数的最佳组合。

对于训练,我们使用学习率为的Adam 优化器0.005。当验证损失在几个时期内没有改善时,学习率会自动降低。作为损失函数,我们使用均方误差 (MSE)。因此,大错误比小错误受到的惩罚相对更多。


训练和效果

这些是训练 50 个时期后的损失(误差)图:

大约 8 个 epoch 后,验证损失变得高于训练损失。直到训练结束,验证损失仍然减少,因此我们没有看到模型严重过度拟合的迹象。测试集上的最终损失 (MSE) 为0.0064. MSE 的解释可能非常不直观。

MAE 是——这意味着预测平均降低 1.7% 

白色圆圈包含目标关键点,实心圆圈包含我们的预测。在大多数情况下,它们非常接近(重叠)。

改进

我们有一些改进的想法,但我们还没有时间实施:

  1. 目前,单个模型正在估计两个关键点。为每个关键点训练一个特定的模型可能会更好。这还有一个额外的好处,可以稍后添加新的关键点,而无需重新训练完整的模型。

  2. 另一个想法是考虑照片的角度。例如,将其添加为密集层的输入,可能会争辩说,角度会改变任务的性质,因此提供此信息可能有助于网络。按照这种思路,为每个角度训练一个单独的网络也可能是有益的。

结论

通过这项研究,我们证明了使用 CNN 检测花卉图像中的关键点的可行性。所使用的方法也可能适用于其他领域的后处理任务,例如产品摄影。

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于深度学习的花卉图像关键点检测 的相关文章

随机推荐

  • 怎么制作睡袋rust_unturned睡袋怎么做 unturned睡袋合成方法介绍

    自从上个月初unturned游戏发布以来 受到了大批网友的追逐 甚至成为了steam上排名第三的游戏 在游戏中玩家虽然能够无限重生 但是重生的地方却是随机的 有时甚至会随机重生到僵尸群中 而unturned睡袋却可以帮助玩家固定重生点 那u
  • springboot点餐微信小程序系统毕业设计源码221144

    springboot点餐微信小程序 摘 要 点餐微信小程序采用B S模式 采用JAVA语言 springboot框架 mysql数据库 小程序框架uniapp等开工具 促进了点餐微信小程序的业务发展 与传统线下点餐相比 点餐维信小程序不但节
  • Python将图像转成像素风,圆圈、线条、波浪、十字绣、乐高积木、我的世界积木、回形针、字母......

    Python将图像转成像素风 圆圈 线条 波浪 十字绣 乐高积木 我的世界积木 回形针 字母 1 效果图 2 原理 3 源码 参考 1 效果图 回形针效果图如下 十字绣效果图如下 水平线效果图如下 垂直线效果图如下
  • 云原生之使用docker部署NTP时间服务器

    云原生之使用docker部署NTP时间服务器 一 chrony介绍 二 容器镜像介绍 三 检查本地docker状态 四 下载ntp镜像 五 部署ntp服务器 1 创建ntp容器 2 查看ntp容器状态 六 检查ntp服务器的时间源 七 客户
  • 增量训练lightgbm模型,深度学习模型

    1 机器学习 增量训练方法 机器学习 增量训练方法 知乎 包含 sklearn lightgbm增量训练方法 2 深度学习模型增量训练 增量训练主要面临的问题 当增量训练时 主要解决的是新增加的训练样本中的新词问题 如果对新增加的新词不做i
  • Vue之父子组件通信(一)

    1 父组件向子组件传递数据 父组件向子组件传值 1 父组件调用子组件的时候 绑定动态属性
  • ST-Bluenrg-lp芯片编程因为地址重叠导致常量值被更改

    所遇问题 定义的结构体 用于限制范围大小 类似于 struct test SysParaMax test1 5000 test2 5000 test3 100 test4 600 struct test SysParaMin test1 0
  • intellij idea tomcat permGen space

    vmsettings options are Xms128m Xmx700m XX MaxPermSize 250m XX ReservedCodeCacheSize 64m tomcat are Xms64m Xmx256m
  • cin读取数字时遇到字符的情况

    cin读取数字时遇到字符 当定义一个int变量 用cin输入时 如果输入的是一个字符 会发生以下4中情况 1 n的值变成0 2 不匹配的输入被留在输入流中 3 cin对象的一个错误标记被设置 即cin fail 返回true 4 对cin的
  • SpringBoot项目用 jQuery webcam plugin实现调用摄像头拍照并保存图片

    参考博客 http www voidcn com article p oigngyvb kv html 自定义样式
  • TestNG测试用例

    使用TestNG的第一个测试用例 要遵循的步骤 1 按Ctrl N 在TestNG类别下选择 TestNG Class 然后单击Next 要么 右键单击Test Case文件夹 转到TestNG并选择 TestNG Class 2 如果您的
  • 考研/面试 数据结构大题必会代码(理解+记忆,实现使用C++,STL库)

    文章目录 一 线性表 1 逆置顺序表所有元素 2 删除线性链表中数据域为 item 的所有结点 3 逆转线性链表 递归 快速解题 非递归 4 复制线性链表 递归 5 将两个按值有序排列的非空线性链表合并为一个按值有序的线性链表 二 树 1
  • 门面模式

    门面模式是对象的结构模式 外部与一个子系统的通信必须通过一个统一的门面对象进行 门面模式提供一个高层次的接口 使得子系统更易于使用 门面模式有三个角色组成 1 门面角色 facade 这是门面模式的核心 它被客户角色调用 因此它熟悉子系统的
  • DVWA靶场--文件上传/包含(low-high).

    文件上传 low 没有做任何过滤直接上传即可 medium 源码 uploaded type image jpeg uploaded type image png 这段源码可以看出来他对上传到content type值做了过滤 只允许上传这
  • 分享如何建立一个完美的 Python 项目

    当开始一个新的 Python 项目时 大家很容易一头扎进去就开始编码 其实花一点时间选择优秀的库 将为以后的开发节省大量时间 并带来更快乐的编码体验 在理想世界中 所有开发人员的关系是相互依赖和关联的 协作开发 代码要有完美的格式 没有低级
  • 小程序`canvasToTempFilePath:fail:cearte bitmap failed?`

    这个方法的思路来源链接 微信开放社区 主要是通过延迟 重试 以及画质来解决手机性能等问题导致的canvasToImageFile故障 代码仅供参考 欢迎大家提供更多方法或思路 或指出代码异常 谢谢 下面是我用到项目中的代码片段 海报信息 P
  • PID算法的理论分析

    PID算法的理论和应用 PID算法基本原理 PID算法的离散化 PID算法伪代码 PID算法C 实现 pid cpp pid h pid example cpp Python代码 仿真结果 PID算法基本原理 PID算法是控制行业最经典 最
  • webbench剖析

    webbench 其为linux上一款web性能压力测试工具 它最多可以模拟3万个并发连接数来测试服务器压力 其原理为fork多个子进程 每个子进程都循环做web访问测试 子进程将访问的结果通过管道告诉父进程 父进程做最终结果统计 其主要原
  • javaweb 之 JDBC 详解 数据库连接池

    JDBC简介 JDBC 就是使用Java语言操作关系型数据库的一套API 全称 Java DataBase Connectivity Java 数据库连接 JDBC 本质 官方 sun公司 定义的一套操作所有关系型数据库的规则 即接口 各个
  • 基于深度学习的花卉图像关键点检测

    点击上方 小白学视觉 选择加 星标 或 置顶 重磅干货 第一时间送达 在本文中 我们描述了我们如何使用卷积神经网络 CNN 来估计花卉图像中关键点的位置 并且在 3D 模型上渲染这些图像上茎和花的位置等关键点 为了能够与真实花束的照片对比