在循环中评估 Tensorflow 操作非常慢

2023-12-30

我试图通过编码一些简单的问题来学习张量流:我试图使用直接采样蒙特卡罗方法找到 pi 的值。

运行时间比我想象的要长得多for loop去做这个。我看过其他关于类似事情的帖子,并且我尝试遵循解决方案,但我认为我仍然一定做错了什么。

下面附上我的代码:

import tensorflow as tf
import numpy as np
import time

n_trials = 50000

tf.reset_default_graph()


x = tf.random_uniform(shape=(), name='x')
y = tf.random_uniform(shape=(), name='y')
r = tf.sqrt(x**2 + y**2)

hit = tf.Variable(0, name='hit')

# perform the monte carlo step
is_inside = tf.cast(tf.less(r, 1), tf.int32)
hit_op = hit.assign_add(is_inside) 

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # Make sure no new nodes are added to the graph
    sess.graph.finalize()

    start = time.time()   

    # Run monte carlo trials  -- This is very slow
    for _ in range(n_trials):
        sess.run(hit_op)

    hits = hit.eval()
    print("Pi is {}".format(4*hits/n_trials))
    print("Tensorflow operation took {:.2f} s".format((time.time()-start)))

>>> Pi is 3.15208
>>> Tensorflow operation took 8.98 s

相比之下,做一个for loopnumpy 中的类型解决方案速度快一个数量级

start = time.time()   
hits = [ 1 if np.sqrt(np.sum(np.square(np.random.uniform(size=2)))) < 1 else 0 for _ in range(n_trials) ]
a = 0
for hit in hits:
    a+=hit
print("numpy operation took {:.2f} s".format((time.time()-start)))
print("Pi is {}".format(4*a/n_trials))

>>> Pi is 3.14032
>>> numpy operation took 0.75 s

下面附上的图表显示了不同次数的试验的总体执行时间的差异。

请注意:我的问题不是关于“如何最快地执行此任务”,我认识到有更多有效的方法来计算 Pi。我仅使用它作为基准测试工具来根据我熟悉的东西(numpy)检查张量流的性能。


速度慢与 Python 和 Tensorflow 之间的一些通信开销有关sess.run,它在循环内执行多次。我建议使用tf.while_loop在 Tensorflow 中执行计算。这将是一个更好的比较numpy.

import tensorflow as tf
import numpy as np
import time

n_trials = 50000

tf.reset_default_graph()

hit = tf.Variable(0, name='hit')

def body(ctr):
    x = tf.random_uniform(shape=[2], name='x')
    r = tf.sqrt(tf.reduce_sum(tf.square(x))
    is_inside = tf.cond(tf.less(r,1), lambda: tf.constant(1), lambda: tf.constant(0))
    hit_op = hit.assign_add(is_inside)
    with tf.control_dependencies([hit_op]):
        return ctr + 1

def condition(ctr):
    return ctr < n_trials

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    result = tf.while_loop(condition, body, [tf.constant(0)])

    start = time.time()
    sess.run(result)

    hits = hit.eval()
    print("Pi is {}".format(4.*hits/n_trials))
    print("Tensorflow operation took {:.2f} s".format((time.time()-start)))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在循环中评估 Tensorflow 操作非常慢 的相关文章

  • 使用 ElementTree 时出现未定义实体错误

    我有一组 XML 文件 需要读取它们并将其格式化为单个 CSV 文件 为了读取 XML 文件 我使用了提到的解决方案here https stackoverflow com questions 5530857 parse xml file
  • 将 SSH 密钥文件与 Fabric 结合使用

    如何配置结构以使用 SSH 密钥文件连接到远程主机 例如 Amazon EC2 实例 由于某种原因 找到一个带有 SSH 密钥文件使用示例的简单 fabfile 并不容易 我写了一个博客文章 http blog y3xz com post
  • Python 和图形数据库。使用 java lib 包装器还是 REST api?

    我想问你在Python中使用图数据库 Neo4j 的最佳方法 你觉得我应该使用 neo4j python embedded neo4j python 嵌入式 http docs neo4j org chunked milestone pyt
  • Keras 模型无法预测是否在线程中调用

    我尝试在线程应用程序中使用 keras 和可用模型 VGG16 执行预测 但是 如果我在主线程中调用预测 一切都会正常 但是如果我在线程函数内部进行预测 无论我使用threading multiprocessing 它只是在预测过程中停止
  • 使用 QuantLib 计算带有下限的 FloatingRateBond 的现金流量

    对 QuantLib 非常陌生 所以猜测这是一个菜鸟错误 很高兴了解这个强大的库 所以感谢作者和贡献者 如果没有下限参数 我可以在没有定价器的情况下为 FloatingRateBond 生成现金流量金额 所以我不明白为什么包含下限参数需要定
  • 在 MAC OS X 10.9 上安装 NLTK 确实很困难

    我是 Python Mac OS 新手 我正在寻找 NLTK 教科书 但我在安装它时遇到了一些问题 我一直在寻找解决方案 但不幸的是 所有解决方案似乎都不适合我 或者我误解了如何使用它们 我遇到的基本问题是 尽管按照说明进行操作 NLTK
  • 如何在Python中增加文件名

    我正在尝试保存大量需要分成不同文件的数据 如下所示 数据 1 dat 数据 2 dat 数据 3 dat 数据 4 dat 我如何在Python中实现这个 from itertools import count filename data
  • 从 Robot Framework 访问 python 类的变量

    我有一个 python 文件 例如 Animals py 在里面我定义了 3 个不同的类 如下所示 Animals py class Animal listAnimal dog cat lt def init self Animal con
  • 如何查看Databricks中的所有数据库和表

    我想列出 Azure Databricks 中每个数据库中的所有表 所以我希望输出看起来像这样 Database Table name Database1 Table 1 Database1 Table 2 Database1 Table
  • 通过 Scrapy 抓取 Google Analytics

    我一直在尝试使用 Scrapy 从 Google Analytics 获取一些数据 尽管我是一个完全的 Python 新手 但我已经取得了一些进展 我现在可以通过 Scrapy 登录 Google Analytics 但我需要发出 AJAX
  • Python 中的 Lanczos 插值与 2D 图像

    我尝试重新缩放 2D 图像 灰度 图像大小为 256x256 所需输出为 224x224 像素值范围从 0 到 1300 我尝试了两种使用 Lanczos 插值来重新调整它们的方法 首先使用PIL图像 import numpy as np
  • OpenCV Python cv2.mixChannels()

    我试图将其从 C 转换为 Python 但它给出了不同的色调结果 In C Transform it to HSV cvtColor src hsv CV BGR2HSV Use only the Hue value hue create
  • 如何在flask中使用g.user全局

    据我了解 Flask 中的 g 变量 它应该为我提供一个全局位置来存储数据 例如登录后保存当前用户 它是否正确 我希望我的导航在登录后在整个网站上显示我的用户名 我的观点包含 from Flask import g among other
  • Django:按钮链接

    我是一名 Django 新手用户 尝试创建一个按钮 单击该按钮会链接到我网站中的另一个页面 我尝试了一些不同的例子 但似乎没有一个对我有用 举个例子 为什么这不起作用
  • 使用 matplotlib 绘制时间序列数据并仅在年初显示年份

    rcParams date autoformatter month b n Y 我正在使用 matpltolib 来绘制时间序列 如果我按上述方式设置 rcParams 则生成的图会在每个刻度处标记月份名称和年份 我怎样才能将其设置为仅在每
  • Python - StatsModels、OLS 置信区间

    在 Statsmodels 中 我可以使用以下方法拟合我的模型 import statsmodels api as sm X np array 22000 13400 47600 7400 12000 32000 28000 31000 6
  • Flask 会话变量

    我正在用 Flask 编写一个小型网络应用程序 当两个用户 在同一网络下 尝试使用应用程序时 我遇到会话变量问题 这是代码 import os from flask import Flask request render template
  • 如何使用 Ansible playbook 中的 service_facts 模块检查服务是否存在且未安装在服务器中?

    我用过service facts检查服务是否正在运行并启用 在某些服务器中 未安装特定的软件包 现在 我如何知道这个特定的软件包没有安装在该特定的服务器上service facts module 在 Ansible 剧本中 它显示以下错误
  • 如何替换 pandas 数据框列中的重音符号

    我有一个数据框dataSwiss其中包含瑞士城市的信息 我想用普通字母替换带有重音符号的字母 这就是我正在做的 dataSwiss Municipality dataSwiss Municipality str encode utf 8 d
  • 测试 python Counter 是否包含在另一个 Counter 中

    如何测试是否是pythonCounter https docs python org 2 library collections html collections Counter is 包含在另一个中使用以下定义 柜台a包含在计数器中b当且

随机推荐