Keras 模型训练良好,但预测的值相同

2023-12-12

让我们尝试制作MobileNet V. 2在嘈杂的图像上找到一条亮带。是的,使用深度卷积网络来实现这样的策略有点过分了,但最初它的目的就像烟雾测试一样,以确保模型有效。我们将使用合成数据对其进行训练:

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

SHAPE = (32, 320, 1)
def gen_sample():
    while True:
        data = np.random.normal(0, 1, SHAPE)
        i = np.random.randint(0, SHAPE[1]-8)
        data[:,i:i+8,:] += 4
        yield data.astype(np.float32), np.float32(i)

ds = tf.data.Dataset.from_generator(gen_sample, output_signature=(
    tf.TensorSpec(shape=SHAPE, dtype=tf.float32),
    tf.TensorSpec(shape=(), dtype=tf.float32))).batch(100)

d, i = next(gen_sample())
plt.figure()
plt.imshow(d)
plt.show()

A sample image

现在我们构建并训练一个模型:

model = tf.keras.models.Sequential([
    tf.keras.applications.MobileNetV2(
        input_shape=SHAPE, include_top=False, weights=None, alpha=0.5),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(1)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(
        learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=0.01, decay_steps=1000, decay_rate=0.9)),
    loss='mean_squared_error')
history = model.fit(ds, steps_per_epoch=10, epochs=40)

我们使用生成的数据,所以我们不需要验证集,不是吗?所以我们可以观察损失如何减少。而且它确实减少得很好:

Epoch 1/40
10/10 [==============================] - 27s 2s/step - loss: 15054.8417
Epoch 2/40
10/10 [==============================] - 23s 2s/step - loss: 193.9126
Epoch 3/40
10/10 [==============================] - 24s 2s/step - loss: 76.9586
Epoch 4/40
10/10 [==============================] - 25s 2s/step - loss: 68.8521
...
Epoch 37/40
10/10 [==============================] - 20s 2s/step - loss: 4.5258
Epoch 38/40
10/10 [==============================] - 20s 2s/step - loss: 22.1212
Epoch 39/40
10/10 [==============================] - 20s 2s/step - loss: 28.4854
Epoch 40/40
10/10 [==============================] - 20s 2s/step - loss: 18.0123

训练碰巧没有达到最佳结果,但它仍然应该是合理的:答案应该在真实值±8左右。我们来测试一下:

d, i = list(ds.take(1))[0]
model.evaluate(d, i)
np.stack((model.predict(d).ravel(), i.numpy()), 1)[:10,]
4/4 [==============================] - 0s 32ms/step - loss: 16955.7871
array([[ 66.84666 , 222.      ],
       [ 66.846664,  46.      ],
       [ 66.846664,  71.      ],
       [ 66.84668 , 268.      ],
       [ 66.846664,  86.      ],
       [ 66.84668 , 121.      ],
       [ 66.846664, 301.      ],
       [ 66.84667 , 106.      ],
       [ 66.84665 , 138.      ],
       [ 66.84667 ,  95.      ]], dtype=float32)

哇!如此巨大的评价损失从何而来?为什么模型不断预测相同的愚蠢值?训练期间一切都那么美好!

事实上,在一天左右的时间里,我意识到发生了什么事,但我向其他人提供了解决这个谜题并赚取一些积分的可能性。


问题在于,在训练模式下合理运行的网络无法在推理模式下运行。可能是什么原因?基本上有两种层类型在两种模式下以不同的方式工作:丢失和批量归一化。在MobileNet V. 2,我们只有批量归一化,所以让我们考虑一下它是如何工作的。

在训练模式中,BN 层计算批次均值和方差,并使用这些批次值对数据进行归一化。同时,它会记住均值和方差作为移动平均值,并用一个系数进行加权momentum.

moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
moving_var = moving_var * momentum + var(batch) * (1 - momentum)

确实,这momentum是一个重要的超参数,特别是当真实的批量统计数据远离初始值时。假设初始方差值为1.0,动量为0.99(这是默认值),真实的数据方差是0.1。比 10% 的误差 (var < 0.11)可以在447批次后实现。

现在问题的根本原因是:MobileNet所有众多的 BN 层都有momentum=0.999,这意味着需要 4497 个批处理步骤才能达到相同的 10% 误差!当您小批量训练 ImageNet 等非常大的异构数据集时,这是 100% 合理的超参数选择。但在这个玩具示例中,结果是 BN 层无法记住 400 个批次期间的真实数据统计信息,并且在推理过程中使用完全错误的值!

修复方法非常简单:只需更改之前的动量即可model.compile:

for layer in model.layers[0].layers:
    if type(layer) is tf.keras.layers.BatchNormalization:
        layer.momentum = 0.9
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras 模型训练良好,但预测的值相同 的相关文章

  • Python、Tkinter、更改标签颜色

    有没有一种简单的方法来更改按钮中文本的颜色 I use button text input text here 更改按下后按钮文本的内容 是否存在类似的颜色变化 button color red Use the foreground设置按钮
  • InterfaceError:连接已关闭(使用 django + celery + Scrapy)

    当我在 Celery 任务中使用 Scrapy 解析函数 有时可能需要 10 分钟 时 我得到了这个信息 我用 姜戈 1 6 5 django celery 3 1 16 芹菜 3 1 16 psycopg2 2 5 5 我也使用了psyc
  • 将字符串转换为带有毫秒和时区的日期时间 - Python

    我有以下 python 片段 from datetime import datetime timestamp 05 Jan 2015 17 47 59 000 0800 datetime object datetime strptime t
  • 如何生成给定范围内的回文数列表?

    假设范围是 1 X 120 这是我尝试过的 gt gt gt def isPalindrome s check if a number is a Palindrome s str s return s s 1 gt gt gt def ge
  • 导入错误:没有名为 _ssl 的模块

    带 Python 2 7 的 Ubuntu Maverick 我不知道如何解决以下导入错误 gt gt gt import ssl Traceback most recent call last File
  • 更改自动插入 tkinter 小部件的文本颜色

    我有一个文本框小部件 其中插入了三条消息 一条是开始消息 一条是结束消息 一条是在 单位 被摧毁时发出警报的消息 我希望开始和结束消息是黑色的 但被毁坏的消息 参见我在代码中评论的位置 插入小部件时颜色为红色 我不太确定如何去做这件事 我看
  • 如何使用包含代码的“asyncio.sleep()”进行单元测试?

    我在编写 asyncio sleep 包含的单元测试时遇到问题 我要等待实际的睡眠时间吗 I used freezegun到嘲笑时间 当我尝试使用普通可调用对象运行测试时 这个库非常有用 但我找不到运行包含 asyncio sleep 的测
  • 如何等到 Excel 计算公式后再继续 win32com

    我有一个 win32com Python 脚本 它将多个 Excel 文件合并到电子表格中并将其另存为 PDF 现在的工作原理是输出几乎都是 NAME 因为文件是在计算 Excel 文件内容之前输出的 这可能需要一分钟 如何强制工作簿计算值
  • __del__ 真的是析构函数吗?

    我主要用 C 做事情 其中 析构函数方法实际上是为了销毁所获取的资源 最近我开始使用python 这真的很有趣而且很棒 我开始了解到它有像java一样的GC 因此 没有过分强调对象所有权 构造和销毁 据我所知 init 方法对我来说在 py
  • python pandas 中的双端队列

    我正在使用Python的deque 实现一个简单的循环缓冲区 from collections import deque import numpy as np test sequence np array range 100 2 resha
  • HTTPS 代理不适用于 Python 的 requests 模块

    我对 Python 还很陌生 我一直在使用他们的 requests 模块作为 PHP 的 cURL 库的替代品 我的代码如下 import requests import json import os import urllib impor
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • Python 3 中“map”类型的对象没有 len()

    我在使用 Python 3 时遇到问题 我得到了 Python 2 7 代码 目前我正在尝试更新它 我收到错误 类型错误 map 类型的对象没有 len 在这部分 str len seed candidates 在我像这样初始化它之前 se
  • Nuitka 未使用 nuitka --recurse-all hello.py [错误] 编译 exe

    我正在尝试通过 nuitka 创建一个简单的 exe 这样我就可以在我的笔记本电脑上运行它 而无需安装 Python 我在 Windows 10 上并使用 Anaconda Python 3 我输入 nuitka recurse all h
  • glpk.LPX 向后兼容性?

    较新版本的glpk没有LPXapi 旧包需要它 我如何使用旧包 例如COBRA http opencobra sourceforge net openCOBRA Welcome html 与较新版本的glpk 注意COBRA适用于 MATL
  • 在 Python 类中动态定义实例字段

    我是 Python 新手 主要从事 Java 编程 我目前正在思考Python中的类是如何实例化的 我明白那个 init 就像Java中的构造函数 然而 有时 python 类没有 init 方法 在这种情况下我假设有一个默认构造函数 就像
  • 您可以在 Python 类型注释中指定方差吗?

    你能发现下面代码中的错误吗 米皮不能 from typing import Dict Any def add items d Dict str Any gt None d foo 5 d Dict str str add items d f
  • Spark.read 在 Databricks 中给出 KrbException

    我正在尝试从 databricks 笔记本连接到 SQL 数据库 以下是我的代码 jdbcDF spark read format com microsoft sqlserver jdbc spark option url jdbc sql
  • Python 分析:“‘select.poll’对象的‘poll’方法”是什么?

    我已经使用 python 分析了我的 python 代码cProfile模块并得到以下结果 ncalls tottime percall cumtime percall filename lineno function 13937860 9
  • Pandas 与 Numpy 数据帧

    看这几行代码 df2 df copy df2 1 df 1 df 1 values 1 df2 ix 0 0 我们的教练说我们需要使用 values属性来访问底层的 numpy 数组 否则我们的代码将无法工作 我知道 pandas Data

随机推荐

  • 如何隐藏从 Python 调度的 COM 对象

    我在 Python 中使用 COM 并且希望该对象在后台隐藏运行 使用 Excel 我会 Import win32com client Excel win32com client Dispatch Excel Application Exc
  • &货币更改为¤cy=GBP

    我有一个非常奇怪的问题 我生成的要在电子邮件中发送的查询字符串正在以某种方式更改 我编写的一个旧应用程序根据数据库中的各种参数创建一个 URL dim wpret as string a target blank href a instId
  • 根据javascript中的选择选项显示/隐藏div

    上网搜了一下 学会了如何做到这一点 实施了它 但这不起作用 我想在选择学生时显示 div 学生 在选择教师时显示 div 老师 这是 jsp 文件的一部分 HTML 代码 table tr td td tr table
  • strip_tags 足以从字符串中删除 HTML 吗?

    站点用户可以在站点上注册 并且在注册期间他可以提供名称 我希望这个名称是一个有效的名称 并且不含任何 HTML 和其他时髦字符 strip tags 足够吗 我发现没有单一的功能可以防止用户输入白痴 最好将几种混合在一起 val trim
  • 用 Python 读取 PowerPoint 表格?

    我正在使用 python pptx 模块自动更新 powerpoint 文件中的值 我可以使用以下代码提取文件中的所有文本 from pptx import Presentation prs Presentation path to pre
  • 如何很好地将qint64“转换”为QProgressBar的int

    我正在使用 QFtp 是的 我知道 并且一切正常 使用他们自己的示例中的代码作为指导 http doc qt io archives qt 4 7 network qftp ftpwindow cpp html 我遇到的唯一问题是发送 或接
  • 如何执行 SQL 表中列出的 SQL Server 代理作业

    我试图将所有 SQL Server 代理作业存储在表名称中 并希望根据它们的加载频率来执行它们 CREATE TABLE Maintainance SQLJobName varchar 100 SQL Job Name which need
  • 从 Firebase 函数将数据返回到 Android [重复]

    这个问题在这里已经有答案了 我正在尝试做的事情 只需从 Firebase Cloud Function 返回数据即可 该函数用于在支付网关的服务器中创建支付订单 我所需的有关订单详细信息的数据位于function err data 见下文
  • 像松弛评论框反应原生的动画

    我正在开发一个评论框 在向上滑动操作时将其扩展到设备的高度 并在向下滑动操作时返回到其原始高度 但我无法向其中添加动画 因为该功能无法按照我想要的方式工作 作为参考 我们可以讨论松弛评论框动画 我的代码如下 code 小吃链接 https
  • split(" +") 和 split(" ") 不同

    我想消除字符串中的真空 String input java example java aaa bbb String temp input trim split 结果是 java 示例 javaaaabbb 但我想要的结果是 java示例 j
  • Javascript非对称加密和认证

    这里的一些人正在开发一个应用程序 其中包含一些可通过登录访问的 安全区域 在过去 登录表单和后续的 安全 页面都是通过 http 传输的纯文本 因为它是一个用于访问的应用程序 在几乎不可能使用 SSL 的共享服务器上使用 例如 WordPr
  • setState() 位于 componentDidUpdate() 内部

    我正在编写一个脚本 该脚本根据下拉菜单的高度和屏幕上输入的位置将下拉菜单移动到输入下方或上方 我还想根据其方向将修饰符设置为下拉菜单 但使用setState里面的componentDidUpdate创建无限循环 这是显而易见的 我找到了使用
  • Java中如何转义引号

    我被这个问题困住了 我有一个ResultSet写入 html 报告 这ResultSetis writer write td a href rsevidencia getString Evidencia a a a td 但该链接不起作用
  • React Native 要求图像与变量中断

    为什么这条线工作时没有错误 var gicon species ii color 0 require assets gLight jpg require assets nLight png 而这一行会抛出错误 which light gLi
  • 如何设置PHP下载文件到特定目录?

    我正在寻找有关此问题的一些一般指导 我创建了一个使用 cURL 下载 csv 文件的 PHP 脚本 目前 当我运行脚本时 它会将文件下载到我的计算机 我想修改下载路径以将其路由到我的网络主机上的目录 有没有什么简单的方法可以用 PHP 来做
  • UIImage 的高质量缩放

    我需要缩放来自 iPhone 应用程序中视图层的图像的分辨率 显而易见的方法是在 UIGraphicsBeginImageContextWithOptions 中指定比例因子 但只要比例因子不是 1 0 图像质量就会受到影响 远远超出像素损
  • 将div变成链接

    我有一个 div 阻止一些我不想改变的奇特的视觉内容 我想让它成为一个可点击的链接 我正在寻找类似的东西 a href div div a 但这是有效的 XHTML 1 1 来到这里是希望找到一个更好的解决方案 但我不喜欢这里提供的任何解决
  • 使用存储在一个固定(流)文档中的 VisualBrush 进行及时控制的快照

    我需要及时拍摄 Control 的快照并将它们存储在一个固定文档中 问题是 VisualBrush 在某种程度上是 懒惰的 并且不会通过将其添加到文档来评估自身 当我最终创建文档时 所有页面都包含相同 最后 的控制状态 虽然 VisualB
  • Jasper iReport 自定义日期和自定义时间

    在 Excel 中 我有一个日期格式 yyyy MM dd hh mm 和一个时间格式 hh mm 我将其设置为 iReport 数据库的源 对于日期 我在 iReport 中设置了自定义日期格式 与 Excel 中相同 然后我设置类 ja
  • Keras 模型训练良好,但预测的值相同

    让我们尝试制作MobileNet V 2在嘈杂的图像上找到一条亮带 是的 使用深度卷积网络来实现这样的策略有点过分了 但最初它的目的就像烟雾测试一样 以确保模型有效 我们将使用合成数据对其进行训练 import numpy as np im