TensorFlow中重复训练和预测时如何避免OOM错误?

2023-12-03

我在 TensorFlow 中有一些代码,它采用一个基本模型,用一些数据对其进行微调(训练),然后使用该模型来predict()使用一些其他数据。所有这些都封装在一个main()模块的方法并且工作正常。

然而,当我在不同的基本模型上循环运行此代码时,我最终会在(例如 7 个基本模型)之后出现 OOM。这是预期的吗?我希望Python会在每次之后进行清理main()称呼。 TensorFlow 不这样做吗?我怎样才能强迫它呢?

Edit:这是一个 MWE,显示的不是 OOM 崩溃,而是内存消耗增加:

import gc
import os

import numpy as np
import psutil
import tensorflow as tf

tf.get_logger().setLevel("ERROR")  # Suppress "tf.function retracing" warnings
process = psutil.Process(os.getpid())
for i in range(100):
    (model := tf.keras.applications.mobilenet.MobileNet()).compile(loss="mse")
    history = model.fit(
        x=(x := tf.zeros((1, *model.input.shape[1:]))),
        y=(y := tf.zeros((1, *model.output.shape[1:]))),
        verbose=0,
    )
    prediction = model.predict(x)
    _ = gc.collect()
    # tf.keras.backend.clear_session()
    print(f"rss {i}: {process.memory_info().rss >> 20} MB")

在我的计算机(CPU)上,它打印

rss 0: 374 MB
rss 1: 438 MB
rss 2: 478 MB
rss 3: 517 MB
rss 4: 554 MB
rss 5: 588 MB
rss 6: 634 MB
rss 7: 669 MB
rss 8: 686 MB
rss 9: 726 MB
...
rss 30: 1386 MB
rss 31: 1413 MB
rss 32: 1445 MB
rss 33: 1476 MB
rss 34: 1506 MB
rss 35: 1536 MB
rss 36: 1568 MB
rss 37: 1597 MB
rss 38: 1630 MB
rss 39: 1662 MB
...

With tf.keras.backend.clear_session()未注释,它更好,但还不完美:

rss 0: 374 MB
rss 1: 420 MB
rss 2: 418 MB
rss 3: 450 MB
rss 4: 447 MB
rss 5: 469 MB
rss 6: 469 MB
rss 7: 475 MB
rss 8: 487 MB
rss 9: 494 MB
...
rss 40: 519 MB
rss 41: 516 MB
rss 42: 517 MB
rss 43: 520 MB
rss 44: 519 MB
rss 45: 519 MB
rss 46: 521 MB
rss 47: 517 MB
rss 48: 521 MB
rss 49: 521 MB
...
rss 90: 531 MB
rss 91: 531 MB
rss 92: 531 MB
rss 93: 531 MB
rss 94: 532 MB
rss 95: 532 MB
rss 96: 533 MB
rss 97: 534 MB
rss 98: 533 MB
rss 99: 533 MB

切换顺序gc.collect() and tf.keras.backend.clear_session()也没有帮助。


None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow中重复训练和预测时如何避免OOM错误? 的相关文章

随机推荐

  • keytool 和 openssl 证书指纹不匹配

    我试图在 META INF 内对 Android 开发者证书进行指纹识别 以用于研究目的 我发现在某些情况下 keytool 和 openssl 的输出会给我同一证书提供不同的 SHA1 指纹 使用密钥工具 keytool princert
  • OpenMP 开销计算

    给定 n 个线程 有没有一种方法可以计算在 OpenMP 中实现特定指令所需的开销量 例如周期数 例如 给出下面的代码 pragma omp parallel pragma omp for for int i 0 i lt m i a i
  • 安装allure pytest适配器后出错

    我正在尝试在 Windows 8 机器上使用 Pytest 3 6xx 配置 Allure 2 6 0 我能够运行 pytest 并生成 jUnit xml 报告文件 稍后我可以将其传递给 allure allure 服务器 jUnitXm
  • 是否可以检测用户何时切换到不同的浏览器选项卡?

    我试图检测用户何时从当前浏览器选项卡切换到另一个选项卡 监听 window onblur 在 Firefox 中可以很好地检测用户何时将焦点切换到另一个窗口 但当用户切换到另一个选项卡时它似乎不会触发 然而 当从另一个选项卡切换到有问题的选
  • Web 服务请求调用 SOAP 请求缺少空参数

    我对 Web 服务和 C 都很陌生 所以如果我的问题太简单 请原谅我 我四处搜寻 但找不到答案 至少根据我的关键词找到了答案 我尝试通过 C Visual Web Developer 2010 Express 调用 Web 服务 但收到错误
  • Coldfusion 中的哈希用于安全支付网关

    我正在尝试在 Coldfusion 中创建一个哈希密码 以便我们的安全支付网关接受交易 不幸的是 支付网关拒绝接受我生成的哈希值 该表单发送交易的所有元素 并发送基于五个不同字段生成的哈希值 在 PHP 中它是 我认为 Coldfusion
  • Java Web 应用程序指定入口点

    我有一些 Java Web 应用程序 现在它从 index jsp 页面开始 我有自己的课程 代码如下 import java io import javax servlet import javax servlet http public
  • 当我尝试将双精度型转换为浮点数时,为什么会出现错误?

    我在将双精度型转换为浮点数时遇到了一些问题 代码 float volume 0 5 Double i Volume Value volume float i 100F Bass BASS SetVolume volume 正如你所看到的 我
  • Visual Studio 中的 aspx 页面设计视图有用吗?

    我从来没有真正发现 Visual Studio 中的设计视图在开发 aspx 页面时有用 所以我基本上从不使用它 我是否遗漏了某些东西 或者这只是那些不是特别有用的功能之一 你使用设计视图吗 如果是这样 你觉得它有用吗 如果没有 为什么不呢
  • 如何根据XML文件自动生成WPF控件?

    我有一个 Xml 文件 它告诉我必须添加到表单中的控件 但此 Xml 会动态更改 我需要更新表单 目前 我可以读取XML文件 但我不知道是否可以基于该文件自动创建表单 对的 这是可能的 WPF 提供了多种在 Xaml 或代码中创建控件的方法
  • Excel 的独立代码

    Can VBA编写代码以对任何操作执行操作Excel file 当我在中创建项目时视觉工作室 它要求一个Excel要链接到它的文件 我写的所有代码都在ThisWorkbook vb因此仅作用于Excel链接到项目的文件 Ideally I
  • 如何全屏滑动选定的网格图像

    我创建了一个网格视图图像应用程序 我想在图像滑动中显示所选图像 实际上我在我的应用程序中实现了图像滑动but问题是图像滑动从第一张图像开始 而不是从选定的图像开始 example 如果我选择第三张图像 则图像滑动应该从第三张图像开始 而不是
  • 缺少必需参数:aws_access_key_id、aws_secret_access_key

    我目前正在尝试在终端中运行我的测试套件 但出现以下错误 Missing required arguments aws access key id aws secret access key ArgumentError 我在我的项目中使用 C
  • Android HttpClient:NetworkOnMainThreadException

    我有下面的一些代码 protected void testConnection String url DefaultHttpClient httpclient new DefaultHttpClient HttpGet httpget ne
  • 用户帐户“root”的指定密码无效,或无法连接到数据库服务器

    我在 Windows Server 2012R2 上使用 Windows 平台安装程序 5 0 安装 WordPress 时遇到此错误 目前我在该服务器上有一个带有 mySQL 的 php 站点 运行良好 几个月前 作为设置该网站的一部分
  • 正则表达式将给定单词替换为两侧的空格或根本不替换

    我正在使用 PHP 中的一些代码 从搜索引擎获取引用数据 为我提供用户输入的查询 然后 我想从该字符串中删除某些停用词 如果存在 但是 该单词两端可能有也可能没有空格 例如 我一直使用 str replace 删除一个单词 如下所示 key
  • 绘制图像分类模型的混淆矩阵

    我用 keras 构建了一个图像分类 CNN 虽然模型本身运行良好 它可以正确预测新数据 但我在绘制模型的混淆矩阵和分类报告时遇到问题 我使用 ImageDataGenerator 训练了模型 train path DATASET TRAI
  • 我如何获取全局javascript变量中的ajax内容

    我想将内容放入javascript全局定义的变量中 我使用ajax调用获得的内容 http pastebin com TqiJx3PA 感谢您的任何建议 Pastebin 代码已经做到了这一点 我猜你实际面临的问题是存在的 因为你的 aja
  • 字符串格式为 yyyy-MM-dd HH:mm:ss Iphone

    我有一个 nsstring 见下文 NSString Mydate 9 8 2011 以月 日 年的格式 我希望这个字符串的格式yyyy MM dd HH mm ss 例如 2011 09 08 15 51 57 这样我需要以后面的格式在标
  • TensorFlow中重复训练和预测时如何避免OOM错误?

    我在 TensorFlow 中有一些代码 它采用一个基本模型 用一些数据对其进行微调 训练 然后使用该模型来predict 使用一些其他数据 所有这些都封装在一个main 模块的方法并且工作正常 然而 当我在不同的基本模型上循环运行此代码时