使用多核CPU的正确方法是什么jax.pmap
?
以下示例在 CPU 核心后端上为 SPMD 创建环境变量,测试 JAX 是否识别设备,并尝试设备锁定。
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'
import jax as jx
import jax.numpy as jnp
jx.local_device_count()
# WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# 2
jx.devices("cpu")
# [CpuDevice(id=0), CpuDevice(id=1)]
def sfunc(x): while True: pass
jx.pmap(sfunc)(jnp.arange(2))
从 jupyter 内核执行并观察htop
显示只有一个核心被锁定
我收到相同的输出htop
省略前两行并运行时:
$ env XLA_FLAGS=--xla_force_host_platform_device_count=2 python test.py
更换sfunc
with
def sfunc(x): return 2.0*x
并打电话
jx.pmap(sfunc)(jnp.arange(2))
# ShardedDeviceArray([0., 2.], dtype=float32, weak_type=True)
确实返回一个SharedDeviecArray
.
显然我没有正确配置 JAX/XLA 以使用两个核心。我缺少什么以及我可以做什么来诊断问题?
据我所知,您正在正确配置核心(参见例如问题#2714 https://github.com/google/jax/issues/2714)。问题出在你的测试函数上:
def sfunc(x): while True: pass
该函数陷入无限循环在跟踪时,不在运行时。跟踪发生在单个 CPU 上的主机 Python 进程中(请参阅如何在 JAX 中思考 https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html了解 JAX 转换中跟踪的概念)。
如果您想在运行时观察 CPU 使用情况,则必须使用一个完成跟踪并开始运行的函数。为此,您可以使用任何实际产生结果的长时间运行的函数。这是一个简单的例子:
def sfunc(x):
for i in range(100):
x = (x @ x)
return x
jx.pmap(sfunc)(jnp.zeros((2, 1000, 1000)))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)