筛选和过滤
这小节与索引和切片有点类似,但倾向于从「整体」中统一筛选出「符合条件」的内容,而索引和切片更多的是依照「某种方法」切出一块内容。本小节内容同样非常重要,可以算第二个最重要的小节。主要包括以下内容:
- 条件筛选
- 提取(按条件)
- 抽样(按分布)
- 最大最小 index(特殊值)
这几个内容都很重要,使用的也非常高频。条件筛选经常用于 Mask 或异常值处理,提取则常用于结果过滤,抽样常用在数据生成(比如负样本抽样),最大最小 index 则常见于机器学习模型预测结果判定中(根据最大概率所在的 index 决定结果属于哪一类)。
rng = np.random.default_rng(42)
arr = rng.integers(1, 100, (3, 4))
arr
array([[ 9, 77, 65, 44],
[43, 86, 9, 70],
[20, 10, 53, 97]])
条件筛选
⭐⭐⭐
顾名思义,根据一定的条件对 array 进行筛选(标记)并后续处理。核心 API 是 np.where
。
⚠️ 需要注意的是:where 分别返回各维度的 index,赋值的是「不满足」条件的。
# 条件筛选,可以直接在整个 array 上使用条件
arr > 50
array([[False, True, True, False],
[False, True, False, True],
[False, False, True, True]])
# 返回满足条件的索引,因为是两个维度,所以会返回两组结果
np.where(arr > 50)
(array([0, 0, 1, 1, 2, 2]), array([1, 2, 1, 3, 2, 3]))
# 不满足条件的赋值,将 <=50 的替换为 -1
np.where(arr > 50, arr, -1)
array([[-1, 77, 65, -1],
[-1, 86, -1, 70],
[-1, -1, 53, 97]])
提取
⭐
在 array 中提取指定条件的值。
⚠️ 需要注意的是:提取和唯一值返回的都是一维向量。
# 提取指定条件的值
np.extract(arr > 50, arr)
array([77, 65, 86, 70, 53, 97])
# 唯一值,是另一种形式的提取
np.unique(arr)
array([ 9, 10, 20, 43, 44, 53, 65, 70, 77, 86, 97])
抽样
⭐⭐⭐⭐⭐
我们在跑模型时常常需要使用部分数据对整个过程快速验证,您当然可以使用 np.random
生成模拟数据。但有真实数据时,从真实数据中随机抽样会比较好。
rng = np.random.default_rng(42)
# 第一个参数是要抽样的集合,如果是一个整数,则表示从 0 到该值
# 第二个参数是样本大小
# 第三个参数表示结果是否可以重复
# 第四个参数表示出现的概率,长度和第一个参数一样
# 由于(0 1 2 3)中 2 和 3 的概率比较高,自然就选择了 2 和 3
rng.choice(4, 2, replace=False, p=[0.1, 0.2, 0.3, 0.4])
array([3, 2])
# 旧的 API
# 如果是抽样语料的 index,更多的方法是这样:
data_size = 10000
np.random.choice(data_size, 50, replace=False)
array([6339, 4894, 1531, 7814, 224, 9538, 9619, 3801, 3359, 3617, 2795,
6627, 8501, 1681, 4212, 5085, 2439, 744, 9123, 6733, 5688, 5480,
6983, 7058, 310, 1838, 5072, 746, 5873, 9372, 5953, 4944, 1780,
464, 1247, 845, 1807, 7354, 4925, 547, 2996, 3909, 7344, 9617,
8642, 661, 2453, 5475, 228, 2427])
最值 Index
⭐⭐⭐⭐⭐
这小节主要是两个 API:np.argmax(min)
和 np.argsort
,当然最常用的还是第一个,不用说,自然是可以(需要)指定 axis 的。
rng = np.random.default_rng(42)
arr = rng.uniform(1, 100, (3, 4))
arr
array([[77.62164881, 44.44896554, 86.00119407, 70.03943488],
[10.32355744, 97.58661281, 76.3528305 , 78.82036622],
[13.68324963, 45.58820785, 37.7090044 , 92.7497339 ]])
np.argmax/argmin
# 所有值中最大值的 Index,基本不这么用
np.argmax(arr)
5
# 按列(axis=0)最大值的 Index
np.argmax(arr, axis=0)
array([0, 1, 0, 2])
# 按行(axis=1)最小值的 Index
np.argmin(arr, axis=1)
array([1, 0, 0])
np.argsort
arr
array([[77.62164881, 44.44896554, 86.00119407, 70.03943488],
[10.32355744, 97.58661281, 76.3528305 , 78.82036622],
[13.68324963, 45.58820785, 37.7090044 , 92.7497339 ]])
# 默认按行(axis=1)排序的索引
np.argsort(arr)
array([[1, 3, 0, 2],
[0, 2, 3, 1],
[0, 2, 1, 3]])
# 数据按行(axis=1)排序的索引,同上
np.argsort(arr, axis=1)
array([[1, 3, 0, 2],
[0, 2, 3, 1],
[0, 2, 1, 3]])
# 数据按列(axis=0)排序索引
np.argsort(arr, axis=0)
array([[1, 0, 2, 0],
[2, 2, 1, 1],
[0, 1, 0, 2]])