筛选NumPy数组

Filtering a NumPy Array(筛选NumPy数组)

本文介绍了筛选NumPy数组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有一个NumPy数组arr,我想要按元素进行筛选,例如 我只想获取低于某个阈值k的值。

有几种方法,例如:

  1. 使用生成器:np.fromiter((x for x in arr if x < k), dtype=arr.dtype)
  2. 使用布尔掩码切片:arr[arr < k]
  3. 使用np.where()arr[np.where(arr < k)]
  4. 使用np.nonzero()arr[np.nonzero(arr < k)]
  5. 使用基于Cython的自定义实现
  6. 使用基于Numba的自定义实现

哪一个是最快的?内存效率如何?


(编辑:根据@ShadowRanger评论增加np.nonzero())

推荐答案

摘要

以下测试旨在提供对不同方法的一些见解,应持保留态度。 这里测试的不完全是一般过滤,而只是应用阈值,这具有计算条件相当快的显著特征。如果该条件意味着昂贵的计算,则将获得非常不同的结果。

基本上,两遍加速(无论是使用Numba还是Cython--只要您事先知道类型)将是最快和更高效的内存,除了非常大的输入,对于这种情况,单遍Numba/Cython更快(代价是使用更大的临时内存)。使用np.where()/np.nonzero()而不是直接使用掩码可能会导致计算略微加快,而且只要输出的大小小于输入的50%,通常不会有任何影响(可能除了较大的临时内存占用)。np.fromiter()方法要慢得多,但不会生成大型临时对象。


定义

  1. 使用生成器:
def filter_fromiter(arr, k):
    return np.fromiter((x for x in arr if x < k), dtype=arr.dtype)
  1. 使用布尔掩码切片:
def filter_mask(arr, k):
    return arr[arr < k]
  1. 使用np.where()
def filter_where(arr, k):
    return arr[np.where(arr < k)]
  1. 使用np.nonzero()
def filter_nonzero(arr, k):
    return arr[np.nonzero(arr < k)]
  1. 使用基于Cython的自定义实现:
    • 单程filter_cy()
    • 两遍filter2_cy()
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


cimport numpy as cnp
cimport cython as ccy

import numpy as np
import cython as cy


cdef long NUM = 1048576
cdef long MAX_VAL = 1048576
cdef long K = 1048576 // 2


cdef int smaller_than_cy(long x, long k=K):
    return x < k


cdef size_t _filter_cy(long[:] arr, long[:] result, size_t size, long k):
    cdef size_t j = 0
    for i in range(size):
        if smaller_than_cy(arr[i]):
            result[j] = arr[i]
            j += 1
    return j


cpdef filter_cy(arr, k):
    result = np.empty_like(arr)
    new_size = _filter_cy(arr, result, arr.size, k)
    result.resize(new_size)
    return result


cdef size_t _filtered_size(long[:] arr, size_t size, long k):
    cdef size_t j = 0
    for i in range(size):
        if smaller_than_cy(arr[i]):
            j += 1
    return j


cpdef filter2_cy(arr, k):
    cdef size_t new_size = _filtered_size(arr, arr.size, k)
    result = np.empty(new_size, dtype=arr.dtype)
    new_size = _filter_cy(arr, result, arr.size, k)
    return result

import functools


filter_np_cy = functools.partial(filter_cy, k=K)
filter_np_cy.__name__ = 'filter_np_cy'


filter2_np_cy = functools.partial(filter2_cy, k=K)
filter2_np_cy.__name__ = 'filter2_np_cy'
  1. 使用基于Numba的自定义实现
    • 单程filter_np_nb()
    • 两遍filter2_np_nb()
import numba as nb
import functools


@nb.njit
def filter_func(x, k):
    return x < k


@nb.njit
def filter_nb(arr, result, k):
    j = 0
    for i in range(arr.size):
        if filter_func(arr[i], k):
            result[j] = arr[i]
            j += 1
    return j


def filter_np_nb(arr, k=K):
    result = np.empty_like(arr)
    j = filter_nb(arr, result, k)
    result.resize(j, refcheck=False)
    return result


@nb.njit
def filter2_nb(arr, k):
    j = 0
    for i in range(arr.size):
        if filter_func(arr[i], k):
            j += 1
    result = np.empty(j, dtype=arr.dtype)
    j = 0
    for i in range(arr.size):
        if filter_func(arr[i], k):
            result[j] = arr[i]
            j += 1
    return result


filter2_np_nb = functools.partial(filter2_nb, k=K)
filter2_np_nb.__name__ = 'filter2_np_nb'

时间基准

基于生成器的filter_fromiter()方法比其他方法慢得多(大约2个数量级,因此在图表中省略)。

计时将取决于输入数组大小和筛选项目的百分比。

作为输入大小的函数

第一个图将计时作为输入大小的函数(对于约50%过滤出的元素):

总的来说,基于Numba的方法始终是最快的,紧随其后的是Cython方法。在这些方法中,两遍方法通常是最快的,但在输入非常大的情况下,单遍方法往往会占据主导地位。在NumPy中,基于np.where()和基于np.nonzero()的方法基本上是相同的(除了非常小的输入,对于np.nonzero()似乎稍微慢一些),它们都比布尔掩码切片快,除了非常小的输入(低于~100个元素),其中布尔掩码切片更快。 此外,对于非常小的输入,基于Cython的解决方案比基于NumPy的解决方案要慢。

作为填充的函数

第二个图将计时作为通过过滤器的项目的函数来处理(对于大约100万个元素的固定输入大小):

第一个观察结果是,所有方法在接近50%的填充值时都是最慢的,而在填充较少或较多的情况下,它们更快,在不填充的情况下最快(滤出值的最高百分比,通过值的最低百分比,如图表的x轴所示)。 同样,Numba和Cython版本通常都比基于NumPy的版本更快,其中Numba几乎总是最快的,而Cython在图表的最右边部分击败了Numba。 对于较大的填充值,两次通过方法的边际速度收益不断增加,直到大约。50%,之后单传占据速度领奖台。 在NumPy中,基于np.where()np.nonzero()的方法基本上是相同的。 与基于NumPy的解决方案相比,当填充小于60%时,np.where()/np.nonzero()解决方案优于布尔掩码切片,之后布尔掩码切片速度最快。

(完整代码可用here)


内存注意事项

基于生成器的filter_fromiter()方法只需要最少的临时存储空间,与输入的大小无关。 就内存而言,这是最有效的方法。 具有类似内存效率的是Cython/Numba两遍方法,因为输出的大小是在第一遍期间确定的。

在内存端,Cython和Numba的单遍解决方案都需要输入大小的临时数组。 因此,与两遍或基于生成器的遍相比,它们的内存效率不是很高。然而,与掩码相比,它们具有相似的渐近临时内存占用,但常量项通常大于掩码。

布尔掩码切片解决方案需要一个输入大小但类型为bool的临时数组,在NumPy中为1位,因此这比典型64位系统上的NumPy数组的默认大小小约64倍。

基于np.nonzero()/np.where()的解决方案与第一步中的布尔掩码切片(在np.nonzero()/np.where()中)具有相同的要求,后者在第二步(np.nonzero()/np.where()的输出)中被转换为一系列int(在64位系统上通常为int64))。因此,第二步具有不同的内存要求,具体取决于筛选的元素的数量。


备注

  • 在指定不同的过滤条件时,生成器方法也是最灵活的
  • Cython解决方案需要指定数据类型以使其速度更快,否则需要为多个类型分派额外的工作
  • 对于Numba和Cython,可以将筛选条件指定为泛型函数(因此不需要进行硬编码),但必须在其各自的环境中指定,并且必须注意确保正确编译以提高速度,否则会出现明显的减速。
  • 单遍解决方案需要额外的代码来处理未使用的(但最初分配的)内存。
  • NumPy方法不返回输入的视图,而是返回一个副本,结果是advanced indexing:
arr = np.arange(100)
k = 50
print('`arr[arr > k]` is a copy: ', arr[arr > k].base is None)
# `arr[arr > k]` is a copy:  True
print('`arr[np.where(arr > k)]` is a copy: ', arr[np.where(arr > k)].base is None)
# `arr[np.where(arr > k)]` is a copy:  True
print('`arr[:k]` is a copy: ', arr[:k].base is None)
# `arr[:k]` is a copy:  False

(编辑:包括基于np.nonzero()的解决方案和修复了内存泄漏,避免了单次通过的Cython/Numba版本的复制,包括两次通过的Cython/Numba版本--基于@ShadowRanger、@PaulPanzer、@Max9111和@DavidW的评论。)

这篇关于筛选NumPy数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!

本文标题为:筛选NumPy数组

基础教程推荐