合并张量流数据集批次

2024-05-20

请考虑下面的代码:

import tensorflow as tf
import numpy as np
 
simple_features = np.array([
         [1, 1, 1],
         [2, 2, 2],
         [3, 3, 3],
         [4, 4, 4],
         [5, 5, 5],

])
 
simple_labels = np.array([
         [-1, -1],
         [-2, -2],
         [-3, -3],
         [-4, -4],
         [-5, -5],

])
 

simple_features1 = np.array([
         [1, 4, 1],
         [2, 2, 2],
         [3, 3, 3],
         [6, 4, 4],
         [5, 4, 5],

])
 
simple_labels1 = np.array([
         [8, -7],
         [-2, -2],
         [-3, 7],
         [-4, 9],
         [-5, -5],

])

def print_dataset(ds):
    for inputs, targets in ds:
        print("---Batch---")
        print("Feature:", inputs.numpy())
        print("Label:", targets.numpy())
        print("")
        
ds1 = tf.keras.preprocessing.timeseries_dataset_from_array(simple_features, simple_labels, sequence_length=4, batch_size=1)
print_dataset(ds1)

ds2 = tf.keras.preprocessing.timeseries_dataset_from_array(simple_features1, simple_labels1, sequence_length=4, batch_size=1)
print_dataset(ds2)

上面的代码将创建特征和标签。我想按以下方式合并两个相应的批次。例如第一批ds1的显示方式如下:

---Batch---
Feature: [[[1 1 1]
  [2 2 2]
  [3 3 3]
  [4 4 4]]]
Label: [[-1 -1]]

...第一批 ds2 看起来像这样。

---Batch---
Feature: [[[1 4 1]
  [2 2 2]
  [3 3 3]
  [6 4 4]]]
Label: [[ 8 -7]]

第一批 ds1 和第一批 ds2 应该以这种方式合并,得到以下输出:

---Batch---
Feature: [[[1 1 1 1 4 1]
  [2 2 2 2 2 2]
  [3 3 3 3 3 3]
  [4 4 4 6 4 4 ]]]
Label: [[-1 -1 8 -7]]

您可以使用tf.concat连接两个数据集:

import tensorflow as tf
import numpy as np
 
simple_features = np.array([
         [1, 1, 1],
         [2, 2, 2],
         [3, 3, 3],
         [4, 4, 4],
         [5, 5, 5],
])
simple_labels = np.array([
         [-1, -1],
         [-2, -2],
         [-3, -3],
         [-4, -4],
         [-5, -5],
])
simple_features1 = np.array([
         [1, 4, 1],
         [2, 2, 2],
         [3, 3, 3],
         [6, 4, 4],
         [5, 4, 5],
])
simple_labels1 = np.array([
         [8, -7],
         [-2, -2],
         [-3, 7],
         [-4, 9],
         [-5, -5],
])

def print_dataset(ds):
    for inputs, targets in ds:
        print("---Batch---")
        print("Feature:", inputs.numpy())
        print("Label:", targets.numpy())
        print("")
        
ds1 = tf.keras.preprocessing.timeseries_dataset_from_array(simple_features, simple_labels, sequence_length=4, batch_size=1)
ds2 = tf.keras.preprocessing.timeseries_dataset_from_array(simple_features1, simple_labels1, sequence_length=4, batch_size=1)

def merge(data1, data2):
  x1, y1 = data1
  x2, y2 = data2
  return tf.concat([x1, x2], axis=-1), tf.concat([y1, y2], axis=-1)

dataset = tf.data.Dataset.zip((ds1, ds2)).map(merge)
print_dataset(dataset)
---Batch---
Feature: [[[1 1 1 1 4 1]
  [2 2 2 2 2 2]
  [3 3 3 3 3 3]
  [4 4 4 6 4 4]]]
Label: [[-1 -1  8 -7]]

---Batch---
Feature: [[[2 2 2 2 2 2]
  [3 3 3 3 3 3]
  [4 4 4 6 4 4]
  [5 5 5 5 4 5]]]
Label: [[-2 -2 -2 -2]]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

合并张量流数据集批次 的相关文章

随机推荐

  • CSS:水平滚动时背景不存在

    好的 我的背景设置如下 HTML div div CONTENT HERE div div CSS container background url image gif content width 800px margin auto 因此
  • 如何检测字符串中的非 ASCII 字符?

    如果我有一个 PHP 字符串 如何以有效的方式确定它是否至少包含一个非 ASCII 字符 我所说的非 ASCII 字符是指不属于该表的任何字符 http www asciitable com http www asciitable com
  • Flask 环境变量被忽略(FLASK_ENV 和 FLASK_APP)WINDOWS 10

    After setting the environment variables FLASK ENV and FLASK APP running flask run will give me this error 该代码片段显示了命令提示符
  • Pytest报告摘要显示错误信息

    我对 pytest 挂钩和插件相对较新 我无法弄清楚如何让我的 pytest 代码为我提供测试执行摘要以及失败原因 考虑代码 class Foo def init self val self val val def test compare
  • 如何模拟MyBatis映射器接口?

    我正在为我的 Jersey Rest API 编写单元测试 该 API 在后台使用 MyBatis 这是我的课程结构 休息服务 Path api public class HelloRestService Inject HelloBean
  • 确定 NSView 是否绘制的正确方法

    有没有正确的方法来确定是否NSView实际上是否在当前视图层次结构中绘制 考虑以下情况 视图完全在屏幕外 不是强制性的 该视图不在视图层次结构的顶部 The isHidden and isHiddenOrHasHiddenAncestor不
  • CameraX Image拍照速度慢

    我正在使用 CameraX 这是我的图像捕获 mImageCapture ImageCapture Builder setCaptureMode ImageCapture CAPTURE MODE MINIMIZE LATENCY setT
  • MySQL NOT IN 来自同一个表中的另一列

    我想运行 mysql 查询来选择表中的所有行films其中的值title该列不存在于另一列的所有值中的任何位置 collection 这是我的表格的简化版本 其中包含内容 mysql gt select from films id titl
  • SSE:跨页边界的未对齐加载和存储

    我在页面边界旁边执行未对齐加载或存储之前读过某处 例如使用 mm loadu si128 mm storeu si128内在函数 代码应首先检查整个向量 在本例中为 16 个字节 是否属于同一页 如果不属于同一页 则切换到非向量指令 我知道
  • 是否可以在 Java 8 中调试 Lambda

    我刚刚开始使用 Java 8 Lambda 我注意到我无法在 NetBeans IDE 中调试它们 如果我尝试将断点附加到以下代码 我会得到一个变量断点 这绝对不是我想要的 private EventListener myListener
  • 如何在 C# + XNA 中将音调更改为超过 1 或 -1?

    我需要拥有比 2 个八度音阶更多的自由来创建我想要的东西 但 XNA 却做不到 我确实意识到可能没有办法让程序接受更大 更小的值 但是有没有办法解决它 就像以最低音调发出声音 然后创建一个新的声音 这样我就可以降低它更多 None
  • ZK中如何在特定位置添加多个组件

    我正在 ZK 应用程序中工作 我需要添加
  • C# asp.net中WebForm中Winform的Textbox.KeyDown的交替事件是什么?

    在 WinForms 应用程序中 我可以有一个textbox1 keydown事件 但我想在 WebForm ASP NET 中实现同样的事情 那么我该怎么做呢 我需要从数据库中检索有关此事件的数据 您可以使用 onkeydown 事件 然
  • 消息 203,级别 16,状态 2,不是有效标识符

    我收到以下错误 消息 203 级别 16 状态 2 过程 getQuestion 第 18 行名称 select top 1 from tlb Question inner join tlb options on tlb options q
  • 使用 Visual C++ 在桌面上绘图

    我正在编写一个 opencv 应用程序 使用 Visual Studio VC 控制台应用程序使用激光束进行绘图 我想在桌面上画线 我知道绘图功能在 GDI32 dll 中可用 但对如何将 GDI32 dll 与我的 vc 代码集成感到困惑
  • Android 动画 GIF

    我正在尝试使用 WebView 显示动画 GIF 它在大多数设备上运行良好 但仍有一些设备不支持动画并显示静态 GIF 如何检测设备是否支持 WebView 中的动画 GIF 以便在不支持时显示适当的消息 是的 这似乎是一个常见问题 htt
  • 下划线反跳与参数

    假设我有这个事件处理程序 var mousewheel function e blah 但是 我想消除它 所以我这样做 它按预期工作 var mousewheelDebounced debounce mousewheel 500 docum
  • React-navigation、tintColor 在 props 验证中丢失

    我已将反应导航代码放入单独的 Routes 文件中 然后将其导入到 App js 文件中 一切工作正常 但我在 Atom Nuclide 中使用 Airbnb ESLint 配置 并收到了 TintColor 错误 道具验证中缺少tintC
  • 如何除以两个原生 JavaScript BigInt 并获得小数结果

    这是我到目前为止所尝试过的 我正在寻找一个12 34 BigInt 12340000000000000000 BigInt 1000000000000000000 12n Number BigInt 12340000000000000000
  • 合并张量流数据集批次

    请考虑下面的代码 import tensorflow as tf import numpy as np simple features np array 1 1 1 2 2 2 3 3 3 4 4 4 5 5 5 simple labels