Python numpy使用记录3.数组筛选切片,np.where

前言

如果我想提取数组中大于某个阈值的所有元素,可以使用数组筛选后提取。

本篇记录数组筛选的方法,np.where

np.where的使用

np.where是numpy中用于元素筛选的函数,有两种使用方法。

1.筛选替换

函数原型:np.where(condition, x, y)condition表示数组与筛选条件,x表示满足条件的替换值,y表示不满足条件的替换值,函数返回替换矩阵(同维度)。举个例子:

import numpy as np

a = np.arange(12).reshape([3, 4])
# array([[ 0,  1,  2,  3],
#        [ 4,  5,  6,  7],
#        [ 8,  9, 10, 11]])

b = np.where(a>5, 1, 0)
# array([[0, 0, 0, 0],
#        [0, 0, 1, 1],
#        [1, 1, 1, 1]])

判断条件是数组a大于阈值5,当元素满足条件时,把这个元素替换成1,否则替换成0。

2.筛选提取

函数原型np.where(condition),当只有condition参数时,函数返回满足条件的元素的多维索引。举个例子:

import numpy as np

a = np.arange(12).reshape([3, 4])
# array([[ 0,  1,  2,  3],
#        [ 4,  5,  6,  7],
#        [ 8,  9, 10, 11]])

b = np.where(a>5)
# (array([1, 1, 2, 2, 2, 2]), array([2, 3, 0, 1, 2, 3]))

返回的是一个元组,包含两个数组。这两个数组是满足条件的元素的各维度索引,把两个数组拼接起来更加直观:

d = np.stack(b, axis=-1)
# array([[1 2]
#        [1 3]
#        [2 0]
# 		 [2 1]
# 		 [2 2]
# 		 [2 3]])

拼接后的结果就是符合条件的元素索引数组。

通过np.where得到筛选后的索引数组,就可以通过tuple索引实现筛选提取啦:

a[c]
# array([6, 7 ,8, 9, 10, 11])

注意:由于np.where返回的tuple长度与原数组相同,因此tuple索引坐标的维度与原数组也是相同的,提取出的数组必然ndim=1

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐