将 SageMaker 管道模式与 tfrecords 的 s3 目录结合使用

2024-02-14

我打电话给sagemaker.tensorflow.TensorFlow.fit()当我使用时无限期挂起,没有错误消息Pipe代替File as the input_mode。我相应地替换了TensorFlowDataset with Pipemodedataset。此次培训在File模式成功完成。

我的数据由两个 s3 存储桶组成,每个存储桶中有多个 tfrecord 文件。尽管广泛查看了文档,但我对如何使用并没有信心Pipemodedataset在这种情况下 - 具体来说,如何设置channel.

这是我的 Sagemaker 笔记本设置:

hyperparameters = {
    "batch-size": 1,
    "pipe_mode": 1,
}

estimator_config = {
    "entry_point": "tensorflow_train.py",
    "source_dir": "source",
    "framework_version": "2.3",
    "py_version": "py37",
    "instance_type": "ml.p3.2xlarge",
    "instance_count": 1,
    "role": sagemaker.get_execution_role(),
    "hyperparameters": hyperparameters,
    "output_path": f"s3://{bucket_name}",
    "input_mode": "Pipe",
}

tf_estimator = TensorFlow(**estimator_config)

s3_data_channels = {
    "training": f"s3://{bucket_name}/data/training",
    "validation": f"s3://{bucket_name}/data/validation",
}

tf_estimator.fit(s3_data_channels)

如果我要跑aws s3 ls on the s3_data_channels,我会得到 tfrecord 文件的列表。

这是我设置数据集的方式(请参阅 if / else 语句,具体取决于是否pipe_mode被选中:

import tensorflow as tf

if __name__ == "__main__":

    arg_parser = argparse.ArgumentParser()
    ...
    arg_parser.add_argument("--pipe_mode", type=int, default=0)

    arg_parser.add_argument("--train_dir", type=str, default=os.environ.get("SM_CHANNEL_TRAINING"))
    arg_parser.add_argument(
        "--validation_dir", type=str, default=os.environ.get("SM_CHANNEL_VALIDATION")
    )
    arg_parser.add_argument("--model_dir", type=str)
    args, _ = arg_parser.parse_known_args()

    AUTOTUNE = tf.data.experimental.AUTOTUNE

    if args.pipe_mode == 1:
        from sagemaker_tensorflow import PipeModeDataset
        train_ds = PipeModeDataset(channel="training", record_format='TFRecord')
        val_ds = PipeModeDataset(channel="validation", record_format='TFRecord')

    else:
        train_files = tf.data.Dataset.list_files(args.train_dir + '/*tfrecord')
        val_files = tf.data.Dataset.list_files(args.validation_dir + '/*tfrecord')
        train_ds = tf.data.TFRecordDataset(filenames=train_files, num_parallel_reads=AUTOTUNE)
        val_ds = tf.data.TFRecordDataset(filenames=val_files, num_parallel_reads=AUTOTUNE)

    train_ds = (
        train_ds.map(tfrecord_parser, num_parallel_calls=AUTOTUNE)
        .batch(args.batch_size)
        .prefetch(AUTOTUNE)
    )

    val_ds = (
        val_ds.map(tfrecord_parser, num_parallel_calls=AUTOTUNE)
        .batch(args.batch_size)
        .prefetch(AUTOTUNE)
    )
    ...

我遇到了同样的问题,使用管道模式时 model.fit() 无限期地卡住了。经过一些研究并尝试了许多改变,它通过定义解决了每个时期的步骤拟合模型时。

我想当使用文件模式时它已经知道每个时期会有多少步,但是使用管道模式你必须手动指定它

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

将 SageMaker 管道模式与 tfrecords 的 s3 目录结合使用 的相关文章

  • 如何在对象检测 API Tensorflow 中仅检测人体

    我在用tensorflow对象检测 API 用于检测对象 它在我的 Windows 系统中运行良好 我如何对其进行更改以仅检测提到的对象 例如 我只想检测人类而不是所有对象 根据此中的第 1 条评论answer https stackove
  • 如何将 std::vector 转换为张量而不在 C++ 中的张量流中进行复制?

    在c 中 多维矩阵存储在std vector
  • 如何在nodejs(tensorflow.js)中训练模型?

    我想做一个图像分类器 但我不会python Tensorflow js 使用我熟悉的 javascript 可以用它来训练模型吗 训练步骤是什么 坦白说 我不知道从哪里开始 我唯一想到的是如何加载 mobilenet 它显然是一组预先训练的
  • ValueError:形状(无,50)和(无,1)在 Tensorflow 和 Colab 中不兼容

    我正在使用 LSTM 训练 Tensorflow 模型以进行预测维护 对于每个实例 我创建一个矩阵 50 4 其中 50 是历史序列的长度 4 是每个记录的特征数量 因此为了训练模型 我使用例如 55048 50 4 张量和 55048 1
  • 修改Keras中的层权重

    我正在尝试修改 Keras 中某个层的输出 我有一个编码器 它将时间序列转换为潜在空间 之后 对于每个压缩的时间序列 我想向时间序列添加一些数字 例如我有 input d Input 100 h1 d Reshape 100 1 input
  • 导入tensorflow时,出现以下错误:没有名为“numpy.core._multiarray_umath”的模块

    我已经安装了 Ancaconda3 和 Tensorflow 当我尝试在 python shell 中导入 Tensorflow 时 收到以下错误 ModuleNotFoundError 没有名为 numpy core multiarray
  • 用于测试张量流安装的速度基准

    我怀疑我的 GPU 机器上是否正确配置了张量流 因为在我精美的 GPU 机器上训练一个简单的线性回归模型 批量大小 32 1500 个输入特征 150 个输出变量 的每次迭代速度比在笔记本电脑上慢 100 倍 我使用的是 Titan X 配
  • tf.gfile 在 TensorFlow 中起什么作用?

    我见过人们使用以下几个函数tf gfile例如tf gfile GFile or tf gfile Exists 我有一个想法tf gfile处理文件 但是 我无法找到官方文档来了解它还提供了什么 如果你能帮我的话那就太好了 对于登陆这里的
  • 从 Keras 检查点加载

    我正在 Keras 中训练一个模型 我使用以下代码保存了所有内容 filepath project model hdh5 checkpoint ModelCheckpoint project model hdf5 monitor loss
  • conv1D 中形状的尺寸

    我尝试过构建一个只有一层的 CNN 但遇到了一些问题 事实上 编译器告诉我 ValueError 检查模型输入时出错 预期的 conv1d 1 input 具有 3 个维度 但得到形状为 569 30 的数组 这是代码 import num
  • Tensorflow新Op CUDA内核内存管理

    我已经使用 GPU CUDA 内核在 Tensorflow 中实现了一个相当复杂的新 Op 该操作需要大量动态内存分配 这些变量不是张量 并且在操作完成后被释放 更具体地说 它涉及使用哈希表 现在我正在使用cudaMalloc and cu
  • 将 Pytorch 模型 .pth 转换为 onnx 模型

    我有一个预训练的模型 其格式为 pth 扩展名 我想将其转换为 Tensorflow protobuf 但我没有找到任何方法来做到这一点 我见过 onnx 可以将模型从 pytorch 转换为 onnx 然后从 onnx 转换为 Tenso
  • Tensorflow seq2seq 获取序列隐藏状态

    我不久前才开始研究tensorflow 我正在研究 seq2seq 模型 并以某种方式让教程起作用 但我一直坚持获取每个句子的状态 据我了解 seq2seq 模型采用输入序列并通过 RNN 为序列生成隐藏状态 随后 模型使用序列的隐藏状态来
  • 有没有办法在bigquery中使用kmeans、tensorflow保存的模型?

    我知道这有点愚蠢 因为 BigQueryML 现在为 Kmeans 提供了良好的初始化 尽管如此 我还是需要在张量流中训练一个模型 然后将其传递给 BigQuery 进行预测 我保存了模型 一切正常 直到我尝试将其上传到 bigquery
  • 需要 TensorFlow 依赖项。如何在 Windows 上运行 TensorFlow

    我有兴趣让 TensorFlow 在 Windows 上运行 但目前我意识到这是不可能的 因为某些依赖项无法在 Windows 上使用 例如巴泽尔 之所以出现这种需求 是因为据我目前了解 从 TensorFlow 访问 GPU 的唯一方法是
  • 交换keras中的张量轴

    我想将图像批次的张量轴从 batch size row col ch 交换为 批次大小 通道 行 列 在 numpy 中 这可以通过以下方式完成 X batch np moveaxis X batch 3 1 我该如何在 Keras 中做到
  • 在tensorflow.js中对张量进行分区、屏蔽或过滤

    我有 2 个相同长度的张量 data and groupIds 我想分开data通过相应的值分成几组groupId 例如 const data tf tensor 1 2 3 4 5 const groupIds tf tensor 0 1
  • 如何解释tf.map_fn的结果?

    看代码 import tensorflow as tf import numpy as np elems tf ones 1 2 3 dtype tf int64 alternates tf map fn lambda x x x x el
  • GPU 上的张量流:尽管 cuda 的 deviceQuery 返回“PASS”结果,但没有已知设备

    注 这个问题最初是在github上问的 https github com tensorflow tensorflow issues 7648 issuecomment 280866214 但被要求改为在这里 我在 GPU 上运行 Tenso
  • Keras CNN 回归模型损失低,准确度为 0

    我在 keras 中遇到这个 NN 回归模型的问题 我正在研究一个汽车数据集 以根据 13 个维度预测价格 简而言之 我已将其读取为 pandas 数据帧 将数值转换为浮点数 缩放值 然后对分类值使用 one hot 编码 这创建了很多新列

随机推荐

  • 似乎无法解析 KeyedByTypeCollection?

    我正在使用 NET 可移植来创建一个库 并且在尝试创建 KeyedByTypeCollection 的实例时遇到了一些问题 我检查了我的参考资料 NET Portable Subset System Collections Generic
  • React Native - SQLite 找不到预填充的数据库文件

    我正在尝试使用https github com andpor react native sqlite storage https github com andpor react native sqlite storage对于 SQLite
  • 我们如何在没有第三个变量和算术运算符的情况下交换两个数字?

    我们如何在没有第三个变量和算术运算符的情况下交换两个数字 XOR算不算算术运算符 如果没有 那么 X X XOR Y Y X XOR Y X X XOR Y 将此伪代码转换为可编译的 Java 代码作为练习留给读者 关于 java 标签
  • docker-compose如何引用其他目录中的文件

    有这个 dockerfile FROM python 3 8 3 alpine ENV MICRO SERVICE home app microservice RUN addgroup S APP USER adduser S APP US
  • CordovaWebView 与 android 中的 onBackPressed 方法混淆

    正如标题所说CordovaWebView and onBackPressed在 android 中组合起来会产生奇怪的结果 我有混合应用程序 我的主要活动有DrawerLayout and CordovaWebView 我的 onBackP
  • android ndk数据保存/加载

    我正在致力于将 PC OpenGL 应用程序移植到 Android 上 我选择使用 NDK android native app glue 框架 据我了解 它允许我继续使用 C 甚至不编写任何 JAVA 代码行 听起来很有希望 对我来说第一
  • 将加密的 csv 导入 Python 3

    因此 我计划使用 Jupyter Notebook Python 3 进行一些数据分析 出于协作原因 我想将数据存储在 github 存储库上 但数据集很敏感 因此 我想将数据 当前为 csv 作为加密文件存储在存储库上 然后在运行时解密
  • .NET ORM、不可变值对象、结构、默认构造函数和只读属性

    我刚刚开始使用 NET ORM 甚至还没有在 Entity Framework 和 NHibernate 之间做出决定 但在这两种情况下 我都遇到了一个问题 因为他们似乎希望我以各种方式损害域模型的完整性 特别是在 C 对象设计的更精细的方
  • (numpy) __array_wrap__ 有什么作用?

    我第一次深入 SciPy LinAlg 模块 我看到了这个函数 def makearray a new asarray a wrap getattr a array prepare new array wrap return new wra
  • 'quietly = TRUE' 何时在 require() 函数中真正起作用?

    我正在尝试编写一组函数来检查丢失的 R 软件包 并在必要时安装它们 StackOverflow 上有一些很好的代码可以做到这一点 从这里开始 https stackoverflow com questions 4090169 elegant
  • 在更改视图的可见性时应用动画

    我的应用程序中有一个 Horizo ntalScrollView 并且我经常使用它的可见性 可见和消失 所以我想要的是 我可以应用某种动画或其他东西 使其开始以滑动的方式变得可见和不可见 而不是突然使其可见和不可见吗 任何帮助或建议将不胜感
  • 本地主机上的 Django/Celery 多个队列 - 路由不起作用

    我跟着芹菜docs http celery readthedocs org en latest userguide routing html manual routing在我的开发机器上定义 2 个队列 我的芹菜设置 CELERY ALWA
  • 从 Eclipse 中删除插件的正确方法

    上次 我遇到了从 Eclipse 中删除插件的问题 症状 1 如果删除通过已安装菜单 无法正确重新安装并且有多个视角 例如对于 SQL 资源管理器 在Open Perspective menu 2 如果通过文件系统删除 手动从plugins
  • 人员 API 谷歌配额限制

    我正在研究 People API 这仅适用于 google 用户 有人知道吗 我一天 分钟可以免费询问多少次 一般配额限制是多少 超过门槛需要花费多少钱 Thanks 有两种不同的 People API 您可以在云控制台中查看两者的配额 G
  • 具有基本身份验证的 Webclient / HttpWebRequest 返回 404 未找到有效 URL

    编辑 我想回来指出问题根本不在我这边 而是与另一家公司的代码有关 我正在尝试使用基本身份验证来打开页面 我不断收到 404 页面未找到错误 我可以将我的网址复制并粘贴到浏览器中 它工作正常 如果我尚未登录他们的网站 它会弹出一个凭据框 否则
  • ASP.NET Core 默认调试启动 URL

    使用 ASP NET Core Web API 模板时 默认调试启动 URL 以某种方式设置为api values 此默认配置在哪里以及如何更改它 我能找到的有关此启动 URL 声明位置的文档非常少 这个里面有简短的提及博客文章 https
  • vuelidate 异步验证器 - 如何去抖?

    因此 我的电子邮件 用户表单元素上的异步验证器存在问题 每次输入字母时 它都会检查有效性 如果电子邮件有 30 个字符 那么就超过 30 个电话 有人知道消除 vuelidate 自定义验证器的最佳方法吗 当我尝试使用 debounce 时
  • 传统 For 循环与增强型 For 循环 [重复]

    这个问题在这里已经有答案了 这段代码 import java util import java io class TestClass public static void main String args throws Exception
  • Visual Basic .NET 中的 UInt32 数据类型是什么?

    是什么UInt32VB NET 中的数据类型 有人可以告诉我它的位长度和之间的区别吗UInt32 and Int32 它是整数还是浮点数 它是一个无符号 32 位整数 U 表示无符号 Int 表示整数 32 换 32 或者你可以看看文档 h
  • 将 SageMaker 管道模式与 tfrecords 的 s3 目录结合使用

    我打电话给sagemaker tensorflow TensorFlow fit 当我使用时无限期挂起 没有错误消息Pipe代替File as the input mode 我相应地替换了TensorFlowDataset with Pip