Tf.Print() 不打印张量的形状?

2023-12-24

我使用 Tensorflow 编写了一个简单的分类程序并获取输出,但我尝试打印模型参数、特征和偏差的张量形状。 函数定义:

import tensorflow as tf, numpy as np
from tensorflow.examples.tutorials.mnist import input_data


def get_weights(n_features, n_labels):
#    Return weights
    return tf.Variable( tf.truncated_normal((n_features, n_labels)) )

def get_biases(n_labels):
    # Return biases
    return tf.Variable( tf.zeros(n_labels))

def linear(input, w, b):
    #  Linear Function (xW + b)
#     return np.dot(input,w) + b 
    return tf.add(tf.matmul(input,w), b)

def mnist_features_labels(n_labels):
    """Gets the first <n> labels from the MNIST dataset
    """
    mnist_features = []
    mnist_labels = []
    mnist = input_data.read_data_sets('dataset/mnist', one_hot=True)

    # In order to make quizzes run faster, we're only looking at 10000 images
    for mnist_feature, mnist_label in zip(*mnist.train.next_batch(10000)):

        # Add features and labels if it's for the first <n>th labels
        if mnist_label[:n_labels].any():
            mnist_features.append(mnist_feature)
            mnist_labels.append(mnist_label[:n_labels])

    return mnist_features, mnist_labels

图形创建:

# Number of features (28*28 image is 784 features)
n_features = 784
# Number of labels
n_labels = 3

# Features and Labels
features = tf.placeholder(tf.float32)
labels = tf.placeholder(tf.float32)

# Weights and Biases
w = get_weights(n_features, n_labels)
b = get_biases(n_labels)

# Linear Function xW + b
logits = linear(features, w, b)

# Training data
train_features, train_labels = mnist_features_labels(n_labels)

print("Total {0} data points of Training Data, each having {1} features \n \
      Total {2} number of labels,each having 1-hot encoding {3}".format(len(train_features),len(train_features[0]),\
                                                                     len(train_labels),train_labels[0]
                                                                      )
     )

# global variables initialiser
init= tf.global_variables_initializer()

with tf.Session() as session:

    session.run(init)

问题就在这里:

#            shapes =tf.Print ( tf.shape(features), [tf.shape(features),
#                                                     tf.shape(labels),
#                                                     tf.shape(w),
#                                                     tf.shape(b),
#                                                     tf.shape(logits)
#                                                     ], message= "The shapes are:" )
#         print("Verify shapes",shapes)
    logits = tf.Print(logits, [tf.shape(features),
                           tf.shape(labels),
                           tf.shape(w),
                           tf.shape(b),
                           tf.shape(logits)],
                  message= "The shapes are:")
    print(logits)

我在看here https://stackoverflow.com/questions/33633370/how-to-print-the-value-of-a-tensor-object-in-tensorflow/36296783#36296783,但没发现有多大用处。

    # Softmax
    prediction = tf.nn.softmax(logits)

    # Cross entropy
    # This quantifies how far off the predictions were.
    # You'll learn more about this in future lessons.
    cross_entropy = -tf.reduce_sum(labels * tf.log(prediction), reduction_indices=1)

    # Training loss
    # You'll learn more about this in future lessons.
    loss = tf.reduce_mean(cross_entropy)

    # Rate at which the weights are changed
    # You'll learn more about this in future lessons.
    learning_rate = 0.08

    # Gradient Descent
    # This is the method used to train the model
    # You'll learn more about this in future lessons.
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

    # Run optimizer and get loss
    _, l = session.run(
        [optimizer, loss],
        feed_dict={features: train_features, labels: train_labels})

# Print loss
print('Loss: {}'.format(l))

我得到的输出是:

Extracting dataset/mnist/train-images-idx3-ubyte.gz
Extracting dataset/mnist/train-labels-idx1-ubyte.gz
Extracting dataset/mnist/t10k-images-idx3-ubyte.gz
Extracting dataset/mnist/t10k-labels-idx1-ubyte.gz
Total 3118 data points of Training Data, each having 784 features 
       Total 3118 number of labels,each having 1-hot encoding [0. 1. 0.]
Tensor("Print_22:0", shape=(?, 3), dtype=float32)
Loss: 5.339271068572998

谁能帮助我理解,为什么我看不到张量的形状?


这不是你使用的方式tf.Print。它是一个本身不执行任何操作(仅返回输入)的操作,但会打印所请求的张量作为副作用。你应该做类似的事情

logits = tf.Print(logits, [tf.shape(features),
                           tf.shape(labels),
                           tf.shape(w),
                           tf.shape(b),
                           tf.shape(logits)],
                  message= "The shapes are:")

现在,每当logits被评估(因为它将用于计算损失/梯度),将打印形状信息。

您现在所做的只是打印返回值tf.Printop,这只是它的输入(tf.shape(features)).

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tf.Print() 不打印张量的形状? 的相关文章

随机推荐

  • 如何使用java httpclient实现大文件的HTTP Post分块上传?

    我有一个巨大的文件要上传 另一端的服务器确实支持分块上传 有没有具体的例子说明如何做到这一点 或者还有其他库可以做到这一点 使用 HttpClient 4 来自 Apache HttpPost post new HttpPost url M
  • 如何通过Java编码调用浏览器?

    我想通过 Java 接口调用浏览器 例如 Internet Explorer Firefox Google Chrome Opera 我还需要将一些 Web 链接传递给这个 Java 进程 如何实现这一目标 您可以使用桌面API http
  • WebView getScrollY() 始终返回 0

    我尝试使用 webview 的滚动位置来确定 SwipeRefreshLayout 是否应该能够刷新 除了某些网站 例如https jobs lever co memebox https jobs lever co memebox getS
  • 所以...NoSQL 的事情

    我一直在研究 MongoDB 并且着迷 看来 尽管我不得不怀疑 作为以稍微不同的方式组织数据库的交换 我获得了与免费的 CPU 和 RAM 一样多的性能 它看起来优雅 灵活 但我不会像使用 Rails 那样以快速换取它 那么有什么问题呢 关
  • 为什么这个简单的 Spark 程序不利用多核?

    因此 我在 16 核多核系统上运行这个简单的程序 我运行它 通过发布以下内容 spark submit master local pi py 该程序的代码如下 pi py from pyspark import SparkContext i
  • WSO2 ESB 覆盖 ContentType 属性

    我正在开发 WSO2 ESB 代理服务 该服务涉及通过 ESB 上的 SOAP 端点公开内部 RESTful 服务 我的 RESTful 服务需要 Content type application rdf xml 我尝试使用文档中提到的所有
  • 如果函数创建并返回一个对象,它是否应该在自动释放池中

    我对 Objective C 还是很陌生 据我所知 任何我没有从 alloc new copy 或 mutableCopy 获得的对象都应该被假定在自动释放池中 我认为这也意味着 如果我创建一个创建并返回对象的新实例的函数 我应该在返回之前
  • 在lstm语言模型中使用预训练的word2vec?

    我用tensorflow来训练LSTM语言模型 代码来自here https github com tensorflow models blob master tutorials rnn ptb ptb word lm py 根据文章her
  • 将多个 json 数据添加到 panda 数据帧

    我正在使用 api 获取 3 个 json 数据 我想将这些数据添加到 1 个 panda 数据帧 这是我的代码 我传入的书籍中包含书籍 id 作为 x 这 3 个 id 返回了 3 个不同的 json 对象 其中包含所有书籍信息 for
  • WPF DataGrid CanUserAddRows = True

    我似乎在向 a 添加行时遇到问题DataGrid通过界面本身 这是用户界面的屏幕截图 正如您所看到的 在数据 库中找到了 0 行 因此没有任何内容显示在数据库中DataGrid在右侧 但我喜欢那里有一个空行 用于手动添加行 这DataGri
  • 获取 iOS Swift 中的顶级 ViewController

    我想实现一个单独的 ErrorHandler 类 它显示某些事件的错误消息 此类的行为应该从不同的其他类中调用 当发生错误时 会有一个UIAlertView作为输出 此 AlertView 的显示应始终位于顶部 因此 无论错误从哪里抛出 最
  • 如何将 RPC 与 Volttron 结合使用

    我想在我的 volttron 应用程序中使用 RPC 调用 但我无法让任何调用正常工作 所有调用都会失败 并出现 没有到主机的路由 错误
  • 为什么 Django 开发服务器会挂在这个管理工具 JS 文件上?

    使用 Django 管理工具时 它会定期挂起并停止响应请求 直到重新启动为止 每当它挂起时 日志中的最后一行是 获取 admin jsi18n HTTP 1 1 200 2158 挂起似乎发生在 POST 之后 例如查看添加对象的结果时 据
  • flowtype如何用可选字段注释联合

    如何在流程中实现以下目标 export type Response err string data Array data Array 我想表达一种类型 它返回错误和可选数据 或者不返回错误字段 如果没有 但是 我用它作为 return er
  • “npx tsc --version”报告虚拟机内不同的 TypeScript 版本

    我希望能够跑步npx tsc在我的主机 来宾操作系统上的项目中 但客人正在使用不同的 旧的 版本tsc 我不确定它是从哪里来的 我的设置 主机操作系统 Windows 10 来宾操作系统 Debian 9 我正在使用 VirtualBox
  • 使用 IDisposable 取消订阅事件

    我有一个处理来自 WinForms 控件的事件的类 根据用户正在执行的操作 我引用该类的一个实例并创建一个新实例来处理同一事件 我需要首先从事件中取消订阅旧实例 很简单 如果可能的话 我想以非专有的方式执行此操作 这似乎是 IDisposa
  • JQuery 自动完成:如何强制从列表中选择(键盘)

    我正在使用 JQuery UI 自动完成 一切都按预期进行 但是当我使用键盘上的向上 向下键循环时 我注意到文本框按预期填充了列表中的项目 但是当我到达列表末尾并再次按下向下箭头时这时 我输入的原始术语就会出现 这基本上允许用户提交该条目
  • git p4克隆/同步:如何添加新的P4路径

    我创建了一个 P4 客户端视图规范 并用它制作了一个 git p4 克隆 并定期同步 P4 的更改 效果非常好 有一天 我想向我克隆的 Git 存储库添加另一个 P4 路径 但它卡住了 即使我添加了 git p4 克隆使用的客户端视图规范的
  • mysql 自动终止查询

    mysql 是否有可能自动终止耗时超过 20 秒的查询 我猜您正在寻找名为 mk kill 的 maatkit 实用程序 它将杀死符合某些条件的查询
  • Tf.Print() 不打印张量的形状?

    我使用 Tensorflow 编写了一个简单的分类程序并获取输出 但我尝试打印模型参数 特征和偏差的张量形状 函数定义 import tensorflow as tf numpy as np from tensorflow examples