如何在 tf.data.Dataset.map 中使用预训练的 keras 模型进行推理?

2024-01-09

我有一个预先训练的模型,我正在尝试构建另一个模型,该模型将前一个模型的输出作为输入。我不想端到端地训练模型,只想使用第一个模型进行推理。第一个模型的训练使用tf.data.Dataset管道,我的第一个倾向是将模型集成为另一个dataset.map()在管道尾部进行操作,但我遇到了问题。我在此过程中遇到了 20 个不同的错误,每个错误都与前一个无关。批量归一化层似乎尤其是一个痛点。

下面是一个说明该问题的最小入门示例。它是用 R 编写的,但也欢迎用 python 提供答案。

我使用的是tensorflow-gpu版本1.13.1和kerastf.keras

library(reticulate)
library(tensorflow)
library(keras)
library(tfdatasets)
use_implementation("tensorflow")

model_weights_path <- 'model-weights.h5'

arr <- function(...) 
  np_array(array(seq_len(prod(unlist(c(...)))), unlist(c(...))), dtype = 'float32')

new_model <- function(load_weights = TRUE) {
  model <- keras_model_sequential() %>% 
    layer_conv_1d(5, 5, activation = 'relu', input_shape = shape(150, 10)) %>%
    layer_batch_normalization() %>%
    layer_flatten() %>%
    layer_dense(10, activation = 'softmax')
  if (load_weights)
    load_model_weights_hdf5(model, model_weights_path)
  freeze_weights(model)
  model
}

if(!file.exists(model_weights_path)) {
  model <- new_model(FALSE) 
  save_model_weights_hdf5(model, model_weights_path)
}

model <- new_model()

data <- arr(20, 150, 10)
ds <- tfdatasets::tensors_dataset(data) %>% 
  dataset_repeat()

ds2 <- ds %>% 
  dataset_map(function(x) {
    model(x)
  })

try(nb <- next_batch(ds2))

sess <- k_get_session()
it <- make_iterator_initializable(ds2)
sess$run(iterator_initializer(it))
nb <- it$get_next()

try(sess$run(nb))

sess$run(tf$initialize_all_variables())

try(sess$run(nb))

也许这不会直接回答你的问题,因为我不熟悉 R。但我最近使用构建了一个输入管道tf.data.

The generate_images函数映射使用.map并使用经过训练的生成器模型来生成新图像。

gen_model = tf.keras.models.load_model(artifact_dir+'/'+generators[-1], compile=False)

NOISE_DIM = 100

def generate_images(l):
    # generate images using the trained generator
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
    images = gen_model(noise)

    # prepare the images for resize_and_preprocess function
    images = tf.squeeze(images, axis=-1)
    images = images*0.5+0.5
    images = tf.image.convert_image_dtype(images, dtype=tf.uint8)

    return images

genloader = tf.data.Dataset.from_tensors([1])

genloader = (
    genloader
    .map(generate_images, num_parallel_calls=AUTO)
    .map(resize_and_preprocess, num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

关于批量归一化,它在训练和推理阶段的表现有所不同。在基于 Python 的 TensorFlow 中,需要通过training=False当使用具有批量归一化层的预训练模型时。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 tf.data.Dataset.map 中使用预训练的 keras 模型进行推理? 的相关文章

  • 我应该使用 Python 双端队列还是列表作为堆栈? [复制]

    这个问题在这里已经有答案了 我想要一个可以用作堆栈的 Python 对象 使用双端队列还是列表更好 元素数量较少还是数量较多有什么区别 您的情况可能会根据您的应用程序和具体用例而有所不同 但在一般情况下 列表非常适合堆栈 append is
  • 使用主题交换运行多个 Celery 任务

    我正在用 Celery 替换一些自制代码 但很难复制当前的行为 我期望的行为如下 创建新用户时 应向tasks与交换user created路由键 该消息应该触发两个 Celery 任务 即send user activate email
  • 在 R 中使用 Huggingface Transformer 模型

    我正在尝试在 R 中使用不同的 Huggingface 模型 这是通过 reticulate 导入 Transformer 包来实现的 谢谢 https rpubs com eR ic transfoRmers https rpubs co
  • python multiprocessing 设置生成进程等待

    是否可以生成一些进程并将生成进程设置为等待生成的进程完成 下面是我用过的一个例子 import multiprocessing import time import sys def daemon p multiprocessing curr
  • MongoEngine 查询具有以列表中指定的前缀开头的属性的对象的列表

    我需要在 Mongo 数据库中查询具有以列表中任何前缀开头的特定属性的元素 现在我有一段这样的代码 query mymodel terms term in query terms 并且这会匹配在列表 term 上有一个项目的对象 该列表中的
  • Tensorboard SyntaxError:语法无效

    当我尝试制作张量板时 出现语法错误 尽管开源代码我还是无法理解 我尝试搜索张量板的代码 但不清楚 即使我不擅长Python 我这样写路径C Users jh902 Documents logs因为我正在使用 Windows 10 但我不确定
  • 使用 dplyr::filter 的整洁方式是什么?

    使用下面的函数调用foo c b 输出以内联方式显示 正确的写作方式是什么df gt filter x gt x 我已经包含了一个使用的示例mutate以整洁的风格与之对比filter foo lt function variables x
  • Python 3:将字符串转换为变量[重复]

    这个问题在这里已经有答案了 我正在从 txt 文件读取文本 并且需要使用我读取的数据之一作为类实例的变量 class Sports def init self players 0 location name self players pla
  • Java 和 Python 可以在同一个应用程序中共存吗?

    我需要一个 Java 实例直接从 Python 实例数据存储中获取数据 我不知道这是否可能 数据存储是否透明 唯一 或者每个实例 如果它们确实可以共存 都有其单独的数据存储 总结一下 Java 应用程序如何从 Python 应用程序的数据存
  • 当字段是数字时怎么说...在 mongodb 中匹配?

    所以我的结果中有一个名为 城市 的字段 结果已损坏 有时它是一个实际名称 有时它是一个数字 以下代码显示所有记录 db zips aggregate project city substr city 0 1 sort city 1 我需要修
  • 尽管我已在 python ctypes 中设置了信号处理程序,但并未调用它

    我尝试过使用 sigaction 和 ctypes 设置信号处理程序 我知道它可以与python中的信号模块一起使用 但我想尝试学习 当我向该进程发送 SIGTERM 时 但它没有调用我设置的处理程序 只打印 终止 为什么它不调用处理程序
  • pandas - 包含时间序列数据的堆积条形图

    我正在尝试使用时间序列数据在 pandas 中创建堆积条形图 DATE TYPE VOL 0 2010 01 01 Heavy 932 612903 1 2010 01 01 Light 370 612903 2 2010 01 01 Me
  • Django REST Framework - CurrentUserDefault 使用

    我正在尝试使用CurrentUserDefault一个序列化器的类 user serializers HiddenField default serializers CurrentUserDefault 文档说 为了使用它 请求 必须作为
  • 如何使用 Python 3 检查目录是否包含文件

    我到处寻找这个答案但找不到 我正在尝试编写一个脚本来搜索特定的子文件夹 然后检查它是否包含任何文件 如果包含 则写出该文件夹的路径 我已经弄清楚了子文件夹搜索部分 但检查文件却难倒了我 我发现了有关如何检查文件夹是否为空的多个建议 并且我尝
  • ggplot2、R 中的单条形条形图

    我有以下数据和代码 gt ddf var1 var2 1 aa 73 2 bb 18 3 cc 9 gt gt dput ddf structure list var1 c aa bb cc var2 c 73L 18L 9L Names
  • python 中的“槽包装器”是什么?

    object dict 和其他地方的隐藏方法设置为这样的
  • 每当使用 import cv2 时 OpenCV 都会出错

    我在终端上使用 pip3 install opencv contrib python 安装了 cv2 并且它工作了 但是每当我尝试导入 cv2 或运行导入了 cv2 的 vscode 文件时 在 python IDLE 上它都会说 Trac
  • 重新分配唯一值 - pandas DataFrame

    我在尝试着assign unique值在pandas df给特定的个人 For the df below Area and Place 会一起弥补unique不同的价值观jobs 这些值将分配给个人 总体目标是使用尽可能少的个人 诀窍在于这
  • 如何将 ggrough 图表另存为 .png

    说我正在使用R包裹ggrough https xvrdm github io ggrough https xvrdm github io ggrough 我有这个代码 取自该网页 library ggplot2 library ggroug
  • 如何将Python3设置为Mac上的默认Python版本?

    有没有办法将 Python 3 8 3 设置为 macOS Catalina 版本 10 15 2 上的默认 Python 版本 我已经完成的步骤 看看它安装在哪里 ls l usr local bin python 我得到的输出是这样的

随机推荐