具有 tf 数据集输入的 Tensorflow keras

2023-11-25

我是张量流 keras 和数据集的新手。谁能帮我理解为什么下面的代码不起作用?

import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.keras.utils import multi_gpu_model
from tensorflow.python.keras import backend as K


data = np.random.random((1000,32))
labels = np.random.random((1000,10))
dataset = tf.data.Dataset.from_tensor_slices((data,labels))
print( dataset)
print( dataset.output_types)
print( dataset.output_shapes)
dataset.batch(10)
dataset.repeat(100)

inputs = keras.Input(shape=(32,))  # Returns a placeholder tensor

# A layer instance is callable on a tensor, and returns a tensor.
x = keras.layers.Dense(64, activation='relu')(inputs)
x = keras.layers.Dense(64, activation='relu')(x)
predictions = keras.layers.Dense(10, activation='softmax')(x)

# Instantiate the model given inputs and outputs.
model = keras.Model(inputs=inputs, outputs=predictions)

# The compile step specifies the training configuration.
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
          loss='categorical_crossentropy',
          metrics=['accuracy'])

# Trains for 5 epochs
model.fit(dataset, epochs=5, steps_per_epoch=100)

它失败并出现以下错误:

model.fit(x=dataset, y=None, epochs=5, steps_per_epoch=100)
File "/home/wuxinyu/pyEnv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training.py", line 1510, in fit
validation_split=validation_split)
File "/home/wuxinyu/pyEnv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training.py", line 994, in _standardize_user_data
class_weight, batch_size)
File "/home/wuxinyu/pyEnv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training.py", line 1113, in _standardize_weights
exception_prefix='input')
File "/home/wuxinyu/pyEnv/lib/python3.5/site-packages/tensorflow/python/keras/engine/training_utils.py", line 325, in standardize_input_data
'with shape ' + str(data_shape))
ValueError: Error when checking input: expected input_1 to have 2 dimensions, but got array with shape (32,)

根据 tf.keras 指南,我应该能够直接将数据集传递给 model.fit,如本例所示:

输入 tf.data 数据集

使用数据集 API 扩展到大型数据集或多设备训练。将 tf.data.Dataset 实例传递给 fit 方法:

# Instantiates a toy dataset instance:
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)
dataset = dataset.repeat()

不要忘记指定steps_per_epoch打电话时fit在数据集上。

model.fit(数据集,epochs=10,steps_per_epoch=30) 这里,fit方法使用steps_per_epoch参数——这是模型在进入下一个纪元之前运行的训练步骤数。由于数据集产生批量数据,因此此代码片段不需要batch_size。

数据集也可用于验证:

dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32).repeat()

val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels))
val_dataset = val_dataset.batch(32).repeat()

model.fit(dataset, epochs=10, steps_per_epoch=30,
      validation_data=val_dataset,
      validation_steps=3)

我的代码有什么问题,正确的方法是什么?


对于您最初的问题,即为什么会收到错误:

Error when checking input: expected input_1 to have 2 dimensions, but got array with shape (32,)

您的代码中断的原因是因为您没有应用.batch()回到dataset变量,像这样:

dataset = dataset.batch(10)

你只是简单地打电话dataset.batch().

这会打破,因为没有batch()输出张量没有批量处理,即你得到了形状(32,)代替(1,32).

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

具有 tf 数据集输入的 Tensorflow keras 的相关文章

随机推荐

  • 每个端点的不同服务行为

    情况 我们正在某些 WCF 服务上实施不同类型的安全性 客户端证书 用户名和密码以及匿名 我们有 2 个 ServiceBehaviorConfigurations 一种用于 httpBinding 一种用于 wsHttpBinding 我
  • 有没有办法可靠地检测CPU核心总数?

    我需要一种可靠的方法来检测计算机上有多少个 CPU 核心 我正在创建一个数值密集型模拟 C 应用程序 并希望创建最大数量的运行线程作为核心 我已经尝试了互联网上建议的许多方法 例如Environment ProcessorCount 使用W
  • 对路径的访问被拒绝 - File.Move 失败,但 File.Delete 有效

    我正在尝试执行一个简单的File Move操作但我得到 System UnauthorizedAccessException 异常 对路径的访问被拒绝 据我所知 没有任何东西正在使用我试图移动的文件 包含文件夹也已关闭 我可以通过文件资源管
  • 以编程方式设置 iPhone 模拟器位置

    我刚刚更新到 XCode 4 2 发现了一个很酷的功能 可以让我手动设置设备位置 有谁知道如何以编程方式完成同样的事情 我想在一些单元测试中设置位置 以下 AppleScript 将允许您设置 iOS 模拟器的位置 应该可以将这种脚本集成到
  • 将 javascript 数组传递给 servlet

    我已经看过有关此主题的先前问题 但我的问题尚未解决 我将数组从 javascript 传递到 servlet JavaScript 代码 var action new Array function getProtAcionValues ro
  • prism/mvvm:将列绑定到 DataGrid

    我正在使用标准的 NET DataGrid 如下所示
  • 将二维数组表示为一维数组[重复]

    这个问题在这里已经有答案了 可能的重复 实现矩阵 使用数组的数组 2D 还是一维数组 哪个更有效 二维数组与一维数组的性能 有一天 我正在查看我朋友的一个分子动力学代码库 他将一些二维数据表示为一维数组 因此 他不必使用两个索引 而只需要跟
  • C++:'cout << 指针 << ++pointer' 生成编译器警告

    我这里有一个C 学习演示 char c M short s 10 long l 1002 char cptr c short sptr s long lptr l cout lt lt cptr t lt lt static cast
  • Rails 中的路径解析

    我正在寻找解析路由路径的方法 如下所示 ActionController Routing new post path parse gt controller gt posts action gt index 应该是相反的url for Up
  • 如何在 Cython 中声明 2D 列表

    我正在尝试编译这种代码 def my func double c int m cdef double f m m f c for x in range m for y in range m 这引发了 Error compiling Cyth
  • 在MySQL查询中将部分非数字文本转换为数字

    是否可以在 MySQL 查询中将文本转换为数字 我有一个带有标识符的列 该标识符由名称和数字组成 格式为 名称 数字 该列具有 VARCHAR 类型 我想根据数字 具有相同名称的行 对行进行排序 但列是根据字符顺序排序的 即 name 1
  • wget 中的递归下载如何工作?

    wget 用于镜像站点 但我想知道该实用程序如何下载该域的所有 URL wget r www xyz com wget如何下载域xyz的所有URL 它是否像爬虫一样访问索引页面并解析它并提取链接 简短回答 通常 是的 Wget 会抓取所有
  • 在浏览器中显示致命/通知错误

    嗯 我刚刚开始使用 hhvm hack 但我想向浏览器显示错误 但没有成功 我将ini设置设置如下 error reporting E ALL ini set display errors 1 根据 var dumpini get值被设置为
  • 如何从 Adwords API 中提取数据并将其放入 Pandas Dataframe 中

    我正在使用 Python 从 Google AdWords API 中提取数据 我想将该数据放入 Pandas DataFrame 中 以便我可以对数据进行分析 我正在使用 Google 提供的示例here 下面是我尝试将输出读取为 pan
  • 使用反射调用静态方法时如何通过 ref 传递参数?

    我使用反射在对象上调用静态方法 MyType GetMethod MyMethod BindingFlags Static Invoke null new object Parameter1 Parameter2 如何通过引用而不是通过值传
  • Twitter Streaming API - urllib3.exceptions.ProtocolError: ('连接中断:IncompleteRead

    使用 tweepy 运行一个 python 脚本 该脚本在英语推文的随机样本中流式传输 使用 twitter 流 API 一分钟 然后交替搜索 使用 twitter 搜索 API 一分钟 然后返回 我发现的问题是 大约 40 多秒后 流媒体
  • Windows系统上的IOS编程[重复]

    这个问题在这里已经有答案了 我有兴趣学习 IOS 编程 但目前我无法访问 Macintosh 系统 只是想知道是否有适用于 Windows 的 IOS SDK 的等效版本 不过我有一部 iPhone 您可以查看GNUStep这是一个跨平台的
  • 设置 Z 索引不起作用。容器后面的按钮(HTML - CSS)

    I using Metro css windows 8 style and have a problem I have container with alerts the blue in the picture and above ther
  • MS Access VBA - 在数据表子表单中显示动态构建的 SQL 结果

    我在 MS Office 应用程序 用于自动化和 ETL 流程 中拥有多年使用 VBA 的经验 但直到最近才需要处理 MS Access 中的表单 我正在为我设计的数据库设计一些简单的数据提取表单 并专注于看似简单的任务 目标 我需要一个数
  • 具有 tf 数据集输入的 Tensorflow keras

    我是张量流 keras 和数据集的新手 谁能帮我理解为什么下面的代码不起作用 import tensorflow as tf import tensorflow keras as keras import numpy as np from