今天讲二维和三维数组如何索引,重点在二维,三维是从二维拓展而来;另外,末尾我们还介绍了一个工具,专门用来索引,即torch.index_select(),或者a.index_select()。

二维

二维数组很容易索引,例如:

c=torch.rand(2,3)
print(c)

在这里插入图片描述

#取行
print(c[[0,1],:])
#取列
print(c[:,[1,2]])

在这里插入图片描述

#取行列,所以就是某个元素了。
print(c[[0],[0]])#取(0,0)
print(c[[0,0,0],[0,1,2]])#取(0,0),(0,1),(0,2)

在这里插入图片描述
需要注意的是,上述写成下述则不是我们想要的:

index=torch.tensor([[0,0,0],[0,1,2]])
print(c[index])

在这里插入图片描述
索引的时候我们最好是使用list,上述可以改成:

index=torch.tensor([[0,0,0],[0,1,2]])
index=index.tolist()
print(c[index])

在这里插入图片描述

三维

三维其实可以从二维中递推而得到,基本原理还是一样,之前是有两个位置,所以是一个逗号,现在变成三个位置,所以是2个逗号。
这里只做简单的演示即可:
在这里插入图片描述
在这里插入图片描述

index_select

a=torch.rand(3,4)
print(a)

在这里插入图片描述

indices = torch.tensor([0, 2])
print(a.index_select(0,indices))#取0行和2行。
print(a.index_select(1,indices))#取0列和2列。

在这里插入图片描述

gather

最近又看到一个函数,属实懵逼了,上面这些索引好像就够了,但是这个也能索引。

tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])

比如说我们想要索引出5,7,9。那么用gather怎么实现呢?我们可以指定维度为行,dim=1,然后再指定[2,1,0],这就表示[0,2],[1,1],[2,0]了;那么如果我们先指定维度为列,dim=0,那么似乎就无法索引出[5,7,9]这个顺序了,其会索引出[9,7,5]的顺序。我们来看一下:

index = torch.tensor([[2], [1], [0]])
torch.gather(tensor_0,1,index)

tensor([[5],
[7],
[9]])

可以看到,相当于他会帮你补齐另外一个坐标,[2],那么由于维度是1,从而是[0,2],而不是[2,0]。另外,一个硬性要求是:index必须和tensor_0维度相同,即都是二维的。这个要求我觉得挺奇葩的。

index = torch.tensor([[2,1,0]])
torch.gather(tensor_0,0,index)

tensor([[9, 7, 5]])

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐