Keras LSTM 输入形状的输入形状错误

2024-04-17

在 Keras 中使用时间序列时出现此错误:

ValueError: Error when checking input: expected lstm_1_input to have 3 dimensions, but got array with shape (31, 3)

这是我的功能:

def CreateModel(shape):
  """Creates Keras Model.

  Args:
    shape: (set) Dataset shape. Example: (31,3).

  Returns:
    A Keras Model.

  Raises:
    ValueError: Invalid shape
  """

  if not shape:
    raise ValueError('Invalid shape')

  logging.info('Creating model')
  model = Sequential()
  model.add(LSTM(4, input_shape=(31, 3)))
  model.add(Dense(1))
  model.compile(loss='mean_squared_error', optimizer='adam')
  return model

主要代码:

print(training_features.shape)
model = CreateModel(training_features.shape)
model.fit(
      training_features,
      training_label,
      epochs=FLAGS.epochs,
      batch_size=FLAGS.batch_size,
      verbose=FLAGS.keras_verbose_level)

完整错误:

Traceback (most recent call last):
  File "<embedded module '_launcher'>", line 149, in run_filename_as_main
  File "<embedded module '_launcher'>", line 33, in _run_code_in_main
  File "model.py", line 300, in <module>
    app.run(main)
  File "absl/app.py", line 433, in run
    _run_main(main, argv)
  File "absl/app.py", line 380, in _run_main
    sys.exit(main(argv))
  File "model.py", line 274, in main
    verbose=FLAGS.keras_verbose_level)
  File "keras/models.py", line 960, in fit
    validation_steps=validation_steps)
  File "keras/engine/training.py", line 1581, in fit
    batch_size=batch_size)
  File "keras/engine/training.py", line 1414, in _standardize_user_data
    exception_prefix='input')
  File "keras/engine/training.py", line 141, in _standardize_input_data
    str(array.shape))
ValueError: Error when checking input: expected lstm_1_input to have 3 dimensions, but got array with shape (31, 3)

代码最初来自here https://machinelearningmastery.com/time-series-prediction-with-deep-learning-in-python-with-keras/

我努力了:

training_features = numpy.reshape(
      training_features,
      (training_features.shape[0], 1, training_features.shape[1]))

但我得到:

ValueError: Input 0 is incompatible with layer lstm_1: expected ndim=3, found ndim=4

如果您的原始数据是(31,3)那么我认为您正在寻找的是training_features.shape =(31,3,1)。您可以通过以下行获得它...

training_features = training_features.reshape(-1, 3, 1)

这将简单地向现有数据添加一个新轴(-1 只是告诉 numpy 使用原始数据中的值计算出这个维度)。

您还需要修复模型的输入形状。 31 应该是数据中的样本数。这不包含在 Keras 中input_shape范围。你应该使用...

model.add(LSTM(4, input_shape=(3, 1)))

Keras 会自动将批量大小设置为None这意味着任意数量的样本都适用于该模型。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras LSTM 输入形状的输入形状错误 的相关文章

  • Python:在列表理解本身中引用列表理解?

    这个想法刚刚出现在我的脑海中 假设您出于某种原因想要通过 Python 中的列表理解来获取列表的唯一元素 i if i in created comprehension else 0 for i in 1 2 1 2 3 1 2 0 0 3
  • 将数据从 python pandas 数据框导出或写入 MS Access 表

    我正在尝试将数据从 python pandas 数据框导出到现有的 MS Access 表 我想用已更新的数据替换 MS Access 表 在 python 中 我尝试使用 pandas to sql 但收到错误消息 我觉得很奇怪 使用 p
  • 使 django 服务器可以在 LAN 中访问

    我已经安装了Django服务器 可以如下访问 http localhost 8000 get sms http 127 0 0 1 8000 get sms 假设我的IP是x x x x 当我这样做时 从同一网络下的另一台电脑 my ip
  • OpenCV Python cv2.mixChannels()

    我试图将其从 C 转换为 Python 但它给出了不同的色调结果 In C Transform it to HSV cvtColor src hsv CV BGR2HSV Use only the Hue value hue create
  • 如何在flask中使用g.user全局

    据我了解 Flask 中的 g 变量 它应该为我提供一个全局位置来存储数据 例如登录后保存当前用户 它是否正确 我希望我的导航在登录后在整个网站上显示我的用户名 我的观点包含 from Flask import g among other
  • 通过最小元素比较对 5 个元素进行排序

    我必须在 python 中使用元素之间的最小比较次数来建模对 5 个元素的列表进行排序的执行计划 除此之外 复杂性是无关紧要的 结果是一个对的列表 表示在另一时间对列表进行排序所需的比较 我知道有一种算法可以通过 7 次比较 总是在元素之间
  • 使用带有关键字参数的 map() 函数

    这是我尝试使用的循环map功能于 volume ids 1 2 3 4 5 ip 172 12 13 122 for volume id in volume ids my function volume id ip ip 我有办法做到这一点
  • 如何替换 pandas 数据框列中的重音符号

    我有一个数据框dataSwiss其中包含瑞士城市的信息 我想用普通字母替换带有重音符号的字母 这就是我正在做的 dataSwiss Municipality dataSwiss Municipality str encode utf 8 d
  • 根据列值突出显示数据框中的行?

    假设我有这样的数据框 col1 col2 col3 col4 0 A A 1 pass 2 1 A A 2 pass 4 2 A A 1 fail 4 3 A A 1 fail 5 4 A A 1 pass 3 5 A A 2 fail 2
  • AWS EMR Spark Python 日志记录

    我正在 AWS EMR 上运行一个非常简单的 Spark 作业 但似乎无法从我的脚本中获取任何日志输出 我尝试过打印到 stderr from pyspark import SparkContext import sys if name m
  • BeautifulSoup 中的嵌套标签 - Python

    我在网站和 stackoverflow 上查看了许多示例 但找不到解决我的问题的通用解决方案 我正在处理一个非常混乱的网站 我想抓取一些数据 标记看起来像这样 table tbody tr tr tr td td td table tr t
  • 添加不同形状的 numpy 数组

    我想添加两个不同形状的 numpy 数组 但不进行广播 而是将 缺失 值视为零 可能最简单的例子是 1 2 3 2 gt 3 2 3 or 1 2 3 2 1 gt 3 2 3 1 0 0 我事先不知道形状 我正在弄乱每个 np shape
  • 向 Altair 图表添加背景实心填充

    I like Altair a lot for making graphs in Python As a tribute I wanted to regenerate the Economist graph s in Mistakes we
  • 如何在seaborn displot中使用hist_kws

    我想在同一图中用不同的颜色绘制直方图和 kde 线 我想为直方图设置绿色 为 kde 线设置蓝色 我设法弄清楚使用 line kws 来更改 kde 线条颜色 但 hist kws 不适用于显示 我尝试过使用 histplot 但我无法为
  • 为字典中的一个键附加多个值[重复]

    这个问题在这里已经有答案了 我是 python 新手 我有每年的年份和值列表 我想要做的是检查字典中是否已存在该年份 如果存在 则将该值附加到特定键的值列表中 例如 我有一个年份列表 并且每年都有一个值 2010 2 2009 4 1989
  • Conda SafetyError:文件大小不正确

    使用创建 Conda 环境时conda create n env name python 3 6 我收到以下警告 Preparing transaction done Verifying transaction SafetyError Th
  • 发送用户注册密码,django-allauth

    我在 django 应用程序上使用 django alluth 进行身份验证 注册 我需要创建一个自定义注册表单 其中只有一个字段 电子邮件 密码将在服务器上生成 这是我创建的表格 from django import forms from
  • Python Selenium:如何在文本文件中打印网站上的值?

    我正在尝试编写一个脚本 该脚本将从 tulsaspca org 网站获取以下 6 个值并将其打印在 txt 文件中 最终输出应该是 905 4896 7105 23194 1004 42000 放置的动物 的 HTML span class
  • 如何将输入读取为数字?

    这个问题的答案是社区努力 help privileges edit community wiki 编辑现有答案以改进这篇文章 目前不接受新的答案或互动 Why are x and y下面的代码中使用字符串而不是整数 注意 在Python 2
  • Statsmodels.formula.api OLS不显示截距的统计值

    我正在运行以下源代码 import statsmodels formula api as sm Add one column of ones for the intercept term X np append arr np ones 50

随机推荐