我在 TensorFlow 中有一些代码,它采用一个基本模型,用一些数据对其进行微调(训练),然后使用该模型来predict()
使用一些其他数据。所有这些都封装在一个main()
模块的方法并且工作正常。
然而,当我在不同的基本模型上循环运行此代码时,我最终会在(例如 7 个基本模型)之后出现 OOM。这是预期的吗?我希望Python会在每次之后进行清理main()
称呼。 TensorFlow 不这样做吗?我怎样才能强迫它呢?
Edit:这是一个 MWE,显示的不是 OOM 崩溃,而是内存消耗增加:
import gc
import os
import numpy as np
import psutil
import tensorflow as tf
tf.get_logger().setLevel("ERROR") # Suppress "tf.function retracing" warnings
process = psutil.Process(os.getpid())
for i in range(100):
(model := tf.keras.applications.mobilenet.MobileNet()).compile(loss="mse")
history = model.fit(
x=(x := tf.zeros((1, *model.input.shape[1:]))),
y=(y := tf.zeros((1, *model.output.shape[1:]))),
verbose=0,
)
prediction = model.predict(x)
_ = gc.collect()
# tf.keras.backend.clear_session()
print(f"rss {i}: {process.memory_info().rss >> 20} MB")
在我的计算机(CPU)上,它打印
rss 0: 374 MB
rss 1: 438 MB
rss 2: 478 MB
rss 3: 517 MB
rss 4: 554 MB
rss 5: 588 MB
rss 6: 634 MB
rss 7: 669 MB
rss 8: 686 MB
rss 9: 726 MB
...
rss 30: 1386 MB
rss 31: 1413 MB
rss 32: 1445 MB
rss 33: 1476 MB
rss 34: 1506 MB
rss 35: 1536 MB
rss 36: 1568 MB
rss 37: 1597 MB
rss 38: 1630 MB
rss 39: 1662 MB
...
With tf.keras.backend.clear_session()
未注释,它更好,但还不完美:
rss 0: 374 MB
rss 1: 420 MB
rss 2: 418 MB
rss 3: 450 MB
rss 4: 447 MB
rss 5: 469 MB
rss 6: 469 MB
rss 7: 475 MB
rss 8: 487 MB
rss 9: 494 MB
...
rss 40: 519 MB
rss 41: 516 MB
rss 42: 517 MB
rss 43: 520 MB
rss 44: 519 MB
rss 45: 519 MB
rss 46: 521 MB
rss 47: 517 MB
rss 48: 521 MB
rss 49: 521 MB
...
rss 90: 531 MB
rss 91: 531 MB
rss 92: 531 MB
rss 93: 531 MB
rss 94: 532 MB
rss 95: 532 MB
rss 96: 533 MB
rss 97: 534 MB
rss 98: 533 MB
rss 99: 533 MB
切换顺序gc.collect()
and tf.keras.backend.clear_session()
也没有帮助。