TensorFlow 中 sigmoid 后跟交叉熵和 sigmoid_cross_entropy_with_logits 有什么区别?

2023-12-01

当尝试使用 sigmoid 激活函数获取交叉熵时,两者之间存在差异

  1. loss1 = -tf.reduce_sum(p*tf.log(q), 1)
  2. loss2 = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=p, logits=logit_q),1)

但使用 softmax 激活函数时它们是相同的。

以下是示例代码:

import tensorflow as tf

sess2 = tf.InteractiveSession()
p = tf.placeholder(tf.float32, shape=[None, 5])
logit_q = tf.placeholder(tf.float32, shape=[None, 5])
q = tf.nn.sigmoid(logit_q)
sess.run(tf.global_variables_initializer())

feed_dict = {p: [[0, 0, 0, 1, 0], [1,0,0,0,0]], logit_q: [[0.2, 0.2, 0.2, 0.2, 0.2], [0.3, 0.3, 0.2, 0.1, 0.1]]}
loss1 = -tf.reduce_sum(p*tf.log(q),1).eval(feed_dict)
loss2 = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=p, logits=logit_q),1).eval(feed_dict)

print(p.eval(feed_dict), "\n", q.eval(feed_dict))
print("\n",loss1, "\n", loss2)

你混淆了交叉熵binary and 多类问题。

多类交叉熵

您使用的公式是正确的,它直接对应于tf.nn.softmax_cross_entropy_with_logits:

-tf.reduce_sum(p * tf.log(q), axis=1)

p and q预计是 N 个类别的概率分布。特别地,N 可以是 2,如下例所示:

p = tf.placeholder(tf.float32, shape=[None, 2])
logit_q = tf.placeholder(tf.float32, shape=[None, 2])
q = tf.nn.softmax(logit_q)

feed_dict = {
  p: [[0, 1],
      [1, 0],
      [1, 0]],
  logit_q: [[0.2, 0.8],
            [0.7, 0.3],
            [0.5, 0.5]]
}

prob1 = -tf.reduce_sum(p * tf.log(q), axis=1)
prob2 = tf.nn.softmax_cross_entropy_with_logits(labels=p, logits=logit_q)
print(prob1.eval(feed_dict))  # [ 0.43748799  0.51301527  0.69314718]
print(prob2.eval(feed_dict))  # [ 0.43748799  0.51301527  0.69314718]

注意q正在计算tf.nn.softmax,即输出概率分布。所以它仍然是多类交叉熵公式,仅适用于N = 2。

二元交叉熵

这次正确的公式是

p * -tf.log(q) + (1 - p) * -tf.log(1 - q)

虽然从数学上来说它是多类情况的部分情况,meaning of p and q是不同的。在最简单的情况下,每个p and q是一个数字,对应于 A 类的概率。

重要的: 不要被常见的东西迷惑了p * -tf.log(q)部分和总和。以前的p以前是一个单热向量,现在它是一个数字,零或一。同样适用于q- 这是一个概率分布,现在是一个数字(概率)。

If p是一个向量,每个单独的分量被认为是一个独立二元分类. See 这个答案概述了张量流中 softmax 函数和 sigmoid 函数之间的区别。所以定义p = [0, 0, 0, 1, 0]并不意味着一个单一的向量,而是5个不同的特征,其中4个是关闭的,1个是打开的。定义q = [0.2, 0.2, 0.2, 0.2, 0.2]意味着 5 个特征中的每一个都有 20% 的概率。

这解释了使用sigmoid交叉熵之前的函数:其目标是将 logit 压缩为[0, 1]间隔。

上面的公式对于多个独立特征仍然成立,这正是tf.nn.sigmoid_cross_entropy_with_logits计算:

p = tf.placeholder(tf.float32, shape=[None, 5])
logit_q = tf.placeholder(tf.float32, shape=[None, 5])
q = tf.nn.sigmoid(logit_q)

feed_dict = {
  p: [[0, 0, 0, 1, 0],
      [1, 0, 0, 0, 0]],
  logit_q: [[0.2, 0.2, 0.2, 0.2, 0.2],
            [0.3, 0.3, 0.2, 0.1, 0.1]]
}

prob1 = -p * tf.log(q)
prob2 = p * -tf.log(q) + (1 - p) * -tf.log(1 - q)
prob3 = p * -tf.log(tf.sigmoid(logit_q)) + (1-p) * -tf.log(1-tf.sigmoid(logit_q))
prob4 = tf.nn.sigmoid_cross_entropy_with_logits(labels=p, logits=logit_q)
print(prob1.eval(feed_dict))
print(prob2.eval(feed_dict))
print(prob3.eval(feed_dict))
print(prob4.eval(feed_dict))

您应该看到最后三个张量相等,而prob1只是交叉熵的一部分,因此只有当p is 1:

[[ 0.          0.          0.          0.59813893  0.        ]
 [ 0.55435514  0.          0.          0.          0.        ]]
[[ 0.79813886  0.79813886  0.79813886  0.59813887  0.79813886]
 [ 0.5543552   0.85435522  0.79813886  0.74439669  0.74439669]]
[[ 0.7981388   0.7981388   0.7981388   0.59813893  0.7981388 ]
 [ 0.55435514  0.85435534  0.7981388   0.74439663  0.74439663]]
[[ 0.7981388   0.7981388   0.7981388   0.59813893  0.7981388 ]
 [ 0.55435514  0.85435534  0.7981388   0.74439663  0.74439663]]

现在应该清楚了,取总和-p * tf.log(q) along axis=1在这种情况下没有意义,尽管它在多类情况下是有效的公式。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow 中 sigmoid 后跟交叉熵和 sigmoid_cross_entropy_with_logits 有什么区别? 的相关文章

随机推荐

  • 如何判断 net/http 的 ResponseWriter.Write() 是否已被调用?

    假设我有一系列 net http 处理程序 并且早期的处理程序响应 HTTP 错误 http StatusInternalServerError 例如 如何在以下处理程序中检测到这一点 并避免向客户端发送额外的数据 或者这完全是错误的解决问
  • Anagrams - C 中的链式哈希和探测

    我的标题被编辑了 所以我想确保每个人都知道这是作业 问题只是优化程序 散列是我的想法 我正在优化一个 C 程序 该程序将彼此不同的单词组合在一起 然后将它们打印出来 目前的程序基本上是一个链表的链表 外部列表中的每个链接都是一组彼此不同的单
  • 当发生错误 1224: ERROR_USER_MAPPED_FILE 时?

    我想真正了解什么时候ERROR USER MAPPED FILE发生 所以我写了一些片段 重现错误 但这没有用 请帮我修复我的代码 流程一 HANDLE hFile CreateFile C test full exe GENERIC RE
  • 如何获取返回集中每行空值列的计数?

    我正在寻找一个查询 该查询将在当前查询的末尾返回一个额外的列 该列是返回集中包含空列的所有列的计数 例如 Col 1 Col 2 Col 3 A B 0 A NULL 1 NULL NULL 2 是否有一种简单的方法可以根据行值获取此返回集
  • 如何让 Capture::Tiny 在失败时打印 stderr 和 stdout?

    我正在尝试通过Capture Tiny在失败时获取命令的输出 usr bin env perl use strict use warnings use feature say use Carp confess use Capture Tin
  • http响应文本获取不完整的html

    我在 excel vba 中有一个代码 如下所示 可以获取网页源 html 该代码工作正常 但它获取的 html 不完整 当线webpageSource oHttp ResponseText执行后 变量webpageSource包含 DOC
  • 如何创建具有延迟和尝试限制的 RXjs RetryWhen

    我正在尝试进行 API 调用 使用 angular4 当失败时使用 retryWhen 重试 我希望它延迟 500 毫秒 然后重试 这可以通过以下代码来实现 loadSomething Observable
  • ggplot2在标记刻度线之间显示未标记刻度线

    我在直方图上显示小刻度线时遇到问题 我尝试过绘制未标记的主要刻度线的想法 但刻度线不会显示 我的代码非常麻烦 可能有一些多余的行 任何帮助 将不胜感激 ggplot data Shrimp1 aes Shrimp1 Carapace Len
  • 如何在模态视图解除时传递对象

    我以模态方式呈现 VC 然后在选择单元格并从原始 VC 调用方法时将其关闭 现在的问题是nav and routineTableViewControllerNSLog 为空 我如何展示 VC 模型 NSString selectedRow
  • 如何使用expect通过ssh连接到系统并更改主机系统的密码?

    我正在自动化以下过程 通过 ssh 连接到名为 alpha 的系统 用户名 alpha 的密码为 alpha 连接后 我想设置 root 密码 kickass 我连接的系统默认没有 root 密码 我编写了这个期望脚本来完成这项工作 但它不
  • MKMapView 中显示用户位置的不稳定行为

    我有一个MKMapView与MKUserTrackingBarButtonItem 用户的当前位置只能显示在Follow or FollowWithHeading模式 实现如下所示 void mapView MKMapView mapVie
  • JavaScript 中的检查可返回我是否在智能手机上? [复制]

    这个问题在这里已经有答案了 我想对我的一个 JavaScript 函数进行检查 确定我是否在智能手机上 然后根据结果决定是否运行该函数 在 JavaScript 中检测 检查智能手机 或一般手持设备 的最佳方法是什么 e g if user
  • 显示:块内显示:内联

    我想了解当 CSS 为的元素时会发生什么display block是 CSS 为的元素的 DOM 子元素display inline 这样块元素是内联元素的子元素 这种情况在匿名块盒CSS 2 1 规范部分 示例包括以下规则 body di
  • iOS:可以在 Google Plus 中发送或发布消息

    在 google plus 中 是否有一个 API 可以在 iOS 中向 Google Plus 发送消息或提交帖子 我已经尝试阅读 google 文档 但还没有看到任何可以做到这一点的内容 好的 我明白了 在他们的文档上 https de
  • 结构成员的概念检查

    检查特定结构成员是否验证给定概念的简单 惯用的方法是什么 我尝试了以下方法 但它不起作用 因为 T f 产量类型float include
  • 如何在 Froyo 中检测设备的准确方向?

    我试图暂时锁定 Android 设备的方向 大多数时候它会随着传感器的变化而变化 所以我想做的是弄清楚当前的方向 横向 反向横向 纵向 反向纵向 是什么 将方向更改为该方向 然后将其改回原来的方向 我知道我可以使用诸如 int 方向 thi
  • Python 3 异常处理抛出错误

    我上周开始学习 python 但我无法弄清楚这里出了什么问题 def add x y Adds 2 numbers and returns the result return x y def sub x y Subtracts 2 numb
  • PHP 中的详细正则表达式?

    在 php net 上搜索我找不到任何支持详细的正则表达式在 PHP 中 这是我不知道如何搜索它的错 还是php没有实现它的错 如果php缺少这个功能 除了将正则表达式分成更小的段之外 还有其他方法来注释正则表达式吗 您还可以在正则表达式中
  • 将大型文本 (xyz) 数据库拆分为 x 个相等的部分

    我想拆分一个大型文本数据库 约 1000 万行 我可以使用类似的命令 sed i e 4 s dB e 4 s Best unit Best Unit e 1 3 d cygdrive c Radio Mobile Output TRC T
  • TensorFlow 中 sigmoid 后跟交叉熵和 sigmoid_cross_entropy_with_logits 有什么区别?

    当尝试使用 sigmoid 激活函数获取交叉熵时 两者之间存在差异 loss1 tf reduce sum p tf log q 1 loss2 tf reduce sum tf nn sigmoid cross entropy with