Tensorflow unsorted_segment_sum 维度

2024-03-13

我正在使用tf.unsorted_segment_sumTensorFlow 的方法,当我作为数据给出的张量只有一行时,它工作得很好。例如:

tf.unsorted_segment_sum(tf.constant([0.2, 0.1, 0.5, 0.7, 0.8]),
                        tf.constant([0, 0, 1, 2, 2]), 3)

给出正确的结果:

array([ 0.3,  0.5 , 1.5 ], dtype=float32)

问题是,如果我使用多行张量,如何获得每行的结果?例如,如果我尝试使用两行张量:

tf.unsorted_segment_sum(tf.constant([[0.2, 0.1, 0.5, 0.7, 0.8],
                                     [0.2, 0.2, 0.5, 0.7, 0.8]]),
                        tf.constant([[0, 0, 1, 2, 2],
                                     [0, 0, 1, 2, 2]]), 3)

我期望的结果是:

array([ [ 0.3,  0.5 , 1.5 ], [ 0.4, 0.5, 1.5 ] ], dtype=float32)

但我得到的是:

array([ 0.7,  1. ,  3. ], dtype=float32)

我想知道是否有人知道如何在不使用 for 循环的情况下获取每一行的结果?

提前致谢


EDIT:

虽然下面的解决方案可能涵盖一些其他奇怪的用途,但只需转置数据就可以更轻松地解决此问题。事实证明,尽管tf.unsorted_segment_sum没有axis参数,它只能沿一个轴工作,只要它是第一个轴。所以你可以这样做:

import tensorflow as tf

with tf.Session() as sess:
    data = tf.constant([[0.2, 0.1, 0.5, 0.7, 0.8],
                        [0.2, 0.2, 0.5, 0.7, 0.8]])
    idx = tf.constant([0, 0, 1, 2, 2])
    result = tf.transpose(tf.unsorted_segment_sum(tf.transpose(data), idx, 3))
    print(sess.run(result))

Output:

[[ 0.30000001  0.5         1.5       ]
 [ 0.40000001  0.5         1.5       ]]

原帖:

tf.unsorted_segment_sum不支持在单轴上工作。最简单的解决方案是将操作应用于每一行,然后将它们连接回去:

data = tf.constant([[0.2, 0.1, 0.5, 0.7, 0.8],
                    [0.2, 0.2, 0.5, 0.7, 0.8]])
segment_ids = tf.constant([[0, 0, 1, 2, 2],
                           [0, 0, 1, 2, 2]])
num_segments = 3
rows = []
for data_i, ids_i in zip(data, segment_ids):
    rows.append(tf.unsorted_segment_sum(data_i, ids_i))
result = tf.stack(rows, axis=0)

但是,这有缺点:1)它仅适用于静态形状的张量(即,您需要有固定数量的行),2)它可能效率不高。第一个可以通过使用绕过tf.while_loop,但是,这会很复杂,而且需要你将行逐一连接起来,效率非常低。另外,您已经说过要避免循环。

更好的选择是为每一行使用不同的 id。例如,您可以添加到中的每个值segment_id就像是num_segments * row_index,因此您可以保证每一行都有自己的一组 id:

num_rows = tf.shape(segment_ids)[0]
rows_idx = tf.range(num_rows)
segment_ids_per_row = segment_ids + num_segments * tf.expand_dims(rows_idx, axis=1)

然后你可以应用操作和重塑来获得你想要的张量:

seg_sums = tf.unsorted_segment_sum(data, segment_ids_per_row,
                                   num_segments * num_rows)
result = tf.reshape(seg_sums, [-1, num_segments])

Output:

array([[ 0.3, 0.5, 1.5 ],
       [ 0.4, 0.5, 1.5 ]], dtype=float32)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow unsorted_segment_sum 维度 的相关文章

随机推荐

  • 如何在 Android 中使用文本视图显示颠倒的文本?

    如何在 Android 中使用文本视图显示颠倒的文本 就我而言 我有一个 2 人游戏 他们彼此面对面玩 我想向第二个面向他们的玩家展示测试 这是我在 AaronMs 建议后实施的解决方案 执行重写的类 bab foo UpsideDownT
  • Firebase 服务器时间戳将 iOS 翻倍

    ServerValue timestamp 回报 AnyHashable Any 如何将其转换为Double 这样我就可以创建一个带有时间戳的日期 这并不是 Firebase 时间戳的工作原理 它实际上所做的是将时间戳写入节点 但在写入之后
  • 如何验证 ZF2 中的复选框

    我已经阅读了许多针对 Zend Framework 缺乏默认复选框验证的解决方法 我最近开始使用 ZF2 但文档有点缺乏 有人可以演示如何使用 Zend 表单和验证机制验证复选框以确保其被选中吗 我正在为我的表单使用数组配置 使用 ZF 网
  • 安全组出口规则仅允许 ECR 请求

    当使用 ECR 存储用于 ECS 的容器映像时 EC2 实例 或 Fargate 服务 必须具有允许 通过公共互联网 访问特定于账户的存储库 URI 的安全组 许多组织都有严格的 IP 白名单规则 通常不允许为所有 IP 启用出站端口 44
  • 从命令行在 Hadoop 中检测压缩编解码器

    有没有简单的方法可以找出 Hadoop 中用于压缩文件的编解码器 我是否需要编写 Java 程序 或者将文件添加到 Hive 以便我可以使用describe formatted table 一种方法是在本地下载文件 使用hdfs dfs g
  • 具有接口的枚举类成员无法在内部找到方法

    我遇到了一个奇怪的问题 我不确定这是编译器问题还是我对接口枚举的理解 我正在使用 IntelliJ IDEA 12 构建一个 Android 项目 并且我有一个这样的类 public class ClassWithEnum private
  • Azure 服务总线序列化类型

    随着我们转向面向服务的体系结构 我们已开始研究使用 Windows Azure 服务总线来替代当前的队列 大部分文档都很清楚 但是我很难确定哪种类型的序列化BrokeredMessage当提供主体时使用 例如 假设我实例化了一个Broker
  • React:formik 表单,如何在回调函数内提交后使用状态

    我在用formik插件reactjs我想要useState表单提交后的变量 Both this and setState未定义 我无法实现它 有人可以帮我完成这件事吗 See screenshot below In JavaScript 默
  • android 延迟加载未在手机上显示图像或显示速度很慢

    我正在使用 JSON 来解析在线 xml 文档以及两种延迟图像加载的方法 以下是我的源代码 解释和我的问题 解释 方法一 使用AsyncTask和线imageLoader DisplayImage String jsonImageText
  • 安装chatterBot时出错

    每当我尝试使用命令安装 ChatterBot 时pip install ChatterBot它给出了这个错误 Retrying Retry total 0 connect None read None redirect None after
  • 扩展点或从 Liquid 模板访问 OpenApiDocument

    We have 规范扩展 https github com OAI OpenAPI Specification blob master versions 3 0 2 md specification extensions i e x isP
  • Git 子模块工作流程建议

    所以几天前我开始使用 Git 聚会已经很晚了 别骂 真正开始熟悉基本命令 想法和工作流程 然而 子模块确实让我大吃一惊 我正在尝试贡献代码FuelPHP http fuelphp com s GitHub https github com
  • Symfony2 中数据库测试的实践?如何隔离?

    目前测试与 Symfony2 数据库交互的最佳实践是什么 我有一个简单的 CRUD 设置 我想确保我的测试没问题 现在 我有 4 个测试 每一个测试都确保创建 更新 删除和列出操作正常发生 我的测试用例有两个神奇的方法 construct
  • 错误代码:1093。您无法在 FROM 子句中指定更新的目标表

    假设我有一个产品表 并且只有 2 个字段 id 和购买日期 我想删除 2019 年购买的最后一件产品 我尝试使用以下查询来做到这一点 DELETE FROM products WHERE id SELECT id FROM products
  • 加快 WMA(加权移动平均线)计算速度

    我正在尝试计算 15 天柱的指数移动平均线 但希望查看每个 结束 日 柱的 15 天柱 EMA 的 演变 所以 这意味着我有 15 天的柱线 当每天出现新数据时 我想使用新信息重新计算 EMA 实际上 我有 15 天的柱形图 然后 每天之后
  • 如何将图像复制到SD卡上的现有目录?

    我正在尝试使用以下代码复制图像文件 InputStream fileInputStream null OutputStream fileOutputStream null String inPath storage emulated 0 P
  • MFMessageComposeViewController 不显示相机图标

    当我手动调出 新消息 时 我会在文本编辑区域左侧看到一个相机图标 当我使用 MFMessageComposeViewController 时 它不会显示此图标 这意味着您无法插入图像 我知道这是可以做到的 因为那些创造了txtAgif ht
  • 动态尿路感染稳定吗?

    我的文件格式没有声明的 UTI 因此 Launch Services 已为其分配了动态 UTI dyn ah62d4rv4ge81g23wsmw1a5dbte 我无法控制这些文档的 UTI 我也想为该格式开发一个快速查看生成器 并且快速查看
  • SAML 2.0:如何配置断言消费者服务 URL

    我正在实现一个 SAML 2 0 服务提供商 它使用 Okta 作为身份提供商 我想配置断言消费者服务 ACS URL 以便我的服务提供商应用程序中的 SAML 2 0 反映回断言中 但是 我注意到 Okta 身份提供程序改为发送在 Okt
  • Tensorflow unsorted_segment_sum 维度

    我正在使用tf unsorted segment sumTensorFlow 的方法 当我作为数据给出的张量只有一行时 它工作得很好 例如 tf unsorted segment sum tf constant 0 2 0 1 0 5 0