如何使用有状态 LSTM 和 batch_size > 1 布置训练数据

2024-03-29

背景

我想在 Keras 中对“有状态”LSTM 进行小批量训练。我的输入训练数据位于一个大矩阵“X”中,其维度为 m x n,其中

m = number-of-subsequences
n = number-of-time-steps-per-sequence

X 的每一行都包含一个子序列,该子序列接续前一行上的子序列离开的位置。因此,给定一长串数据,

Data = ( t01, t02, t03, ... )

其中“tK”表示原始数据中位置 K 的标记,序列在 X 中布局如下:

X = [
  t01 t02 t03 t04
  t05 t06 t07 t08
  t09 t10 t11 t12
  t13 t14 t15 t16
  t17 t18 t19 t20
  t21 t22 t23 t24
]

Question

我的问题是,当我使用有状态 LSTM 对以这种方式布置的数据进行小批量训练时会发生什么。具体来说,小批量训练通常一次对“连续”的行组进行训练。因此,如果我使用大小为 2 的小批量,则 X 将分为三个小批量 X1、X2 和 X3,其中

X1 = [
  t01 t02 t03 t04
  t05 t06 t07 t08
]

X2 = [
  t09 t10 t11 t12
  t13 t14 t15 t16
]

X3 = [
  t17 t18 t19 t20
  t21 t22 t23 t25
]

请注意,这种类型的小批量处理与训练并不相符statefulLSTM,因为通过处理前一批的最后一列产生的隐藏状态不是与后续批次的第一列之前的时间步相对应的隐藏状态。

要看到这一点,请注意小批量将按照从左到右的方式进行处理,如下所示:

------ X1 ------+------- X2 ------+------- X3 -----
t01 t02 t03 t04 | t09 t10 t11 t12 | t17 t18 t19 t20
t05 t06 t07 t08 | t13 t14 t15 t16 | t21 t22 t23 t24

暗示着

- Token t04 comes immediately before t09
- Token t08 comes immediately before t13
- Token t12 comes immediately before t17
- Token t16 comes immediately before t21

但我希望小批量对行进行分组,以便我们在小批量之间获得这种时间对齐:

------ X1 ------+------- X2 ------+------- X3 -----
t01 t02 t03 t04 | t05 t06 t07 t08 | t09 t10 t11 t12
t13 t14 t15 t16 | t17 t18 t19 t20 | t21 t22 t23 t24

在 Keras 中训练 LSTM 时实现此目标的标准方法是什么?

感谢您在这里的任何指点。


解决方案 1 - 批量大小 = 1

好吧,既然看起来你实际上只有一个序列(虽然分开了,但它仍然是一个序列,对吧?),你确实必须使用等于 1 的批量大小进行训练。

如果您不想更改或重新组织数据,只需:

 X = X.reshape((-1,length,features))

     #where
         #length = 4 by your description    
         #features = 1 (if you have only one var over time, as it seems)

解决方案 2 - 重新组合长度 = 8

仍在使用一个批量大小为 1,重塑输入数据(在将其传递给模型之前),使其具有双倍长度。

最终结果将与您使用所描述的大小为 2 的小批量进行训练完全相同。(但请确保在模型的输入形状中将批量大小设置为 1,否则这会给您带来错误的结果)。

X = X.reshape((-1, 2 * length, features)) 

这会给你:

X = [
  [t01 t02 t03 t04 t05 t06 t07 t08]
  [t09 t10 t11 t12 t13 t14 t15 t16]
  [t17 t18 t19 t20 t21 t22 t23 t24]
]

解决方案 3 - 仅当您实际上有两个不同的序列时才可能

根据你的描述,你似乎只有一个序列。如果您确实有两个不同/独立的序列,那么您可以制作一批大小为 2 的批次。

如果将序列一分为二(并失去它们之间的连接)不是问题,您可以重新排列数据:

X = X.reshape((2,-1,length, features))

Then:

X0 = X[:,0]
X1 = X[:,1]
...

您可以尝试将其分组在一个数组中:

X = X.reshape((2,-1,length, features))
X = np.swapaxes(X,0,1).reshape((-1,length,features))

Then:

X0 = X[0]
X1 = X[1]
...

你可以尝试通过完整的X只要在模型中明确将批量大小设置为 2 即可进行训练输入形状.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用有状态 LSTM 和 batch_size > 1 布置训练数据 的相关文章

  • Keras 中批量大小可变的batch_dot

    我正在尝试编写一个层来合并 2 个张量formula https i stack imgur com I49aj png x 0 和x 1 的形状都是 1 500 M是500 500的矩阵 我希望输出为 500 500 我认为这在理论上是可
  • Tensorflow - 获取队列中的样本数量?

    对于性能监控 我想关注当前排队的示例 我正在平衡用于填充队列的线程数量和队列的最佳最大大小 我如何获得这些信息 我正在使用一个tf train batch 但我猜这些信息可能在下面的某个地方FIFOQueue 我本以为这是一个局部变量 但我
  • Keras如何在Relu激活函数中使用max_value

    keras activation py 中定义的 Relu 函数为 def relu x alpha 0 max value None return K relu x alpha alpha max value max value 它有一个
  • sigmoid激活函数可以用来解决Keras中的回归问题吗?

    我已经用 R 实现了简单的神经网络 但这是我第一次用 Keras 实现 所以希望得到一些建议 我在 Keras 中开发了一个神经网络函数来预测汽车销量 数据集可用here https github com allmydatasets dat
  • 使用张量流导出神经网络的权重

    我使用张量流工具编写了神经网络 一切正常 现在我想导出神经网络的最终权重以制定单一的预测方法 我怎样才能做到这一点 您需要在训练结束时使用以下命令保存模型tf train Saver https www tensorflow org ver
  • 在基本 Tensorflow 2.0 中运行简单回归

    我正在学习 Tensorflow 2 0 我认为在 Tensorflow 中实现最基本的简单线性回归是一个好主意 不幸的是 我遇到了几个问题 我想知道这里是否有人可以提供帮助 考虑以下设置 import tensorflow as tf 2
  • Scipy 稀疏 CSR 矩阵到 TensorFlow SparseTensor - 小批量梯度下降

    我有一个 Scipy 稀疏 CSR 矩阵 它是根据 SVM Light 格式的稀疏 TF IDF 特征矩阵创建的 特征数量巨大且稀疏 所以我必须使用 SparseTensor 否则速度太慢 例如 特征数量为 5 示例文件如下所示 0 4 1
  • 在 Tensorflow2 中将图冻结为 pb

    我们通过图形冻结保存来自 TF1 的许多模型 tf train write graph self session graph def some path get graph definitions with weights output g
  • tf-models:official.vision.detection Mask-RCNN 无效参数:indices[1,63] = [1, -1] 未索引到参数形状 [2,100,112,112]

    我正在尝试根据此处提供的官方 MaskRCNN 模型训练 Mask RCNN 模型 张量流 模型 https github com tensorflow models tree master official vision detectio
  • Keras model.summary() 结果 - 了解参数数量

    我有一个简单的神经网络模型 用于使用 Keras Theano 后端 从用 python 编写的 28x28px 图像中检测手写数字 model0 Sequential number of epochs to train for nb ep
  • Tensorflow 训练期间 GPU 使用率非常低

    我正在尝试为 10 类图像分类任务训练一个简单的多层感知器 这是 Udacity 深度学习课程作业的一部分 更准确地说 任务是对各种字体呈现的字母进行分类 数据集称为 notMNIST 我最终得到的代码看起来相当简单 但无论如何我在训练期间
  • 如何将体积补丁存储到 HDF5 中?

    我有一个尺寸的体积数据256x128x256 由于内存有限 我无法将整个数据直接输入到 CAFFE 因此 我会随机选择n sample补丁50x50x50从体积数据中提取并将其存储到 HDF5 中 我成功地从原始数据及其标签中随机提取了补丁
  • UnimplementedError:图形执行错误:在张量流上运行 nn

    我一直遇到这个错误 我不知道为什么 特别是因为我完全遵循某人的代码并且该人在运行此错误时没有错误 img shape 128 128 3 load pretrained model base model tf keras applicati
  • 如何用Python构建游戏神经网络?

    我是神经网络初学者 我想通过教计算机下跳棋来学习神经网络的基础知识 其实我想学的游戏是盛气凌人 http en wikipedia org wiki Domineering and Hex http en wikipedia org wik
  • Keras 错误:预计会看到 1 个数组

    当我尝试在 keras 中训练 MLP 模型时出现以下错误 我使用的是 keras 版本1 2 2 检查模型输入时出错 您输入的 Numpy 数组列表 传递给您的模型的尺寸不是模型预期的尺寸 预期的 查看 1 个数组 但得到以下 12859
  • GPU 上的 AWS SageMaker [已关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我正在尝试在 AWS 上训练神经网络 Tensorflow 我有一些 AWS 积分 据我了解 AWS
  • softmax_cross_entropy_with_logits 的 PyTorch 等效项

    我想知道 TensorFlow 是否有等效的 PyTorch 损失函数softmax cross entropy with logits TensorFlow 是否有等效的 PyTorch 损失函数softmax cross entropy
  • 如何使用文本和?

    我一直在关注this https github com tensorflow models tree master textsum使用 textsum 的链接 我已经使用提供的命令训练了模型 但我在 textsum log root 目录中
  • 使用 Keras 和 fit_generator 绘制 TensorBoard 分布和直方图

    我正在使用 Keras 使用 fit generator 函数训练 CNN 这似乎是一个已知问题 https github com fchollet keras issues 3358TensorBoard 在此设置中不显示直方图和分布 有
  • 使用决策树

    我知道 tl dr 我将尝试解释我的问题 而不会用大量蹩脚的代码来打扰您 我正在做一项学校作业 我们有蓝精灵的图片 我们必须通过前景背景分析来找到它们 我有一个 Java 决策树 其中包含所有数据 HSV 直方图 1 一个节点 然后尝试找到

随机推荐

  • android + eclipse + maven + actionbarsherlock

    我读了很多关于 actionbarsherlock maven android 的东西 但我见过的解决方案都不适合我 我确信我已经非常接近解决方案 但我不明白 我需要一些帮助 所以这是我的问题 我尝试创建一个依赖于 Actionbarshe
  • 如何删除空值?

    如何删除底部计数中的空值 即 我只想查看实际销售单位的产品 我尝试过非空和非空但没有成功 with member Measures Amount Sold as Measures Internet Sales Amount format s
  • 为什么“超时”不适用于管道?

    以下命令行调用timeout 这没有意义 只是出于测试原因 无法按预期工作 它会等待 10 秒 并且在 3 秒后不会停止命令的运行 为什么 timeout 3 ls sleep 10 您的命令正在执行的操作正在运行timeout 3 ls并
  • 在 Windows 上的 XAMPP 中哪里可以更改 lower_case_table_names=2 的值?

    我正在使用 Windows 7 和 XAMPP 我正在尝试导出数据库 在此过程中表名称将转换为小写 我搜索了很多 我知道我必须改变的值lower case table names from 0 to 2 但是我必须在哪里更改这个值 在哪个文
  • 将 TypeScript 网站从 GitHub 部署到 Azure

    我有一个 NET 网站 其中包含一些 TypeScript 文件 我尝试将其从 GitHub 部署为 Azure 网站 但收到与 TypeScript 相关的错误 在我看来 这可能与我使用最新版本 1 0 有关 而 kudu 版本只有 0
  • Google 端点和公共 Api 密钥

    要使用 Google 服务 您可以使用 OAuth 身份验证 或者 如果您不需要用户登录 则可以使用公共 api 密钥 将授权域定义为请求的来源 现在 我正在使用 google 端点编写自己的 API 并且我将允许用户通过公共 api 密钥
  • 使用sessionStorage有什么好处? [复制]

    这个问题在这里已经有答案了 只是想知道在存储要在 Javascript 轮播中使用的 HTML 内容时使用 HTML5 的 sessionStorage 的实际好处是什么 与性能有关吗 加载时间 带宽 是的 您将使用更少的带宽 这会提高性能
  • 使用 ggdendro 在树状图的片段下显示变量标签

    我的问题与安德里的有关answer https i stack imgur com JW0m1 png我之前的问题 我的问题是是否可以在树状图的相应段下显示变量标签和汽车标签 library ggplot2 library ggdendro
  • 扩展 Android 的默认 Gmail/电子邮件应用程序

    我想通过插入 ContentProvider 或使用意图过滤器来扩展 Android 平台的默认 Gmail 电子邮件应用程序 本质上 我希望能够扫描传入的电子邮件以查找将在我的 Android 应用程序中触发事件的特殊规则 如果自动扫描电
  • 立即终止无循环线程,无需中止或挂起

    我正在实现一个协议库 这里有一个简化的描述 main 函数中的主线程将始终检查网络流 在 tcpclient 内 上是否有某些数据可用 假设响应是收到的消息 线程是正在运行的线程 thread new Thread new ThreadSt
  • 在 Sparklyr 中创建虚拟变量?

    我正在尝试扩展我的一些 ML 管道 我喜欢 Sparklyr 打开的 rstudio spark 和 h2o 的组合 http spark rstudio com http spark rstudio com 我试图弄清楚的一件事是如何使用
  • 多个组的可反应聚合函数

    使用 Rreactable包中 我试图使用两个 groupBy 变量显示标记读数的百分比 在较低级别的分组中 这是计算正确的百分比 但在分组的第二 外部 级别上 它没有显示正确的百分比 这是数据 dat lt structure list
  • PHP:查询 MySQL 最快的方法是什么?因为 PDO 慢得令人痛苦

    我需要执行一个简单的查询 从字面上看 我需要执行的是 SELECT price sqft zipcode FROM homes WHERE home id X 当我使用 PHP 时PDO 我读过的是连接到 MySQL 数据库的推荐方法 简单
  • 如何通过id查找页面上的控件

    有没有一种简单的方法可以通过 id 在任何嵌套容器中 在 ASP NET 中查找控件 除了遍历整个控件树之外 像这个例子 TextBox tb new TextBox ID textboxId panel3 Controls Add tb
  • Spring Boot如何选择外部化的Spring属性文件

    我有这个配置需要用于 Spring Boot 应用程序 server port 8085 server servlet context path authserver data source spring jpa hibernate ddl
  • Windows 上的 Python 包:pip 还是本机安装程序?

    与使用打包安装程序 exe msi 相比 使用 pip 在 Windows 上安装 python 软件包的相对优点是什么 对于初学者来说 有些对我来说不起作用 MySQLdb 是 我的新规则 Try pip or easy install
  • NodeJS + Mysql 与 Docker Compose 2

    我正在尝试构建一个 docker compose 文件来在本地部署连接到 mysql 服务器的 NodeJS 应用程序 我已经尝试了所有方法 在 Stackoverflow 中阅读了大量教程和一些问题 但我不断收到 ECONNREFUSED
  • Apache .htaccess:如何在 Firefox 上用斜杠重写反斜杠?

    如何重写反斜杠 带斜线 在火狐浏览器上 Chrome IE Safari Opera 已构建浏览器用斜杠重写反斜杠 但 Firefox 3 6 13 回归404错误页面 Why Firefox returns 404 error page
  • 使用history.pushState()更新URL中的参数

    我在用history pushState在我的页面上进行 AJAX 调用后 将一些参数附加到当前页面 URL 现在 在基于用户操作的同一页面上 我想使用相同或附加的参数集再次更新页面 URL 所以我的代码如下所示 var pageUrl w
  • 如何使用有状态 LSTM 和 batch_size > 1 布置训练数据

    背景 我想在 Keras 中对 有状态 LSTM 进行小批量训练 我的输入训练数据位于一个大矩阵 X 中 其维度为 m x n 其中 m number of subsequences n number of time steps per s