如何在张量流中将“张量”转换为“numpy”数组?

2023-11-24

我正在尝试在 tesnorflow2.0 版本中将张量转换为 numpy。由于 tf2.0 已启用急切执行,因此它应该默认工作并且在正常运行时也工作。当我在 tf.data.Dataset API 中执行代码时,它给出了一个错误

“属性错误:‘张量’对象没有属性‘numpy’”

我在张量流变量之后尝试过“.numpy()”,而对于“.eval()”我无法获得默认会话。

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
# tf.executing_eagerly()
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
from model.utils import  get_noise
import cv2


def random_noise(input_image):
  img_out = get_noise(input_image)
  return img_out


def load_denoising(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_png(image)
  real_image = image
  input_image = random_noise(image.numpy())
  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)
  return input_image, real_image


def load_image_train(image_file):
  input_image, real_image = load_denoising(image_file)
  return input_image, real_image

这很好用

inp, re = load_denoising('/data/images/train/18.png')
# Check for correct run
plt.figure()
plt.imshow(inp)
print(re.shape,"  ", inp.shape)

这会产生提到的错误

train_dataset = tf.data.Dataset.list_files('/data/images/train/*.png')
train_dataset = train_dataset.map(load_image_train,num_parallel_calls=tf.data.experimental.AUTOTUNE)

注意:random_noise有cv2和sklearn函数


你不能使用.numpy张量上的方法,如果该张量将用于tf.data.Dataset.map call.

The tf.data.Dataset引擎盖下的对象通过创建静态图来工作:这意味着您不能使用.numpy()因为tf.Tensor处于静态图上下文中的对象没有此属性。

因此,该行input_image = random_noise(image.numpy())应该input_image = random_noise(image).

但代码可能会再次失败,因为random_noise calls get_noise来自model.utils包裹。 如果get_noise函数是使用 Tensorflow 编写的,那么一切都会正常。否则,就行不通。

解决方案?仅使用 Tensorflow 原语编写代码。

例如,如果你的函数get_noise只是用输入图像的表面创建随机噪声,您可以将其定义为:

def get_noise(image):
    return tf.random.normal(shape=tf.shape(image))

仅使用 Tensorflow 原语,它就可以工作。

希望此概述有所帮助!

P.S:您可能有兴趣查看文章“分析 tf.function 以发现 AutoGraph 的优点和微妙之处” - 它们涵盖了这一方面(也许第 3 部分是与您的场景相关的部分):part 1 part 2 part 3

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在张量流中将“张量”转换为“numpy”数组? 的相关文章

随机推荐