Filtering a NumPy Array(筛选NumPy数组)
问题描述
假设我有一个NumPy数组arr
,我想要按元素进行筛选,例如
我只想获取低于某个阈值k
的值。
有几种方法,例如:
- 使用生成器:
np.fromiter((x for x in arr if x < k), dtype=arr.dtype)
- 使用布尔掩码切片:
arr[arr < k]
- 使用
np.where()
:arr[np.where(arr < k)]
- 使用
np.nonzero()
:arr[np.nonzero(arr < k)]
- 使用基于Cython的自定义实现
- 使用基于Numba的自定义实现
哪一个是最快的?内存效率如何?
(编辑:根据@ShadowRanger评论增加np.nonzero()
)
推荐答案
摘要
以下测试旨在提供对不同方法的一些见解,应持保留态度。 这里测试的不完全是一般过滤,而只是应用阈值,这具有计算条件相当快的显著特征。如果该条件意味着昂贵的计算,则将获得非常不同的结果。
基本上,两遍加速(无论是使用Numba还是Cython--只要您事先知道类型)将是最快和更高效的内存,除了非常大的输入,对于这种情况,单遍Numba/Cython更快(代价是使用更大的临时内存)。使用np.where()
/np.nonzero()
而不是直接使用掩码可能会导致计算略微加快,而且只要输出的大小小于输入的50%,通常不会有任何影响(可能除了较大的临时内存占用)。np.fromiter()
方法要慢得多,但不会生成大型临时对象。定义
- 使用生成器:
def filter_fromiter(arr, k):
return np.fromiter((x for x in arr if x < k), dtype=arr.dtype)
- 使用布尔掩码切片:
def filter_mask(arr, k):
return arr[arr < k]
- 使用
np.where()
:
def filter_where(arr, k):
return arr[np.where(arr < k)]
- 使用
np.nonzero()
def filter_nonzero(arr, k):
return arr[np.nonzero(arr < k)]
- 使用基于Cython的自定义实现:
- 单程
filter_cy()
- 两遍
filter2_cy()
- 单程
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True
cimport numpy as cnp
cimport cython as ccy
import numpy as np
import cython as cy
cdef long NUM = 1048576
cdef long MAX_VAL = 1048576
cdef long K = 1048576 // 2
cdef int smaller_than_cy(long x, long k=K):
return x < k
cdef size_t _filter_cy(long[:] arr, long[:] result, size_t size, long k):
cdef size_t j = 0
for i in range(size):
if smaller_than_cy(arr[i]):
result[j] = arr[i]
j += 1
return j
cpdef filter_cy(arr, k):
result = np.empty_like(arr)
new_size = _filter_cy(arr, result, arr.size, k)
result.resize(new_size)
return result
cdef size_t _filtered_size(long[:] arr, size_t size, long k):
cdef size_t j = 0
for i in range(size):
if smaller_than_cy(arr[i]):
j += 1
return j
cpdef filter2_cy(arr, k):
cdef size_t new_size = _filtered_size(arr, arr.size, k)
result = np.empty(new_size, dtype=arr.dtype)
new_size = _filter_cy(arr, result, arr.size, k)
return result
import functools
filter_np_cy = functools.partial(filter_cy, k=K)
filter_np_cy.__name__ = 'filter_np_cy'
filter2_np_cy = functools.partial(filter2_cy, k=K)
filter2_np_cy.__name__ = 'filter2_np_cy'
- 使用基于Numba的自定义实现
- 单程
filter_np_nb()
- 两遍
filter2_np_nb()
- 单程
import numba as nb
import functools
@nb.njit
def filter_func(x, k):
return x < k
@nb.njit
def filter_nb(arr, result, k):
j = 0
for i in range(arr.size):
if filter_func(arr[i], k):
result[j] = arr[i]
j += 1
return j
def filter_np_nb(arr, k=K):
result = np.empty_like(arr)
j = filter_nb(arr, result, k)
result.resize(j, refcheck=False)
return result
@nb.njit
def filter2_nb(arr, k):
j = 0
for i in range(arr.size):
if filter_func(arr[i], k):
j += 1
result = np.empty(j, dtype=arr.dtype)
j = 0
for i in range(arr.size):
if filter_func(arr[i], k):
result[j] = arr[i]
j += 1
return result
filter2_np_nb = functools.partial(filter2_nb, k=K)
filter2_np_nb.__name__ = 'filter2_np_nb'
时间基准
基于生成器的filter_fromiter()
方法比其他方法慢得多(大约2个数量级,因此在图表中省略)。
计时将取决于输入数组大小和筛选项目的百分比。
作为输入大小的函数
第一个图将计时作为输入大小的函数(对于约50%过滤出的元素):
总的来说,基于Numba的方法始终是最快的,紧随其后的是Cython方法。在这些方法中,两遍方法通常是最快的,但在输入非常大的情况下,单遍方法往往会占据主导地位。在NumPy中,基于np.where()
和基于np.nonzero()
的方法基本上是相同的(除了非常小的输入,对于np.nonzero()
似乎稍微慢一些),它们都比布尔掩码切片快,除了非常小的输入(低于~100个元素),其中布尔掩码切片更快。
此外,对于非常小的输入,基于Cython的解决方案比基于NumPy的解决方案要慢。
作为填充的函数
第二个图将计时作为通过过滤器的项目的函数来处理(对于大约100万个元素的固定输入大小):
第一个观察结果是,所有方法在接近50%的填充值时都是最慢的,而在填充较少或较多的情况下,它们更快,在不填充的情况下最快(滤出值的最高百分比,通过值的最低百分比,如图表的x轴所示)。
同样,Numba和Cython版本通常都比基于NumPy的版本更快,其中Numba几乎总是最快的,而Cython在图表的最右边部分击败了Numba。
对于较大的填充值,两次通过方法的边际速度收益不断增加,直到大约。50%,之后单传占据速度领奖台。
在NumPy中,基于np.where()
和np.nonzero()
的方法基本上是相同的。
与基于NumPy的解决方案相比,当填充小于60%时,np.where()
/np.nonzero()
解决方案优于布尔掩码切片,之后布尔掩码切片速度最快。
(完整代码可用here)
内存注意事项
基于生成器的filter_fromiter()
方法只需要最少的临时存储空间,与输入的大小无关。
就内存而言,这是最有效的方法。
具有类似内存效率的是Cython/Numba两遍方法,因为输出的大小是在第一遍期间确定的。
在内存端,Cython和Numba的单遍解决方案都需要输入大小的临时数组。 因此,与两遍或基于生成器的遍相比,它们的内存效率不是很高。然而,与掩码相比,它们具有相似的渐近临时内存占用,但常量项通常大于掩码。
布尔掩码切片解决方案需要一个输入大小但类型为bool
的临时数组,在NumPy中为1位,因此这比典型64位系统上的NumPy数组的默认大小小约64倍。
np.nonzero()
/np.where()
的解决方案与第一步中的布尔掩码切片(在np.nonzero()
/np.where()
中)具有相同的要求,后者在第二步(np.nonzero()
/np.where()
的输出)中被转换为一系列int
(在64位系统上通常为int64
))。因此,第二步具有不同的内存要求,具体取决于筛选的元素的数量。
备注
- 在指定不同的过滤条件时,生成器方法也是最灵活的
- Cython解决方案需要指定数据类型以使其速度更快,否则需要为多个类型分派额外的工作
- 对于Numba和Cython,可以将筛选条件指定为泛型函数(因此不需要进行硬编码),但必须在其各自的环境中指定,并且必须注意确保正确编译以提高速度,否则会出现明显的减速。
- 单遍解决方案需要额外的代码来处理未使用的(但最初分配的)内存。
- NumPy方法不返回输入的视图,而是返回一个副本,结果是advanced indexing:
arr = np.arange(100)
k = 50
print('`arr[arr > k]` is a copy: ', arr[arr > k].base is None)
# `arr[arr > k]` is a copy: True
print('`arr[np.where(arr > k)]` is a copy: ', arr[np.where(arr > k)].base is None)
# `arr[np.where(arr > k)]` is a copy: True
print('`arr[:k]` is a copy: ', arr[:k].base is None)
# `arr[:k]` is a copy: False
(编辑:包括基于np.nonzero()
的解决方案和修复了内存泄漏,避免了单次通过的Cython/Numba版本的复制,包括两次通过的Cython/Numba版本--基于@ShadowRanger、@PaulPanzer、@Max9111和@DavidW的评论。)
这篇关于筛选NumPy数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
本文标题为:筛选NumPy数组
基础教程推荐
- 使用Python匹配Stata加权xtil命令的确定方法? 2022-01-01
- 症状类型错误:无法确定关系的真值 2022-01-01
- 使用 Google App Engine (Python) 将文件上传到 Google Cloud Storage 2022-01-01
- 如何在Python中绘制多元函数? 2022-01-01
- 如何在 Python 中检测文件是否为二进制(非文本)文 2022-01-01
- 哪些 Python 包提供独立的事件系统? 2022-01-01
- 将 YAML 文件转换为 python dict 2022-01-01
- 合并具有多索引的两个数据帧 2022-01-01
- 使 Python 脚本在 Windows 上运行而不指定“.py";延期 2022-01-01
- Python 的 List 是如何实现的? 2022-01-01