使用大数据集在 Google Colab TPU 上训练 seq2seq 模型 - Keras

2024-05-19

我正在尝试使用 Google Colab TPU 上的 Keras 训练用于机器翻译的序列到序列模型。 我有一个可以加载到内存中的数据集,但我必须对其进行预处理才能将其提供给模型。特别是,我需要将目标单词转换为一个热向量,并且在许多示例中,我无法将整个转换加载到内存中,因此我需要生成一批数据。

我使用这个函数作为批处理生成器:

def generate_batch_bert(X_ids, X_masks, y, batch_size = 1024):
    ''' Generate a batch of data '''
    while True:
        for j in range(0, len(X_ids), batch_size):
          # batch of encoder and decoder data
          encoder_input_data_ids = X_ids[j:j+batch_size]
          encoder_input_data_masks = X_masks[j:j+batch_size]
          y_decoder = y[j:j+batch_size]
          

          # decoder target and input for teacher forcing
          decoder_input_data = y_decoder[:,:-1]
          decoder_target_seq = y_decoder[:,1:]
          
          # batch of decoder target data
          decoder_target_data = to_categorical(decoder_target_seq, vocab_size_fr)
          # keep only with the right amount of instances for training on TPU
          if encoder_input_data_ids.shape[0] == batch_size:
            yield([encoder_input_data_ids, encoder_input_data_masks, decoder_input_data], decoder_target_data)

问题是,每当我尝试运行 fit 函数时,如下所示:

model.fit(x=generate_batch_bert(X_train_ids, X_train_masks, y_train, batch_size = batch_size),
                    steps_per_epoch = train_samples//batch_size,
                    epochs=epochs,
                    callbacks = callbacks,
                    validation_data = generate_batch_bert(X_val_ids, X_val_masks, y_val, batch_size = batch_size),
                    validation_steps = val_samples//batch_size)

我收到以下错误:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_util.py:445 make_tensor_proto
    raise ValueError("None values not supported.")

ValueError: None values not supported.

不知道出了什么问题以及如何解决这个问题。

EDIT

我尝试在内存中加载较少的数据,以便转换为目标单词的一种热编码不会使内核崩溃,并且它实际上可以工作。所以我生成批次的方式显然有问题。


由于您不提供模型,因此很难判断出了什么问题 定义或任何示例数据。然而,我相当确定你是 遇到同样的情况TensorFlow 错误 https://github.com/tensorflow/tensorflow/issues/47769我最近被咬了。

解决方法是使用tensorflow.data有效的API 使用 TPU 效果更好。像这样:

from tensorflow.data import Dataset
import tensorflow as tf

def map_fn(X_id, X_mask, y):
    decoder_target_data = tf.one_hot(y[1:], vocab_size_fr)
    return (X_id, X_mask, y[:-1]), decoder_target_data
...
X_ids = Dataset.from_tensor_slices(X_ids)
X_masks = Dataset.from_tensor_slices(X_masks)
y = Dataset.from_tensor_slices(y)
ds = Dataset.zip((X_ids, X_masks, y)).map(map_fn).batch(1024)
model.fit(x = ds, ...)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用大数据集在 Google Colab TPU 上训练 seq2seq 模型 - Keras 的相关文章

  • 监控培训课程如何运作?

    我试图理解使用之间的区别tf Session and tf train MonitoredTrainingSession 以及我可能更喜欢其中之一 似乎当我使用后者时 我可以避免许多 杂务 例如初始化变量 启动队列运行程序或设置文件编写器以
  • 使用 Mac M1 在 Docker 容器内的 pip 安装中找不到 Tensorflow

    我正在尝试使用新的 Mac M1 运行一些项目 这些项目已经在英特尔处理器上运行 并被使用英特尔的其他开发人员使用 我无法构建这个简单的 Dockerfile FROM python 3 9 RUN python m pip install
  • 如何清除 tf.flags?

    如果我运行此代码两次 tf flags DEFINE integer batch size 2 batch size for training 我会得到这个错误 DuplicateFlagError The flag batch size
  • Keras Predict_classes 方法返回“列表索引超出范围”错误

    我对 CNN 和机器学习总体来说是新手 并且一直在尝试遵循 TensorFlow 的图像分类教程 现在 可以找到Google Colabhere https colab research google com drive 1gwZp7 t
  • 从 [tensorflow 1.00] 中的 softmax 层提取概率

    使用张量流 我有一个 LSTM 分类模型 以 softmax 作为最终节点 这是我的 softmax 层 with tf name scope Softmax as scope with tf variable scope Softmax
  • Tensorflow:为什么 tf.case 给我错误的结果?

    我正在尝试使用tf case https www tensorflow org api docs python tf case https www tensorflow org api docs python tf case 有条件地更新张
  • Tensorflow图像读取空

    这个问题是基于 Tensorflow图像读取与显示 https stackoverflow com questions 33648322 tensorflow image reading display 根据他们的代码 我们得到以下内容 s
  • 用于测试张量流安装的速度基准

    我怀疑我的 GPU 机器上是否正确配置了张量流 因为在我精美的 GPU 机器上训练一个简单的线性回归模型 批量大小 32 1500 个输入特征 150 个输出变量 的每次迭代速度比在笔记本电脑上慢 100 倍 我使用的是 Titan X 配
  • 为什么 scikit learn 的平均精度分数返回 nan?

    我的 Keras 模型旨在接收两个输入时间序列 将它们连接起来 通过 LSTM 提供它们 并在下一个时间步骤中进行多标签预测 有 50 个训练样本 每个样本有 24 个时间步 每个样本有 5625 个标签 有 12 个验证样本 每个样本有
  • tf.gfile 在 TensorFlow 中起什么作用?

    我见过人们使用以下几个函数tf gfile例如tf gfile GFile or tf gfile Exists 我有一个想法tf gfile处理文件 但是 我无法找到官方文档来了解它还提供了什么 如果你能帮我的话那就太好了 对于登陆这里的
  • 使用 Keras 时,验证集中未见的类别会出现错误

    我有由数值变量和分类变量组成的数据 分类变量有很多类别 因此我使用嵌入来表示这些类别 我的模型是一个简单的神经网络 我知道当你定义嵌入层时你需要通过input dim number of categories 1为了解释训练中看不见的类别
  • 使用 flow_from_dataframe y_col 的正确“值”是什么

    我正在用 pandas 读取 csv 文件 并给出存储在中的列名称colname colnames file label Read data from file data pd read csv Hand Annotations 2 csv
  • conv1D 中形状的尺寸

    我尝试过构建一个只有一层的 CNN 但遇到了一些问题 事实上 编译器告诉我 ValueError 检查模型输入时出错 预期的 conv1d 1 input 具有 3 个维度 但得到形状为 569 30 的数组 这是代码 import num
  • Tensorflow如何生成不平衡组合数据集

    我对新数据集 API tensorflow 1 4 有疑问 我有两个数据集 我需要创建一个组合的不平衡数据集 即 每个批次应包含第一个数据集中一定数量的元素和第二个数据集中一定数量的元素 例如 dataset1 tf data Datase
  • 从字符串列表创建 TfRecords 并在解码后在张量流中提供图形

    目的是创建 TfRecords 数据库 给定 我有 23 个文件夹 每个文件夹包含 7500 个图像 以及 23 个文本文件 每个文件有 7500 行描述单独文件夹中 7500 个图像的特征 我通过以下代码创建了数据库 import ten
  • 如何强制tensorflow使用所有可用的GPU?

    我有一个 8 GPU 集群 当我运行Kaggle 的一段 Tensorflow 代码 https www kaggle com keegil keras u net starter lb 0 277 scriptVersionId 2164
  • 如何将两个 keras 模型连接成一个模型?

    假设我有一个 ResNet50 模型 我希望将该模型的输出层连接到 VGG 模型的输入层 这是 ResNet 模型和 ResNet50 的输出张量 img shape 164 164 3 resnet50 model ResNet50 in
  • Tensorflow seq2seq 获取序列隐藏状态

    我不久前才开始研究tensorflow 我正在研究 seq2seq 模型 并以某种方式让教程起作用 但我一直坚持获取每个句子的状态 据我了解 seq2seq 模型采用输入序列并通过 RNN 为序列生成隐藏状态 随后 模型使用序列的隐藏状态来
  • batch_size = x.shape[0] AttributeError: 'tuple' 对象没有属性 'shape'

    该代码结合图像和掩模进行图像检测 我怎样才能纠正这个错误 batch size x shape 0 AttributeError tuple 对象没有属性 shape 这是用于训练的代码 train datagen ImageDataGen
  • Tensorboard——High-level节点的计算时间与其子节点计算时间的总和不同

    继tutorial https www tensorflow org programmers guide graph viz在 TensorFlow 上 我试图使用张量板来理解运行时统计数据 我发现代表名称范围的高级节点的计算时间不等于其子

随机推荐

  • symfony easyadmin 自定义表单生成器

    我使用 symfony 3 4 和 easycorp easyadmin bundle 1 17 配置表单 easyadmin form fields type group label Basic Information icon enve
  • 使用预训练的 word2vec 初始化 Seq2seq 嵌入

    我对使用预训练的 word2vec 初始化tensorflow seq2seq 实现感兴趣 我已经看过代码了 嵌入似乎已初始化 with tf variable scope scope or embedding attention deco
  • “gld/st_throughput”和“dram_read/write_throughput”指标之间有什么区别?

    在 CUDA 可视化分析器版本 5 中 我知道 gld st requested throughput 是应用程序请求的内存吞吐量 然而 当我试图找到硬件的实际吞吐量时 我很困惑 因为有两对似乎合格的指标 它们是 gld st throug
  • XAML 构建的本地 TFS 到 VSTS 迁移

    目前 我们在本地使用 TFS 2017 update 1 但我们必须在 VSTS 云平台上迁移 TFS 此外 我们还使用自定义构建模板在本地使用 TFS 构建服务器进行 XAML 构建 我们的问题是迁移后所有 XAML 构建定义是否都能正常
  • Android 启动器快捷方式

    我制作了一个简单的打卡 打卡时钟应用程序 我想向用户添加在主屏幕上创建快捷方式的选项 该快捷方式将切换应用程序的状态 超时 超时 但我根本不希望此快捷方式在屏幕上打开应用程序 这是我的 setupShortcut private void
  • 如何使用网格分割图像并保留透明度边界框

    我有一些 png 图像 我想将其分成几个部分 例如按网格或大小 但每个部分应具有与原始图像相同的边界框 透明度 Example 将图像分成两部分 原来的 200 89 Output 部分 1 png 200 89 第2部分 png 200
  • 执行 `EXECUTE IMMEDIATE ` Oracle 语句出现错误

    我是 Oracle 的新手 当我执行以下语句时 BEGIN EXECUTE IMMEDIATE SELECT FROM DUAL END 我得到错误为 命令中从第 2 行开始出错 立即开始执行 从双选择 结尾 错误报告 ORA 00911
  • ROOM迁移过程中如何处理索引信息

    CODE Entity tableName UserRepo indices Index value id unique true public class GitHubRepo PrimaryKey autoGenerate true p
  • 带有 Core Data 对象的动态 UITableView 高度

    过去几天我一直在试图解决一个谜团 即为什么我的批处理大小为 20 的 NSFetchedResultsController 总是在获取完成后立即错误 即加载到内存中 我的所有对象 从而导致请求需要约 20 秒 事实证明 这是因为在我的 he
  • SimaPro 项目中参数不确定的活动的蒙特卡罗 LCA 返回恒定值(无不确定性)

    我从 SimaPro 导入了一个项目 其中几乎每个活动都使用具有不确定性的参数 当我在 Brightway 中对其中任何一个运行蒙特卡洛 LCA 时 结果都是恒定的 就好像数量没有不确定性一样 代码片段显示 10 个步骤 但对于 2000
  • MediaCodec 创建输入表面

    我想使用 MediaCodec 将 Surface 编码为 H 264 使用 API 18 有一种方法可以通过调用 createInputSurface 然后在该表面上绘图来对表面中的内容进行编码 我在 createInputSurface
  • 增加雷达图中长轴标签的空间

    我想创建一个雷达图ggirahExtra ggRadar 问题是我的标签很长并且被剪掉了 我想我可以通过添加在标签和绘图之间创建更多空间margin margin 0 0 2 0 cm to element text in axis tex
  • 如何在C(Linux)中的while循环中准确地睡眠?

    在 C 代码 Linux 操作系统 中 我需要在 while 循环内准确地休眠 比如说 10000 微秒 1000 次 我尝试过usleep nanosleep select pselect和其他一些方法 但没有成功 一旦大约 50 次 它
  • 查找进程的完整路径

    我已经编写了 C 控制台应用程序 当我启动应用程序时 不使用cmd 我可以看到它列在任务管理器的进程列表中 现在我需要编写另一个应用程序 在其中我需要查找以前的应用程序是否正在运行 我知道应用程序名称和路径 所以我已将管理对象搜索器查询写入
  • 在 Flutter 中显示 iOS 的 PDF 内联文件

    我正在 flutter 中专门为 iOS 开发一个应用程序 现阶段 我需要向其中添加 PDF 文件 问题是 flutter 没有原生的方式来显示 PDF 文件 据我研究 由此tread https github com flutter fl
  • Text::平衡和多行 xml

    看来我有点失落了 我需要解析一个大的 大约 100 mb 且相当难看的 xml 文件 如果我使用parsefile 它返回错误 文档元素后的垃圾 但它会很乐意解析文件的较小元素 所以我决定将文件分解为元素并解析它们 由于不鼓励使用正则表达式
  • Java中接口作为方法参数

    前几天去面试 被问到了这样的问题 问 反转链表 给出以下代码 public class ReverseList interface NodeList int getItem NodeList nextNode void reverse No
  • nodemon 安装错误“没有可用于超时的有效版本”

    尝试在全新的节点项目中安装 nodemon 时出现此错误 我创建了一个名为 my project 的空白文件夹 然后 在其中 我执行了创建一个 package json 文件 npm init f 然后当尝试运行时 npm install
  • RemoteAuthentication 错误:OpenIdConnectAuthenticationHandler:message.State 为 null 或为空

    RemoteAuthentication 错误 OpenIdConnectAuthenticationHandler message State 为 null 或为空 即使成功获取代码 id token 和 token 后 我将 Razor
  • 使用大数据集在 Google Colab TPU 上训练 seq2seq 模型 - Keras

    我正在尝试使用 Google Colab TPU 上的 Keras 训练用于机器翻译的序列到序列模型 我有一个可以加载到内存中的数据集 但我必须对其进行预处理才能将其提供给模型 特别是 我需要将目标单词转换为一个热向量 并且在许多示例中 我