MobileNetV2 的 Keras 和 TensorFlow Hub 版本之间的差异

2023-12-03

我正在研究一种迁移学习方法,并且在使用 MobileNetV2 时得到了非常不同的结果keras.applications以及 TensorFlow Hub 上提供的一个。这对我来说似乎很奇怪,因为两个版本都声称here and here从同一检查点提取它们的权重mobilenet_v2_1.0_224。 这是如何重现差异的,你可以找到 Colab Notebookhere:

!pip install tensorflow-gpu==2.1.0
import tensorflow as tf
import numpy as np
import tensorflow_hub as hub
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2

def create_model_keras():
  image_input = tf.keras.Input(shape=(224, 224, 3))
  out = MobileNetV2(input_shape=(224, 224, 3),
                  include_top=True)(image_input)
  model = tf.keras.models.Model(inputs=image_input, outputs=out)
  model.compile(optimizer='adam', loss=["categorical_crossentropy"])
  return model

def create_model_tf():
  image_input = tf.keras.Input(shape=(224, 224 ,3))
  out = hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4",
                      input_shape=(224, 224, 3))(image_input)
  model = tf.keras.models.Model(inputs=image_input, outputs=out)
  model.compile(optimizer='adam', loss=["categorical_crossentropy"])
  return model

当我尝试对随机批次进行预测时,结果不相等:

keras_model = create_model_keras()
tf_model = create_model_tf()
np.random.seed(42)
data = np.random.rand(32,224,224,3)
out_keras = keras_model.predict_on_batch(data)
out_tf = tf_model.predict_on_batch(data)
np.array_equal(out_keras, out_tf)

版本的输出来自keras.applications总和为 1,但 TensorFlow Hub 的版本不是。而且两个版本的形状也不同:TensorFlow Hub 有 1001 个标签,keras.applications有 1000 个。

np.sum(out_keras[0]), np.sum(out_tf[0])

prints (1.0000001, -14.166359)

造成这些差异的原因是什么?我错过了什么吗?

编辑 2020年2月18日

正如 Szymon Maszke 指出的,TFHub 版本返回 logits。这就是为什么我添加了一个 Softmax 层create_model_tf如下:out = tf.keras.layers.Softmax()(x)

arnoegw提到TfHub版本需要将图像标准化为[0,1],而keras版本需要标准化为[-1,1]。当我对测试图像使用以下预处理时:

from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
img = tf.keras.preprocessing.image.load_img("/content/panda.jpeg", target_size=(224,224))
img = tf.keras.preprocessing.image.img_to_array(img)
img = preprocess_input(img)
img = tf.io.read_file("/content/panda.jpeg")
img = tf.image.decode_jpeg(img)
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.resize(img, (224,224))

两者都正确预测相同的标签,并且以下条件为真:np.allclose(out_keras, out_tf[:,1:], rtol=0.8)

编辑 2 2020 年 2 月 18 日在我写之前,不可能将格式相互转换。这是由一个错误引起的。


有几个已记录的差异:

  • 正如 Szymon 所说,TF Hub 版本返回 logits(在将其转换为概率的 softmax 函数之前),这是一种常见的做法,因为可以根据 logits 计算出具有更高数值稳定性的交叉熵损失。

  • TF Hub 模型假设 float32 输入在 [0,1] 范围内,这是您得到的tf.image.decode_jpeg(...)其次是tf.image.convert_image_dtype(..., tf.float32)。 Keras 代码使用特定于模型的范围(可能是 [-1,+1])。

  • TF Hub 模型在返回其所有 1001 个输出类时更完整地反映了原始 SLIM 检查点。正如文档中链接的 ImageNetLabels.txt 中所述,添加的类 0 是“背景”(又名“东西”)。这就是对象检测用来指示图像背景而不是任何已知类别的对象的方法。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

MobileNetV2 的 Keras 和 TensorFlow Hub 版本之间的差异 的相关文章

随机推荐

  • 在 PHP 中提取字符串的特定部分

    我只是想知道在 PHP 中提取动态字符串的特定部分最简单 最有效的方法是什么 例如 在此字符串中 http www dailymotion com video xclep1 school gyrls something like a par
  • Android Firebase云功能通知

    我已成功设置 firebase 云功能来向主题发送通知 问题是它发送给包括发件人在内的所有用户 我如何设置我的云功能 以便它不向发件人显示通知 请帮忙 以下是我如何发送到主题 exports sendNotesNotification fu
  • 父级上的 CKEditor“溢出:滚动”导致工具栏冻结在初始位置

    当您使用以下命令将 CKEditor 添加到 div 内的 div 时 overflow scroll 滚动父 div 时工具栏不会移动 div div This is the ckedito div div 可以在这里找到一个例子 htt
  • 当我导入客户端库时,为什么会出现 ReferenceError: self is not Defined ?

    试图创建一个xterm反应组件Next js我陷入了困境 因为我无法克服以前从未收到过的错误消息 我正在尝试导入一个名为的 npm 客户端模块xterm 但是如果我添加导入行 应用程序就会崩溃 import Terminal from xt
  • 正则表达式按空格分割但不转义空格

    我想按标准空白进行分割 但没有转义空格 例如 使用字符串 my name is max 单引号所以 是字面意思 我想要得到 my name is max 我试过这个正则表达式 s 但结果是这样的 gt m name is max 这很接近
  • 如何在 Dartlang 中检索元数据?

    Dartlang教程介绍package metahttps www dartlang org docs dart up and running contents ch02 html ch02 metadata DartEditor 识别元数
  • 从字符串中提取电话号码

    我正在尝试从给定的字符串中提取java中的电话号码 即电话号码可以位于字符串中的任何位置 例如 bla bla TELEPHONE NUMBER bla bla 现在我想在另一个字符串中提取这个电话号码 在使用时 matcher match
  • 如何将保存的 localStorage Web 数据传递到 php 脚本?

    好吧 所以我在尝试找出如何将我保存在 localStorage 中的一些数据传递到我编写的 php 脚本时遇到了一些问题 这样我就可以将其发送到服务器上的数据库 我之前确实找到了一些代码 https developer mozilla or
  • 发送 Outlook 日历邀请 PHP

    该代码的目标是使用 PHP 发送约会和阻止人员日历 我这里有两页 测试 php
  • 通过缓存电子表格值提高脚本性能

    我正在尝试使用 Google Apps 脚本开发一个网络应用程序 将其嵌入到 Google 站点中 该站点仅显示 Google 表格的内容并使用一些简单的参数对其进行过滤 至少目前是这样 稍后我可能会添加更多功能 我得到了一个功能齐全的应用
  • 将密码重置发送到其他电子邮件 - Devise

    我正在使用 Ruby on Rails 5 和 devise 我需要将密码重置电子邮件发送到与我的用户表中存储的电子邮件不同的电子邮件 如何才能实现这一目标 请注意 这是非常不推荐的实现方式 它不在最佳实践的范围内 它又脏又脆弱 但如果你真
  • Apple 文件系统从照片库读取的权限

    我的 ios 应用程序中有一个 UIWebView 它将响应式网站加载到我的 webview 中 在 asp net 中开发 网站有一个按钮用于从设备照片库中选择视频 另一个按钮用于上传视频 在 ios 版本 10 2 之前 它可以成功地将
  • 在帆和水线中混合使用 AND 和 OR 子句

    如何在 Sailsjs 及其 ORM Waterline 中使用 OR 和 AND 子句 例如我有一张书表 book name author free public Book A Author 1 false true Book B Aut
  • 错误标记主机:等待条件超时 [kubernetes]

    我刚刚开始学习 Kubernetes 我已经通过 Kubernetes YUM 存储库安装了 CentOS 7 5 并禁用了 SELinux 的 kubectl kubeadm 和 kubelet 然而 当我想开始一个kubeadm ini
  • 撇号 cms - 自定义小部件中富文本的内联编辑?

    在某些情况下 我无法将富文本的内联编辑保存回数据库 请耐心等待 这里将粘贴一些代码 因为这是我描述我正在做的事情的唯一方式 我的项目中有两种自定义小部件 一种只有一个小部件实例 通常在lib modules目录 article widget
  • 依赖注入类型选择

    最近我遇到一个问题 我必须根据参数选择类型 例如 用于发送通知的类 应根据输入参数选择正确的渠道 电子邮件 短信等 我看起来像这样 public class NotificationManager IEmail email ISms sms
  • Google URLShortener API 返回 ipRefererBlocked

    我正在尝试将 Google URL 缩短 API 与 PHP 结合使用 apiKey ABC url http www stackoverflow com postData array longUrl gt url jsonData jso
  • 正则表达式匹配除空格之外的单个字符

    我需要匹配一个不是空格的单个字符 但我不知道如何使用正则表达式来做到这一点 以下应该足够了 如果您想将其扩展到除空白之外的任何内容 换行符 制表符 空格 硬空格 s or S Note this is a CAPITAL S
  • 将数据从操作传递到另一个操作

    如何通过 RedirectAction 方法将模型从 GetDate 操作传递到另一个 ProcessP 操作 这是源代码 HttpPost public ActionResult GetDate FormCollection values
  • MobileNetV2 的 Keras 和 TensorFlow Hub 版本之间的差异

    我正在研究一种迁移学习方法 并且在使用 MobileNetV2 时得到了非常不同的结果keras applications以及 TensorFlow Hub 上提供的一个 这对我来说似乎很奇怪 因为两个版本都声称here and here从