过滤(减少)NumPy 数组

2024-07-04

假设我有一个 NumPy 数组arr我想根据(可广播)函数的真值进行逐元素过滤(减少),例如 我只想获取低于某个阈值的值k:

def cond(x):
    return x < k

有几种方法,例如:

  1. 使用发电机:np.fromiter((x for x in arr if cond(x)), dtype=arr.dtype)(这是使用列表理解的内存高效版本:np.array([x for x in arr if cond(x)])因为np.fromiter() https://numpy.org/doc/stable/reference/generated/numpy.fromiter.html将直接生成一个 NumPy 数组,而不需要分配中间的 Pythonlist)
  2. 使用布尔掩码:arr[cond(arr)]
  3. 使用整数索引:arr[np.nonzero(cond(arr))](或等效地使用np.where() https://numpy.org/doc/stable/reference/generated/numpy.where.html因为它默认为np.nonzero() https://numpy.org/doc/stable/reference/generated/numpy.nonzero.html只有一个条件)
  4. Using explicit looping with:
    • 单遍和最终复制/调整大小
    • 两遍:第一步确定结果的大小,另一遍实际执行计算

(最后两种方法可以加速Cython https://cython.org/ or Numba https://numba.pydata.org/)

哪个最快?内存效率怎么样?


(编辑:直接使用np.nonzero()代替np.where()根据@ShadowRanger 评论)


Summary

使用基于循环的方法进行单遍和复制,并通过 Numba 加速,可在速度、内存效率和灵活性方面提供最佳的整体权衡。 如果条件函数的执行足够快,则两次传递 (filter2_nb())可能会更快,但无论如何它们的内存效率更高。 此外,对于足够大的输入,调整大小而不是复制(filter_resize_xnb()) 导致执行速度更快。

如果提前知道数据类型(和条件函数)并且可以编译,Cython 加速似乎会更快。 类似的条件硬编码很可能也会导致与 Numba 加速相当的加速。

当涉及仅基于 NumPy 的方法时,布尔掩码或整数索引的速度相当,哪一种速度更快主要取决于过滤因子,即通过过滤条件的值的部分。

The np.fromiter()方法是much速度较慢(这将超出图中的图表),但不会产生大型临时对象。

请注意,以下测试旨在提供对不同方法的一些见解,因此应持保留态度。 最相关的假设是条件是可广播的并且它最终会计算得非常快。


定义

  1. 使用发电机:
def filter_fromiter(arr, cond):
    return np.fromiter((x for x in arr if cond(x)), dtype=arr.dtype)
  1. 使用布尔掩码:
def filter_mask(arr, cond):
    return arr[cond(arr)]
  1. 使用整数索引:
def filter_idx(arr, cond):
    return arr[np.nonzero(cond(arr))]

4a.使用显式循环,单遍和最终复制/调整大小

  • Cython 通过复制加速(预编译条件)
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


import numpy as np


cdef long NUM = 1048576
cdef long MAX_VAL = 1048576
cdef long K = 1048576 // 2


cdef int cond_cy(long x, long k=K):
    return x < k


cdef size_t _filter_cy(long[:] arr, long[:] result, size_t size):
    cdef size_t j = 0
    for i in range(size):
        if cond_cy(arr[i]):
            result[j] = arr[i]
            j += 1
    return j


def filter_cy(arr):
    result = np.empty_like(arr)
    new_size = _filter_cy(arr, result, arr.size)
    return result[:new_size].copy()
  • Cython 加速并调整大小(预编译条件)
def filter_resize_cy(arr):
    result = np.empty_like(arr)
    new_size = _filter_cy(arr, result, arr.size)
    result.resize(new_size)
    return result
  • Numba 通过复制加速
import numba as nb


@nb.njit
def cond_nb(x, k=K):
    return x < k


@nb.njit
def filter_nb(arr, cond_nb):
    result = np.empty_like(arr)
    j = 0
    for i in range(arr.size):
        if cond_nb(arr[i]):
            result[j] = arr[i]
            j += 1
    return result[:j].copy()
  • Numba 通过调整大小来加速
@nb.njit
def _filter_out_nb(arr, out, cond_nb):
    j = 0
    for i in range(arr.size):
        if cond_nb(arr[i]):
            out[j] = arr[i]
            j += 1
    return j


def filter_resize_xnb(arr, cond_nb):
    result = np.empty_like(arr)
    j = _filter_out_nb(arr, result, cond_nb)
    result.resize(j, refcheck=False)  # unsupported in NoPython mode
    return result
  • 使用生成器进行 Numba 加速,并且np.fromiter()
@nb.njit
def filter_gen_nb(arr, cond_nb):
    for i in range(arr.size):
        if cond_nb(arr[i]):
            yield arr[i]


def filter_gen_xnb(arr, cond_nb):
    return np.fromiter(filter_gen_nb(arr, cond_nb), dtype=arr.dtype)

4b.使用显式循环进行两次传递:一次确定结果的大小,一次实际执行计算

  • Cython 加速(预编译条件)
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


cdef size_t _filtered_size_cy(long[:] arr, size_t size):
    cdef size_t j = 0
    for i in range(size):
        if cond_cy(arr[i]):
            j += 1
    return j


def filter2_cy(arr):
    cdef size_t new_size = _filtered_size_cy(arr, arr.size)
    result = np.empty(new_size, dtype=arr.dtype)
    new_size = _filter_cy(arr, result, arr.size)
    return result
  • Numba 加速
@nb.njit
def filter2_nb(arr, cond_nb):
    j = 0
    for i in range(arr.size):
        if cond_nb(arr[i]):
            j += 1
    result = np.empty(j, dtype=arr.dtype)
    j = 0
    for i in range(arr.size):
        if cond_nb(arr[i]):
            result[j] = arr[i]
            j += 1
    return result

时序基准

(基于生成器的filter_fromiter()方法比其他方法慢得多 - 大约慢。 2个数量级。 从列表理解中可以预期类似(也许稍微差一些)的性能。 对于使用非加速代码的任何显式循环都是如此。)

时间将取决于输入数组的大小和过滤项目的百分比。

作为输入大小的函数

第一张图将时间作为输入大小的函数(对于约 50% 的过滤因子——即 50% 的元素出现在结果中):

一般来说,具有一种加速形式的显式循环会导致最快的执行,但根据输入大小略有变化。

在 NumPy 中,整数索引方法基本上与布尔掩码相同。

使用的好处np.fromiter()(无预分配)可以通过编写 Numba 加速生成器来获得,该生成器比其他方法慢(在一个数量级内),但比纯 Python 循环快得多。

作为填充的函数

第二张图将计时作为通过过滤器的项目的函数(对于约 100 万个元素的固定输入大小):

第一个观察结果是,当接近 50% 填充时,所有方法都是最慢的,而填充较少或较多时,它们会更快,并且在没有填充时速度最快(滤除值的最高百分比、通过值的最低百分比,如图表的 x 轴)。

同样,具有某种加速方式的显式循环会带来最快的执行速度。

在 NumPy 中,整数索引和布尔掩码方法再次基本相同。

(完整代码可用here https://colab.research.google.com/drive/1SkZhw7wJrPTPWEJbhi1NWQZpScoG8918)


内存注意事项

基于生成器的filter_fromiter()方法只需要最少的临时存储,与输入的大小无关。 从记忆角度来看,这是最有效的方法。 使用 Numba 加速生成器可以有效地加速此方法。

Cython / Numba 两遍方法具有类似的内存效率,因为输出的大小是在第一遍期间确定的。 这里需要注意的是,计算条件必须很快,这些方法才能快速。

在内存方面,Cython 和 Numba 的单遍解决方案都需要输入大小的临时数组。 因此,与两次传递或基于生成器的传递相比,这些传递的内存效率不是很高。

然而,与掩码相比,它们具有相似的渐近临时内存占用量,但常数项通常大于掩码。

布尔掩码解决方案需要一个输入大小但类型相同的临时数组bool,在 NumPy 中为 1 个字节,因此这比典型 64 位系统上 NumPy 数组的默认大小小约 8 倍。

整数索引解决方案与第一步中的布尔掩码切片具有相同的要求(内部np.nonzero()调用),它被转换为一系列ints(通常int64在 64 位系统上)第二步(输出np.nonzero())。 因此,第二步具有可变的内存需求,具体取决于过滤元素的数量。


Remarks

  • 布尔掩码和整数索引都需要某种形式的条件,能够生成布尔掩码(或者索引列表);在上面的实现中,条件是可广播的
  • 在指定不同的过滤条件时,生成器和 Numba 加速方法也是最灵活的
  • Numba 加速方法需要条件与 Numba 兼容才能在 NoPython 模式下访问 Numba 加速
  • Cython 解决方案需要指定数据类型才能快速运行,或者需要付出额外的努力来进行多种类型分派,并付出额外的努力(此处未探讨)以获得与其他方法相同级别的灵活性
  • 对于 Numba 和 Cython,过滤条件可以是硬编码的,从而导致微小但明显的速度差异
  • 单遍解决方案需要额外的代码来处理未使用的(但最初分配的)内存。
  • NumPy 方法做NOT返回输入的视图,但是一个副本,作为结果高级索引 https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing:
arr = np.arange(100)
k = 50
print('`arr[arr > k]` is a copy: ', arr[arr > k].base is None)
# `arr[arr > k]` is a copy:  True
print('`arr[np.where(arr > k)]` is a copy: ', arr[np.where(arr > k)].base is None)
# `arr[np.where(arr > k)]` is a copy:  True
print('`arr[:k]` is a copy: ', arr[:k].base is None)
# `arr[:k]` is a copy:  False

(已编辑:基于@ShadowRanger、@PaulPanzer、@max9111 和@DavidW 评论的各种改进。)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

过滤(减少)NumPy 数组 的相关文章

随机推荐