Pytorch中Tensor的索引,切片以及花式索引(fancy indexing)
对pytorch中tensor的索引切片,以及更为高级的花式索引的理解
前一段时间遇到一个花式索引的问题,在搜索良久之后没有找到确切的答案,苦苦摸索许久才理解清楚,于是想要写一篇博客来详细的讲讲我对Pytorch中Tensor索引的一些理解。包括普通的索引,切片,一起花式索引。
理解Tensor的dim
我们可以从1维的tensor开始,例如:
>>> a = torch.tensor([0, 1 ,2, 3, 4])
a可以简单的理解为一个一维的数组,但还不够形象。我么先换个形象的角度,暂时抛去数组的概念,我们可以理解为现在有五个元素(数字)排列一个轴上(类似于数轴,也就是dim),那么我们自然而然的就会想要去给每一个元素一个下标,这也就是我们使用的从0开始的整数下标。那么对于这个轴的方向,每一个下标都会确定一个元素(例如下标2,对应着a里面的2)。那么直观的来看就像下面的图一样:
我们可以注意到,现在参与排队的元素是数字,并且只在一个方向上排队,这就构成了最简单的1维的tensor。那么从1维到2维的过程有两种理解方式:
- 仍然是在一个方向上排队,但参与排队的元素是已经按照另一个方向排好队的等长的队伍(维度增加)
- 排队元素仍然是一个数字,但是现在按照两个方向(两个维度)进行排队
那么我们可以看看这样一个2维的tensor,
>>> b = torch.tensor([[0, 1],
>>> [2, 3],
>>> [4, 5]])
那么按照上面两种理解可以理解为:
- 现在在dim=0的方向上有三个元素要参与排列,分别是 [ 0 , 1 ] , [ 2 , 3 ] , [ 3 , 4 ] [0, 1], [2,3],[3,4] [0,1],[2,3],[3,4]。这个时候他们自己排列的方向为dim=1,由此构成了一个2维的tensor。
- 现在有6个元素在两个方向上dim=0,dim=1上排列,其中dim=0方向有三个位置,dim=1的方向上有两个位置。由此构成一个二维的数组。
如下图所示:
那么如果是一个3维或者跟高维度的tensor呢,同样的道理,在多个方向上对已经在某些方向上排列好了的一组数字进行排列。这也就是高维数组是数组的数组说法的一个体现。
回过来再看看二维tensor,如果我们把dim=1方向上的所有数字看作一个整体(1维的一个tensor, [ 0 , 1 ] , [ 2 , 3 ] , [ 4 , 5 ] [0,1],[2,3],[4,5] [0,1],[2,3],[4,5]三个),那么我们可以理解为b是这三个1维tensor的一个1维的tensor,那么这个时候如果我们把dim=0方向上的数字看作一个整体呢,那我们得到的是 [ 0 , 2 , 4 ] , [ 1 , 3 , 5 ] [0,2,4],[1,3,5] [0,2,4],[1,3,5]构成的的一个维tensor。那如果我们把两个维度看作一个整体,那我们得到的是单单一个二维tensor,不用排列。
那么如果是一个3维的或者更高维度的tensor,我们可以将多个维度(m)看作是一个整体,那么剩余的维度(n)构成了一个所有元素都是m维tensor的n维tensor。
那么接下来的索引,切片过程我们也能方便的理解。
索引
简单索引
使用数字进行索引,一般对于一个n维的tensor,索引形式为:
T
[
d
0
,
d
1
,
.
.
.
,
d
t
]
,
t
<
n
.
T[d_0, d_1,...,d_t], t<n.
T[d0,d1,...,dt],t<n.
那么这就相当于将其看作是一个tensor的tensor,将元素按照下标取出。
例如:
>>> a[0]
tensor(0)
>>> a[3]
tensor(3)
>>> b[1]
tensor([2, 3])
>>> b[:, 1] # 展示将dim=0作为整体
tensor([1, 3, 5])
>>> b[1, 1]
tensor(3)
用1维的list,numpy,tensor索引
将整个tensor看作是一个由n-1维tensor构成的1维tensor。将每个取出的元素排列,构成一个新的tensor。
>>> a[[1, 2 ,1]]
tensor([1, 2, 1])
>>> b[[2, 0, 2]]
tensor([[4, 5],
>>> [0, 1],
>>> [4, 5]])
>>> c = torch.rand([4, 3])
>>> c
tensor([[0.6478, 0.3120, 0.6656],
[0.4470, 0.6383, 0.6878],
[0.9854, 0.9709, 0.4868],
[0.1797, 0.3453, 0.9005]])
>>> c[[1, 2, 1]]
tensor([[0.4470, 0.6383, 0.6878],
[0.9854, 0.9709, 0.4868],
[0.4470, 0.6383, 0.6878]])
用booltensor索引
使用booltensor B对T进行索引,需要满足如下条件:
B
.
s
i
z
e
(
)
:
(
b
0
,
b
1
,
.
.
.
,
b
t
−
1
)
T
.
s
i
z
e
(
)
:
(
d
0
,
d
1
,
.
.
.
,
d
t
−
1
,
.
.
.
,
d
n
−
1
)
b
i
=
d
i
,
∀
i
<
t
,
n
≥
t
.
B.size():(b_0, b_1, ...,b_{t-1}) \\ T.size():(d_0, d_1, ...,d_{t-1}, ...,d_{n-1}) \\ b_i = d_i,\forall i < t,\\ n\ge t.
B.size():(b0,b1,...,bt−1)T.size():(d0,d1,...,dt−1,...,dn−1)bi=di,∀i<t,n≥t.
那么其意义就是将True位置的元素取出,构成一个
n
−
t
+
1
n-t+1
n−t+1维的新的tensor。例如:
>>> boolt_1 = torch.tensor([False, True, False, True, True])
>>> a[boolt_1]
tensor([1, 3, 4])
>>> boolt_2 = torch.tensor([True, False, True])
>>> b[boolt_2]
tensor([[0, 1],
[4, 5]])
>>> boolt_3 = torch.tensor([[False, True],
>>> [True, False],
>>> [True, True]])
tensor([1, 2, 4, 5])
>>> boolt_4 = torch.tensor([[False, True],
>>> [True, False]])
>>> b[boolt_4] # 如果不符合条件
Traceback (most recent call last):
File "<input>", line 1, in <module>
IndexError: The shape of the mask [2, 2] at index 0 does not match the shape of the indexed tensor [3, 2] at index 0
可以看到维数条件是严格要求B的所有维度的长度正好等于被索引的tensor的对应维度的长度。
切片
相信再list里面就已经学过切片的概念了,主要使用形如
[
s
t
a
r
t
:
e
n
d
:
s
t
e
p
]
[start : end : step]
[start:end:step]的组合进行子序列的抽取,其中
:
s
t
e
p
: step
:step可选,默认为1,
s
t
a
r
t
start
start和
e
n
d
end
end也可选,分别默认为0和len(obj)。
那么在tensor中遇到的主要是如下形式:
T
[
s
0
:
e
0
:
s
t
0
,
.
.
.
,
s
n
−
1
:
e
n
−
1
:
s
t
n
−
1
]
T[s_0:e_0:st_0,...,s_{n-1}:e_{n-1}:st_{n-1}]
T[s0:e0:st0,...,sn−1:en−1:stn−1]
在每个维度处的切片都相当于将对该维度进行相应的切片操作,在该维度上保留对应下标的元素(数,或者tensor,或者什么都不剩)。例如:
>>> a[1:3]
tensor([1, 2])
>>> a[0:4:2]
tensor([0, 2])
>>> b[0:2, 0:1]
tensor([[0],
[2]])
>>> b[1:, 1:]
tensor([[3],
[5]])
>>> d = torch.randint(5, (3, 3, 3))
>>> d
tensor([[[2, 1, 4],
[4, 1, 1],
[1, 2, 4]],
[[1, 4, 4],
[0, 3, 4],
[1, 2, 2]],
[[4, 4, 4],
[1, 3, 3],
[0, 0, 4]]])
>>> d[:, 1:2, 0:2]
tensor([[[4, 1]],
[[0, 3]],
[[1, 3]]])
显然切片和索引可以进行组合,效果就是将所有切片操作的维度在切片之后构成整体,看作是排列的元素,剩余索引的维度就是取出对应的这些元素。例如:
>>> d[[2, 0, 1], 1:2, 0:2]
tensor([[[1, 3]],
[[4, 1]],
[[0, 3]]])
花式索引
花式索引也是索引的一种,就是使用tensor对tensor进行索引,形如:
T
[
t
0
,
t
1
,
.
.
.
,
t
n
−
1
]
T[t_0, t_1,...,t_{n-1}]
T[t0,t1,...,tn−1]
其中的
t
i
,
i
=
0
,
.
.
.
,
n
−
1
t_i,i=0,...,n-1
ti,i=0,...,n−1是维度不限的long型tensor。
能够执行这一操作的先决条件是,
t
i
,
i
=
0
,
.
.
.
,
n
−
1
t_i,i=0,...,n-1
ti,i=0,...,n−1能够广播成同一形状。广播的机制建议自行了解。
整个语句的过程大致如下:
- 首先将 t i , i = 0 , . . . , n − 1 t_i,i=0,...,n-1 ti,i=0,...,n−1广播成同一形状,如果它们不是同一形状的话,假设最终形状为 ( b 0 , b 1 , . . . , b s ) (b_0, b_1, ...,b_s) (b0,b1,...,bs)
- 这时 [ t 0 , t 1 , . . . , t n − 1 ] [t_0, t_1,...,t_{n-1}] [t0,t1,...,tn−1]这些tensor的对应位置的元素构成1组坐标,总共 b 0 × b 1 × . . . × b s b_0\times b_1\times ...\times b_s b0×b1×...×bs组坐标。
- 每组坐标进行一次简单索引,取出的元素(可能是数,也可能是tensor)放在形状 ( b 0 , b 1 , . . . , b s ) (b_0, b_1, ...,b_s) (b0,b1,...,bs)的对应位置,例如如果是 [ t 0 , t 1 , . . . , t n − 1 ] [t_0, t_1,...,t_{n-1}] [t0,t1,...,tn−1]的所有 ( 0 , 0 , . . . , 0 ) (0,0, ...,0) (0,0,...,0)构成的坐标,那么将结果放在 ( 0 , 0 , . . . , 0 ) (0,0, ...,0) (0,0,...,0)处,得到最后结果。
那么我们看个例子:
>>> idx_0 = torch.tensor([[3, 2],[1, 4]])
>>> a[idx_0]
tensor([[3, 2],
[1, 4]])
>>> b # 查看一下b,
tensor([[0, 1],
[2, 3],
[4, 5]])
>>> idx_0 = torch.tensor([[1, 0],[2, 1]])
>>> idx_1 = torch.tensor([0, 1])
>>> b[idx_0, idx_1]
tensor([[2, 1],
[4, 3]])
分析一下过程,先是idx_1广播成
tensor([[0, 1],
[0, 1]])
构成四组坐标 [ 1 , 0 ] , [ 0 , 1 ] , [ 2 , 0 ] , [ 1 , 1 ] [1, 0],[0, 1],[2, 0],[1, 1] [1,0],[0,1],[2,0],[1,1],对应着 2 , 1 , 4 , 3 2, 1, 4, 3 2,1,4,3,放入对应位置得到最终结果。
那么其实花式索引也是索引的一种,不过是通过多次索引,并且组成新的tensor的更复杂的索引,同理也可以和切片结合。
结语
这些关于tensor的理解,都是个人理解,希望能够帮助到有需要的人,另外如有参考本篇博客,请注明链接,请勿直接搬运。
思考过很久,因为想要写写博客,记录自己成长的过程,但又迟迟没能动手,或者是写写停停,但总算是再一次尝试了一遍,以后也可能是随缘吧,尽力吧
更多推荐
所有评论(0)