Tensor:索引操作
索引操作Tensor支持与numpy.ndarray类似的索引操作,语法上也类似,下面通过一些例子,讲解常用的索引操作。如无特殊说明,索引出来的结果与原tensor共享内存,也即修改一个,另一个会跟着修改。In[31]:a = t.randn(3, 4)aOut[31]:tensor([[ 1.1741,1.4335, -0.8156,0.7622],[-0.6334, -1.4628, -0.7
索引操作
Tensor支持与numpy.ndarray类似的索引操作,语法上也类似,下面通过一些例子,讲解常用的索引操作。如无特殊说明,索引出来的结果与原tensor共享内存,也即修改一个,另一个会跟着修改。
In [31]:
a = t.randn(3, 4)
a
Out[31]:
tensor([[ 1.1741, 1.4335, -0.8156, 0.7622], [-0.6334, -1.4628, -0.7428, 0.0410], [-0.6551, 1.0258, 2.0572, 0.3923]])
In [32]:
a[0] # 第0行(下标从0开始)
Out[32]:
tensor([ 1.1741, 1.4335, -0.8156, 0.7622])
In [33]:
a[:, 0] # 第0列
Out[33]:
tensor([ 1.1741, -0.6334, -0.6551])
In [34]:
a[0][2] # 第0行第2个元素,等价于a[0, 2]
Out[34]:
tensor(-0.8156)
In [35]:
a[0, -1] # 第0行最后一个元素
Out[35]:
tensor(0.7622)
In [36]:
a[:2] # 前两行
Out[36]:
tensor([[ 1.1741, 1.4335, -0.8156, 0.7622],
[-0.6334, -1.4628, -0.7428, 0.0410]])
In [37]:
a[:2, 0:2] # 前两行,第0,1列
Out[37]:
tensor([[ 1.1741, 1.4335],
[-0.6334, -1.4628]])
In [38]:
print(a[0:1, :2]) # 第0行,前两列
print(a[0, :2]) # 注意两者的区别:形状不同
Out[38]:
tensor([[1.1741, 1.4335]])
tensor([1.1741, 1.4335])
In [39]:
# None类似于np.newaxis, 为a新增了一个轴
# 等价于a.view(1, a.shape[0], a.shape[1])
a[None].shape
Out[39]:
torch.Size([1, 3, 4])
In [40]:
a[None].shape # 等价于a[None,:,:]
Out[40]:
torch.Size([1, 3, 4])
In [41]:
a[:,None,:].shape
Out[41]:
torch.Size([3, 1, 4])
In [42]:
a[:,None,:,None,None].shape
Out[42]:
torch.Size([3, 1, 4, 1, 1])
In [43]:
a > 1 # 返回一个ByteTensor
Out[43]:
tensor([[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 1, 1, 0]], dtype=torch.uint8)
In [44]:
a[a>1] # 等价于a.masked_select(a>1)
# 选择结果与原tensor不共享内存空间
Out[44]:
tensor([1.1741, 1.4335, 1.0258, 2.0572])
In [45]:
a[t.LongTensor([0,1])] # 第0行和第1行
Out[45]:
tensor([[ 1.1741, 1.4335, -0.8156, 0.7622],
[-0.6334, -1.4628, -0.7428, 0.0410]])
其它常用的选择函数如表3-2所示。
表3-2常用的选择函数
函数 | 功能 |
---|---|
index_select(input, dim, index) | 在指定维度dim上选取,比如选取某些行、某些列 |
masked_select(input, mask) | 例子如上,a[a>0],使用ByteTensor进行选取 |
non_zero(input) | 非0元素的下标 |
gather(input, dim, index) | 根据index,在dim维度上选取数据,输出的size与index一样 |
gather
是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下:
out[i][j] = input[index[i][j]][j] # dim=0
out[i][j] = input[i][index[i][j]] # dim=1
三维tensor的gather
操作同理,下面举几个例子。
In [46]:
a = t.arange(0, 16).view(4, 4) a
Out[46]:
tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]])
In [47]:
# 选取对角线的元素 index = t.LongTensor([[0,1,2,3]]) a.gather(0, index)
Out[47]:
tensor([[ 0, 5, 10, 15]])
In [48]:
# 选取反对角线上的元素 index = t.LongTensor([[3,2,1,0]]).t() a.gather(1, index)
Out[48]:
tensor([[ 3], [ 6], [ 9], [12]])
In [49]:
# 选取反对角线上的元素,注意与上面的不同 index = t.LongTensor([[3,2,1,0]]) a.gather(0, index)
Out[49]:
tensor([[12, 9, 6, 3]])
In [50]:
# 选取两个对角线上的元素 index = t.LongTensor([[0,1,2,3],[3,2,1,0]]).t() b = a.gather(1, index) b
Out[50]:
tensor([[ 0, 3], [ 5, 6], [10, 9], [15, 12]])
与gather
相对应的逆操作是scatter_
,gather
把数据从input中按index取出,而scatter_
是把取出的数据再放回去。注意scatter_
函数是inplace操作。
out = input.gather(dim, index) -->近似逆操作 out = Tensor() out.scatter_(dim, index)
In [51]:
# 把两个对角线元素放回去到指定位置 c = t.zeros(4,4) c.scatter_(1, index, b.float())
Out[51]:
tensor([[ 0., 0., 0., 3.], [ 0., 5., 6., 0.], [ 0., 9., 10., 0.], [12., 0., 0., 15.]])
对tensor的任何索引操作仍是一个tensor,想要获取标准的python对象数值,需要调用tensor.item()
, 这个方法只对包含一个元素的tensor适用
In [52]:
a[0,0] #依旧是tensor)
Out[52]:
tensor(0)
In [53]:
a[0,0].item() # python float
Out[53]:
0
In [54]:
d = a[0:1, 0:1, None] print(d.shape) d.item() # 只包含一个元素的tensor即可调用tensor.item,与形状无关
torch.Size([1, 1, 1])
Out[54]:
0
In [55]:
# a[0].item() -> # raise ValueError: only one element tensors can be converted to Python scalars
高级索引
PyTorch在0.2版本中完善了索引操作,目前已经支持绝大多数numpy的高级索引1。高级索引可以看成是普通索引操作的扩展,但是高级索引操作的结果一般不和原始的Tensor共享内存。
In [56]:
x = t.arange(0,27).view(3,3,3) x
Out[56]:
tensor([[[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8]], [[ 9, 10, 11], [12, 13, 14], [15, 16, 17]], [[18, 19, 20], [21, 22, 23], [24, 25, 26]]])
In [57]:
x[[1, 2], [1, 2], [2, 0]] # x[1,1,2]和x[2,2,0]
Out[57]:
tensor([14, 24])
In [58]:
x[[2, 1, 0], [0], [1]] # x[2,0,1],x[1,0,1],x[0,0,1]
Out[58]:
tensor([19, 10, 1])
In [59]:
x[[0, 2], ...] # x[0] 和 x[2]
Out[59]:
tensor([[[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8]], [[18, 19, 20], [21, 22, 23], [24, 25, 26]]])
更多推荐
所有评论(0)