在 Numba 中计算 numpy 数组中非零值的数量

2023-12-21

很简单。我正在尝试计算用 Numba 编译的 NumPy jit 中的数组中非零值的数量(njit())。我尝试过的以下操作是 Numba 不允许的。

  1. a[a != 0].size
  2. np.count_nonzero(a)
  3. len(a[a != 0])
  4. len(a) - len(a[a == 0])

如果还有更快、更 Pythonic 和优雅的方式,我不想使用 for 循环。

对于想要查看完整代码示例的评论者......

import numpy as np
from numba import njit

@njit()
def n_nonzero(a):
    return a[a != 0].size

您还可以考虑计算非零值:

import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

我知道这似乎是错误的,但请耐心等待:

import numpy as np
import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

@nb.njit()
def count_len_nonzero(a):
    return len(np.nonzero(a)[0])

@nb.njit()
def count_sum_neq_zero(a):
    return (a != 0).sum()

np.random.seed(100)
a = np.random.randint(0, 3, 1000000000, dtype=np.uint8)
c = np.count_nonzero(a)
assert count_len_nonzero(a) == c
assert count_sum_neq_zero(a) == c
assert count_loop(a) == c

%timeit count_len_nonzero(a)
# 5.94 s ± 141 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_sum_neq_zero(a)
# 848 ms ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_loop(a)
# 189 ms ± 4.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

它实际上比np.count_nonzero,由于某种原因可能会变得相当慢:

%timeit np.count_nonzero(a)
# 4.36 s ± 69.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 Numba 中计算 numpy 数组中非零值的数量 的相关文章

随机推荐