前一段时间遇到一个花式索引的问题,在搜索良久之后没有找到确切的答案,苦苦摸索许久才理解清楚,于是想要写一篇博客来详细的讲讲我对Pytorch中Tensor索引的一些理解。包括普通的索引,切片,一起花式索引。

理解Tensor的dim

我们可以从1维的tensor开始,例如:

>>> a = torch.tensor([0, 1 ,2, 3, 4])

a可以简单的理解为一个一维的数组,但还不够形象。我么先换个形象的角度,暂时抛去数组的概念,我们可以理解为现在有五个元素(数字)排列一个轴上(类似于数轴,也就是dim),那么我们自然而然的就会想要去给每一个元素一个下标,这也就是我们使用的从0开始的整数下标。那么对于这个轴的方向,每一个下标都会确定一个元素(例如下标2,对应着a里面的2)。那么直观的来看就像下面的图一样:
在这里插入图片描述

我们可以注意到,现在参与排队的元素是数字,并且只在一个方向上排队,这就构成了最简单的1维的tensor。那么从1维到2维的过程有两种理解方式:

  • 仍然是在一个方向上排队,但参与排队的元素是已经按照另一个方向排好队的等长的队伍(维度增加)
  • 排队元素仍然是一个数字,但是现在按照两个方向(两个维度)进行排队

那么我们可以看看这样一个2维的tensor,

>>> b = torch.tensor([[0, 1],
>>> 			      [2, 3], 
>>> 			      [4, 5]])

那么按照上面两种理解可以理解为:

  • 现在在dim=0的方向上有三个元素要参与排列,分别是 [ 0 , 1 ] , [ 2 , 3 ] , [ 3 , 4 ] [0, 1], [2,3],[3,4] [0,1],[2,3],[3,4]。这个时候他们自己排列的方向为dim=1,由此构成了一个2维的tensor。
  • 现在有6个元素在两个方向上dim=0,dim=1上排列,其中dim=0方向有三个位置,dim=1的方向上有两个位置。由此构成一个二维的数组。

如下图所示:
在这里插入图片描述
那么如果是一个3维或者跟高维度的tensor呢,同样的道理,在多个方向上对已经在某些方向上排列好了的一组数字进行排列。这也就是高维数组是数组的数组说法的一个体现。

回过来再看看二维tensor,如果我们把dim=1方向上的所有数字看作一个整体(1维的一个tensor, [ 0 , 1 ] , [ 2 , 3 ] , [ 4 , 5 ] [0,1],[2,3],[4,5] [0,1],[2,3],[4,5]三个),那么我们可以理解为b是这三个1维tensor的一个1维的tensor,那么这个时候如果我们把dim=0方向上的数字看作一个整体呢,那我们得到的是 [ 0 , 2 , 4 ] , [ 1 , 3 , 5 ] [0,2,4],[1,3,5] [0,2,4],[1,3,5]构成的的一个维tensor。那如果我们把两个维度看作一个整体,那我们得到的是单单一个二维tensor,不用排列。

那么如果是一个3维的或者更高维度的tensor,我们可以将多个维度(m)看作是一个整体,那么剩余的维度(n)构成了一个所有元素都是m维tensor的n维tensor。

那么接下来的索引,切片过程我们也能方便的理解。

索引

简单索引

使用数字进行索引,一般对于一个n维的tensor,索引形式为:
T [ d 0 , d 1 , . . . , d t ] , t < n . T[d_0, d_1,...,d_t], t<n. T[d0,d1,...,dt],t<n.
那么这就相当于将其看作是一个tensor的tensor,将元素按照下标取出。
例如:

>>> a[0]
tensor(0)
>>> a[3]
tensor(3)
>>> b[1]
tensor([2, 3])
>>> b[:, 1]  # 展示将dim=0作为整体
tensor([1, 3, 5])
>>> b[1, 1]
tensor(3)

用1维的list,numpy,tensor索引

将整个tensor看作是一个由n-1维tensor构成的1维tensor。将每个取出的元素排列,构成一个新的tensor。

>>> a[[1, 2 ,1]]
tensor([1, 2, 1])
>>> b[[2, 0, 2]]
tensor([[4, 5],
>>>     [0, 1],
>>>     [4, 5]])
>>> c = torch.rand([4, 3])
>>> c
tensor([[0.6478, 0.3120, 0.6656],
        [0.4470, 0.6383, 0.6878],
        [0.9854, 0.9709, 0.4868],
        [0.1797, 0.3453, 0.9005]])
>>> c[[1, 2, 1]]
tensor([[0.4470, 0.6383, 0.6878],
        [0.9854, 0.9709, 0.4868],
        [0.4470, 0.6383, 0.6878]])

用booltensor索引

使用booltensor B对T进行索引,需要满足如下条件:
B . s i z e ( ) : ( b 0 , b 1 , . . . , b t − 1 ) T . s i z e ( ) : ( d 0 , d 1 , . . . , d t − 1 , . . . , d n − 1 ) b i = d i , ∀ i < t , n ≥ t . B.size():(b_0, b_1, ...,b_{t-1}) \\ T.size():(d_0, d_1, ...,d_{t-1}, ...,d_{n-1}) \\ b_i = d_i,\forall i < t,\\ n\ge t. B.size():(b0,b1,...,bt1)T.size():(d0,d1,...,dt1,...,dn1)bi=di,i<t,nt.
那么其意义就是将True位置的元素取出,构成一个 n − t + 1 n-t+1 nt+1维的新的tensor。例如:

>>> boolt_1 = torch.tensor([False, True, False, True, True])
>>> a[boolt_1]
tensor([1, 3, 4])
>>> boolt_2 = torch.tensor([True, False, True])
>>> b[boolt_2]
tensor([[0, 1],
        [4, 5]])
>>> boolt_3 = torch.tensor([[False, True], 
>>>                         [True, False],
>>>                         [True, True]])
tensor([1, 2, 4, 5])
>>> boolt_4 = torch.tensor([[False, True], 
>>>                         [True, False]])
>>> b[boolt_4]  # 如果不符合条件
Traceback (most recent call last):
  File "<input>", line 1, in <module>
IndexError: The shape of the mask [2, 2] at index 0 does not match the shape of the indexed tensor [3, 2] at index 0

可以看到维数条件是严格要求B的所有维度的长度正好等于被索引的tensor的对应维度的长度。

切片

相信再list里面就已经学过切片的概念了,主要使用形如 [ s t a r t : e n d : s t e p ] [start : end : step] [start:end:step]的组合进行子序列的抽取,其中 : s t e p : step :step可选,默认为1, s t a r t start start e n d end end也可选,分别默认为0和len(obj)。
那么在tensor中遇到的主要是如下形式:
T [ s 0 : e 0 : s t 0 , . . . , s n − 1 : e n − 1 : s t n − 1 ] T[s_0:e_0:st_0,...,s_{n-1}:e_{n-1}:st_{n-1}] T[s0:e0:st0,...,sn1:en1:stn1]
在每个维度处的切片都相当于将对该维度进行相应的切片操作,在该维度上保留对应下标的元素(数,或者tensor,或者什么都不剩)。例如:

>>> a[1:3]
tensor([1, 2])
>>> a[0:4:2]
tensor([0, 2])
>>> b[0:2, 0:1]
tensor([[0],
        [2]])
>>> b[1:, 1:]
tensor([[3],
        [5]])
>>> d = torch.randint(5, (3, 3, 3))
>>> d
tensor([[[2, 1, 4],
         [4, 1, 1],
         [1, 2, 4]],
        [[1, 4, 4],
         [0, 3, 4],
         [1, 2, 2]],
        [[4, 4, 4],
         [1, 3, 3],
         [0, 0, 4]]])
>>> d[:, 1:2, 0:2] 
tensor([[[4, 1]],
        [[0, 3]],
        [[1, 3]]])

显然切片和索引可以进行组合,效果就是将所有切片操作的维度在切片之后构成整体,看作是排列的元素,剩余索引的维度就是取出对应的这些元素。例如:

>>> d[[2, 0, 1], 1:2, 0:2]
tensor([[[1, 3]],
        [[4, 1]],
        [[0, 3]]])

花式索引

花式索引也是索引的一种,就是使用tensor对tensor进行索引,形如:
T [ t 0 , t 1 , . . . , t n − 1 ] T[t_0, t_1,...,t_{n-1}] T[t0,t1,...,tn1]
其中的 t i , i = 0 , . . . , n − 1 t_i,i=0,...,n-1 ti,i=0,...,n1是维度不限的long型tensor。
能够执行这一操作的先决条件是, t i , i = 0 , . . . , n − 1 t_i,i=0,...,n-1 ti,i=0,...,n1能够广播成同一形状。广播的机制建议自行了解。
整个语句的过程大致如下:

  • 首先将 t i , i = 0 , . . . , n − 1 t_i,i=0,...,n-1 ti,i=0,...,n1广播成同一形状,如果它们不是同一形状的话,假设最终形状为 ( b 0 , b 1 , . . . , b s ) (b_0, b_1, ...,b_s) (b0,b1,...,bs)
  • 这时 [ t 0 , t 1 , . . . , t n − 1 ] [t_0, t_1,...,t_{n-1}] [t0,t1,...,tn1]这些tensor的对应位置的元素构成1组坐标,总共 b 0 × b 1 × . . . × b s b_0\times b_1\times ...\times b_s b0×b1×...×bs组坐标。
  • 每组坐标进行一次简单索引,取出的元素(可能是数,也可能是tensor)放在形状 ( b 0 , b 1 , . . . , b s ) (b_0, b_1, ...,b_s) (b0,b1,...,bs)的对应位置,例如如果是 [ t 0 , t 1 , . . . , t n − 1 ] [t_0, t_1,...,t_{n-1}] [t0,t1,...,tn1]的所有 ( 0 , 0 , . . . , 0 ) (0,0, ...,0) (0,0,...,0)构成的坐标,那么将结果放在 ( 0 , 0 , . . . , 0 ) (0,0, ...,0) (0,0,...,0)处,得到最后结果。

那么我们看个例子:

>>> idx_0 = torch.tensor([[3, 2],[1, 4]])
>>> a[idx_0]
tensor([[3, 2],
        [1, 4]])
>>> b  # 查看一下b,
tensor([[0, 1],
        [2, 3],
        [4, 5]])
>>> idx_0 = torch.tensor([[1, 0],[2, 1]])
>>> idx_1 = torch.tensor([0, 1])
>>> b[idx_0, idx_1]
tensor([[2, 1],
        [4, 3]])

分析一下过程,先是idx_1广播成

tensor([[0, 1],
		[0, 1]])

构成四组坐标 [ 1 , 0 ] , [ 0 , 1 ] , [ 2 , 0 ] , [ 1 , 1 ] [1, 0],[0, 1],[2, 0],[1, 1] [1,0],[0,1],[2,0],[1,1],对应着 2 , 1 , 4 , 3 2, 1, 4, 3 2,1,4,3,放入对应位置得到最终结果。

那么其实花式索引也是索引的一种,不过是通过多次索引,并且组成新的tensor的更复杂的索引,同理也可以和切片结合。

结语

这些关于tensor的理解,都是个人理解,希望能够帮助到有需要的人,另外如有参考本篇博客,请注明链接,请勿直接搬运。
思考过很久,因为想要写写博客,记录自己成长的过程,但又迟迟没能动手,或者是写写停停,但总算是再一次尝试了一遍,以后也可能是随缘吧,尽力吧

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐