torch学习 (42):tensor的索引、切片、拼接,以及变异操作 (Indexing, Slicing, Joining, Mutating)
对torch中tensor的索引、切片、拼接等操作进行说明
文章目录
- 1 argwhere
- 2 cat和concat
- 3 conj
- 4 chunk
- 5 dsplit
- 6 column_stack
- 7 dstack
- 8 gather
- 9 hstack
- 10 tensor.index_add_
- 11 index_select
- 12 masked_select
- 13 movedim和moveaxis
- 14 narrow
- 15 nonzero
- 16 permute
- 17 reshape
- 18 vstack
- 19 select
- 20 scatter
- 21 diagonal_scatter
- 22 select_scatter
- 23 slice_scatter
- 24 scatter_add
- 25 scatter_reduce
- 26 split
- 27 squeeze
- 28 stack
- 29 transpose和swapaxes和swapdims
- 30 t
- 31 take
- 32 take_along_dim
- 33 tensor_split
- 34 tile
- 35 unbind
- 36 unsqueeze
- 37 vsplit
- 38 where
1 argwhere
声明:torch.argwhere(input) → Tensor
用途:找到指定tensor中0元素的索引
import torch
torch.manual_seed(1)
t = torch.tensor([1, 0, 1])
a = torch.argwhere(t)
t = torch.tensor([[1, 0, 1], [0, 1, 1]])
b = torch.argwhere(t)
print(a)
print(b)
输出:
tensor([[0],
[2]])
tensor([[0, 0],
[0, 2],
[1, 1],
[1, 2]])
2 cat和concat
声明:torch.cat(tensors, dim=0, *, out=None) → Tensor;concat则是cat的别名
用途:指定维度拼接tensor
a = torch.randn(2, 3)
b = torch.cat((a, a, a), 0)
c = torch.cat((a, a, a), 1)
print(a)
print(b)
print(c)
输出:
tensor([[ 0.6614, 0.2669, 0.0617],
[ 0.6213, -0.4519, -0.1661]])
tensor([[ 0.6614, 0.2669, 0.0617],
[ 0.6213, -0.4519, -0.1661],
[ 0.6614, 0.2669, 0.0617],
[ 0.6213, -0.4519, -0.1661],
[ 0.6614, 0.2669, 0.0617],
[ 0.6213, -0.4519, -0.1661]])
tensor([[ 0.6614, 0.2669, 0.0617, 0.6614, 0.2669, 0.0617, 0.6614, 0.2669,
0.0617],
[ 0.6213, -0.4519, -0.1661, 0.6213, -0.4519, -0.1661, 0.6213, -0.4519,
-0.1661]])
3 conj
声明:torch.conj(input) → Tensor
用途:返回复数的共轭数,共轭是指实部相同虚部互为相反数的两个复数
a = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
print(a)
print(a.is_conj())
b = torch.conj(a)
print(b)
print(b.is_conj())
输出:
tensor([-1.+1.j, -2.+2.j, 3.-3.j])
False
tensor([-1.-1.j, -2.-2.j, 3.+3.j])
True
4 chunk
声明:torch.chunk(input, chunks, dim=0) → List of Tensors
用途:返回给定tensor的多个块
a = torch.randn(4, 3)
print(a)
b = a.chunk(2)
print(b)
c = a.chunk(2, dim=1)
print(c)
输出:
tensor([[ 0.6614, 0.2669, 0.0617],
[ 0.6213, -0.4519, -0.1661],
[-1.5228, 0.3817, -1.0276],
[-0.5631, -0.8923, -0.0583]])
(tensor([[ 0.6614, 0.2669, 0.0617],
[ 0.6213, -0.4519, -0.1661]]), tensor([[-1.5228, 0.3817, -1.0276],
[-0.5631, -0.8923, -0.0583]]))
(tensor([[ 0.6614, 0.2669],
[ 0.6213, -0.4519],
[-1.5228, 0.3817],
[-0.5631, -0.8923]]), tensor([[ 0.0617],
[-0.1661],
[-1.0276],
[-0.0583]]))
5 dsplit
声明:torch.dsplit(input, indices_or_sections) → List of Tensors
用途:在dim=2上划分tensor
a = torch.arange(16.0).reshape(2, 2, 4)
print(a)
print(torch.dsplit(a, 2))
输出:
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.]],
[[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]])
(tensor([[[ 0., 1.],
[ 4., 5.]],
[[ 8., 9.],
[12., 13.]]]), tensor([[[ 2., 3.],
[ 6., 7.]],
[[10., 11.],
[14., 15.]]]))
6 column_stack
声明:torch.column_stack(tensors, *, out=None) → Tensor
用途:列方向拼接
a = torch.arange(5)
b = torch.arange(10).reshape(5, 2)
print(torch.column_stack((a, b, b)))
输出:
tensor([[0, 0, 1, 0, 1],
[1, 2, 3, 2, 3],
[2, 4, 5, 4, 5],
[3, 6, 7, 6, 7],
[4, 8, 9, 8, 9]])
7 dstack
声明:torch.dstack(tensors, *, out=None) → Tensor
用途:相当于在dim=2拼接tensor
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
print(torch.dstack((a,b)))
a = torch.tensor([[1],[2],[3]])
b = torch.tensor([[4],[5],[6]])
print(torch.dstack((a,b)))
输出:
tensor([[[1, 4],
[2, 5],
[3, 6]]])
tensor([[[1, 4]],
[[2, 5]],
[[3, 6]]])
8 gather
声明:torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
用途:沿着指定维度汇聚值,具体如下:
out[i][j][k] = input[index[i][j][k]][j][k] # dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # dim == 2
a = torch.tensor([[1, 2], [3, 4]])
b = torch.gather(a, 1, torch.tensor([[0, 0], [1, 0]]))
print(a)
print(b)
输出:
tensor([[1, 2],
[3, 4]])
tensor([[1, 1],
[4, 3]])
9 hstack
声明:torch.hstack(tensors, *, out=None) → Tensor
用途:水平方向堆叠
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
print(torch.hstack((a,b)))
a = torch.tensor([[1],[2],[3]])
b = torch.tensor([[4],[5],[6]])
print(torch.hstack((a,b)))
输出:
tensor([1, 2, 3, 4, 5, 6])
tensor([[1, 4],
[2, 5],
[3, 6]])
10 tensor.index_add_
声明:Tensor.index_add_(dim, index, source, , alpha=1) → Tensor
用途:在输入tensor的指定索引出加上sourcealpha
x = torch.ones(5, 3)
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
index = torch.tensor([0, 4, 2])
a = x.index_add_(0, index, t)
print(a)
b = x.index_add_(0, index, t, alpha=-2)
print(b)
输出:
tensor([[ 2., 3., 4.],
[ 1., 1., 1.],
[ 8., 9., 10.],
[ 1., 1., 1.],
[ 5., 6., 7.]])
tensor([[ 0., -1., -2.],
[ 1., 1., 1.],
[-6., -7., -8.],
[ 1., 1., 1.],
[-3., -4., -5.]])
11 index_select
声明:torch.index_select(input, dim, index, *, out=None) → Tensor
用途:选取指定维度下的指定索引的tensor元素
x = torch.randn(3, 4)
print(x)
indices = torch.tensor([0, 2])
a = torch.index_select(x, 0, indices)
b = torch.index_select(x, 1, indices)
print(a)
print(b)
输出:
tensor([[ 0.6614, 0.2669, 0.0617, 0.6213],
[-0.4519, -0.1661, -1.5228, 0.3817],
[-1.0276, -0.5631, -0.8923, -0.0583]])
tensor([[ 0.6614, 0.2669, 0.0617, 0.6213],
[-1.0276, -0.5631, -0.8923, -0.0583]])
tensor([[ 0.6614, 0.0617],
[-0.4519, -1.5228],
[-1.0276, -0.8923]])
12 masked_select
声明:torch.masked_select(input, mask, *, out=None) → Tensor
用途:masked是一个布尔tensor,返回其为真所对应索引的元素,并组成一个新的1D tensor
x = torch.randn(3, 4)
print(x)
mask = x.ge(0.5)
print(mask)
a = torch.masked_select(x, mask)
print(a)
输出:
tensor([[ 0.6614, 0.2669, 0.0617, 0.6213],
[-0.4519, -0.1661, -1.5228, 0.3817],
[-1.0276, -0.5631, -0.8923, -0.0583]])
tensor([[ True, False, False, True],
[False, False, False, False],
[False, False, False, False]])
tensor([0.6614, 0.6213])
13 movedim和moveaxis
声明:torch.movedim(input, source, destination) → Tensor;等价于moveaxis
用途:指定维度之间进行交换
t = torch.randn(3,2,1)
print(torch.movedim(t, 1, 0).shape)
print(torch.movedim(t, (1, 2), (0, 1)).shape)
输出:
torch.Size([2, 3, 1])
torch.Size([2, 1, 3])
14 narrow
声明:torch.narrow(input, dim, start, length) → Tensor
用途:返回tensor的指定维度下的指定区间
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(torch.narrow(x, 0, 0, 2))
print(torch.narrow(x, 1, 1, 2))
输出:
tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[2, 3],
[5, 6],
[8, 9]])
15 nonzero
声明:torch.nonzero(input, *, out=None, as_tuple=False) → LongTensor or tuple of LongTensors
用途:返回tensor的非零值的索引
print(torch.nonzero(torch.tensor([1, 1, 1, 0, 1])))
print(torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
[0.0, 0.4, 0.0, 0.0],
[0.0, 0.0, 1.2, 0.0],
[0.0, 0.0, 0.0, -0.4]])))
print(torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True))
print(torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
[0.0, 0.4, 0.0, 0.0],
[0.0, 0.0, 1.2, 0.0],
[0.0, 0.0, 0.0, -0.4]]), as_tuple=True))
print(torch.nonzero(torch.tensor(5), as_tuple=True))
输出:
tensor([[0],
[1],
[2],
[4]])
tensor([[0, 0],
[1, 1],
[2, 2],
[3, 3]])
(tensor([0, 1, 2, 4]),)
(tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]))
(tensor([0]),)
16 permute
声明:torch.permute(input, dims) → Tensor
用途:序列变换,返回指定维度下的新视角
x = torch.randn(2, 3, 5)
print(x.size())
print(torch.permute(x, (2, 0, 1)).size())
输出:
torch.Size([2, 3, 5])
torch.Size([5, 2, 3])
17 reshape
声明:torch.reshape(input, shape) → Tensor
用途:改变形状
a = torch.arange(4.)
print(torch.reshape(a, (2, 2)))
b = torch.tensor([[0, 1], [2, 3]])
print(torch.reshape(b, (-1,)))
输出:
tensor([[0., 1.],
[2., 3.]])
tensor([0, 1, 2, 3])
18 vstack
声明:torch.vstack(tensors, *, out=None) → Tensor;别名row_stack
用途:竖直方向堆叠
19 select
声明:torch.select(input, dim, index) → Tensor
用途:等价于切片操作
20 scatter
声明:torch.scatter(input, dim, index, src) → Tensor;等价于Tensor.scatter_(dim, index, src, reduce=None) → Tensor
用途:将src上元素按顺序散布在input的给定维度的给定index上
src = torch.arange(1, 11).reshape((2, 5))
print(src)
index = torch.tensor([[0, 1, 2, 0]])
a = torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
print(a)
输出:
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
21 diagonal_scatter
声明:torch.diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) → Tensor
用途:对角散布
a = torch.zeros(3, 3)
b = torch.diagonal_scatter(a, torch.ones(3), 0)
print(b)
b = torch.diagonal_scatter(a, torch.ones(2), 1)
print(b)
输出:
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
tensor([[0., 1., 0.],
[0., 0., 1.],
[0., 0., 0.]])
22 select_scatter
声明:torch.select_scatter(input, src, dim, index) → Tensor
用途:选择散布
a = torch.zeros(2, 2)
b = torch.ones(2)
print(a.select_scatter(b, 0, 0))
输出:
tensor([[1., 1.],
[0., 0.]])
23 slice_scatter
声明:torch.slice_scatter(input, src, dim=0, start=None, end=None, step=1) → Tensor
用途:切片散布
a = torch.zeros(8, 8)
b = torch.ones(2, 8)
print(a.slice_scatter(b, start=6))
输出:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.]])
24 scatter_add
声明:torch.scatter_add(input, dim, index, src) → Tensor;等价于scatter_add_
用途:散布叠加
src = torch.ones((2, 5))
index = torch.tensor([[0, 1, 2, 0, 0]])
print(torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src))
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]])
print(torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src))
输出:
tensor([[1., 0., 0., 1., 1.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.]])
tensor([[2., 0., 0., 1., 1.],
[0., 2., 0., 0., 0.],
[0., 0., 2., 1., 1.]])
25 scatter_reduce
声明:torch.scatter_reduce(input, dim, index, reduce, *, output_size=None) → Tensor
用途:散布下的reduce操作
input = torch.tensor([1, 2, 3, 4, 5, 6])
index = torch.tensor([0, 1, 0, 1, 2, 1])
print(torch.scatter_reduce(input, 0, index, reduce="sum", output_size=3)) # sum, prod, mean, amax, amin
输出:
tensor([ 4, 12, 5])
26 split
声明:torch.split(tensor, split_size_or_sections, dim=0)
用途:切片
a = torch.arange(10).reshape(5, 2)
print(a)
print(torch.split(a, 2))
print(torch.split(a, [1,4]))
输出:
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
(tensor([[0, 1],
[2, 3]]), tensor([[4, 5],
[6, 7]]), tensor([[8, 9]]))
(tensor([[0, 1]]), tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
27 squeeze
声明:torch.squeeze(input, dim=None, *, out=None) → Tensor
用途:去除所有维度为1的维度后的显示;如果指定维度,则只对指定维度生效
x = torch.zeros(2, 1, 2, 1, 2)
print(x.size())
y = torch.squeeze(x)
print(y.size())
y = torch.squeeze(x, 0)
print(y.size())
y = torch.squeeze(x, 1)
print(y.size())
输出:
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 2])
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 1, 2])
28 stack
声明:torch.stack(tensors, dim=0, *, out=None) → Tensor
用途:堆叠tensor
29 transpose和swapaxes和swapdims
声明:torch.transpose(input, dim0, dim1) → Tensor;等价于swapaxes和swapdims
用途:转置指定维度
30 t
声明:torch.t(input) → Tensor
用途:转置2Dtensor
31 take
声明:torch.take(input, index) → Tensor
用途:将input视作1D,并返回给定index元素
src = torch.tensor([[4, 3, 5],
[6, 7, 8]])
print(torch.take(src, torch.tensor([0, 2, 5])))
输出:
tensor([4, 5, 8])
32 take_along_dim
声明:torch.take_along_dim(input, indices, dim, *, out=None) → Tensor
用途:沿维度选取
t = torch.tensor([[10, 30, 20], [60, 40, 50]])
sorted_idx = torch.argsort(t, dim=1)
print(torch.take_along_dim(t, sorted_idx, dim=1))
输出:
tensor([[10, 20, 30],
[40, 50, 60]])
33 tensor_split
声明:torch.tensor_split(input, indices_or_sections, dim=0) → List of Tensors
用途:切片
x = torch.arange(8)
print(torch.tensor_split(x, 3))
输出:
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]))
34 tile
声明:torch.tile(input, dims) → Tensor
用途:指定维度复制
x = torch.tensor([1, 2, 3])
x = x.tile((2,))
print(x)
y = torch.tensor([[1, 2], [3, 4]])
y = torch.tile(y, (2, 2))
print(y)
输出:
tensor([1, 2, 3, 1, 2, 3])
tensor([[1, 2, 1, 2],
[3, 4, 3, 4],
[1, 2, 1, 2],
[3, 4, 3, 4]])
35 unbind
声明:torch.unbind(input, dim=0) → seq
用途:去除维度
x = torch.unbind(torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]))
print(x)
输出:
(tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9]))
36 unsqueeze
声明:torch.unsqueeze(input, dim) → Tensor
用途:插入维度
x = torch.tensor([1, 2, 3, 4])
a = torch.unsqueeze(x, 0)
b = torch.unsqueeze(x, 1)
print(a)
print(b)
输出:
tensor([[1, 2, 3, 4]])
tensor([[1],
[2],
[3],
[4]])
37 vsplit
声明:torch.vsplit(input, indices_or_sections) → List of Tensors
用途:竖直方向切片
38 where
声明:torch.where(condition, x, y) → Tensor
用途:满足条件返回x,否则返回y
x = torch.randn(3, 2)
y = torch.ones(3, 2)
a = torch.where(x > 0, x, y)
print(a)
输出:
tensor([[0.6614, 0.2669],
[0.0617, 0.6213],
[1.0000, 1.0000]])
更多推荐
所有评论(0)