张量的合并操作

张量的合并操作类似与列表的追加元素,可以拼接、也可以堆叠。

1.拼接方法:.cat()

PyTorch中,可以使用.cat()方法实现张量的拼接,不改变张量形状,并且返回结果是原张量的视图
(1).cat()方法的使用,第一个参数和第二个参数:为目标张量,第三个参数:0表示行数增加,1表示列数增加,注意观察张量形状

a = torch.zeros(2, 3)  #创建2行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
                 [0., 0., 0.]])
b = torch.ones(2, 3)   #创建2行3列元素全部为1(浮点型)的二维张量
#结果为:tensor([[1., 1., 1.],
                 [1., 1., 1.]])
c = torch.zeros(3, 3)  #创建3行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
                 [0., 0., 0.],
                 [0., 0., 0.]])
torch.cat([a, b])     # 按照行进行拼接,dim默认取值为0,行数增加,观察仔细
#结果为:tensor([[0., 0., 0.],
                 [0., 0., 0.],
                 [1., 1., 1.],
                 [1., 1., 1.]])            
torch.cat([a, b], 1)   # 按照列进行拼接,列数增加,观察仔细,矩阵的嵌套
#结果为:tensor([[0., 0., 0., 1., 1., 1.],
                 [0., 0., 0., 1., 1., 1.]])
torch.cat([a, c], 1)         # 形状不匹配时将报错                 
#报错
RuntimeError                              Traceback (most recent call last)
<ipython-input-153-8bdd1a857266> in <module>
----> 1 torch.cat([a, c], 1)               # 形状不匹配时将报错
RuntimeError: Sizes of tensors must match except in dimension 1. Got 2 and 3 in dimension 0 (The offending index is 1)

注意理解:拼接的本质是实现元素的堆积,也就是构成a、b两个二维张量的各一维张量的堆积,最终还是构成二维向量。

2.堆叠方法:.stack()

和拼接不同,堆叠不是将元素拆分重装,而是简单的将各参与堆叠的对象分装到一个更高维度的张量里
(1).stack()方法的使用,第一个参数和第二个参数:为目标张量,第三个参数:0表示行数增加,1表示列数增加,注意观察张量形状

a = torch.zeros(2, 3)  #创建2行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
                 [0., 0., 0.]])
b = torch.ones(2, 3)   #创建2行3列元素全部为1(浮点型)的二维张量
#结果为:tensor([[1., 1., 1.],
                 [1., 1., 1.]])
c = torch.zeros(3, 3)  #创建3行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
                 [0., 0., 0.],
                 [0., 0., 0.]])
torch.stack([a, b])     # 堆叠之后,生成一个三维张量
#结果为:tensor([[[0., 0., 0.],
                  [0., 0., 0.]],

                 [[1., 1., 1.],
                  [1., 1., 1.]]])     
torch.stack([a, b]).shape  #查看堆叠后的形状
#结果为:torch.Size([2, 2, 3])
表示:其是3维张量,由2个二维张量组成,每个二维张量有2个一维张量组成,每个一维张量有3个元素
torch.cat([a, b])
#结果为:tensor([[0., 0., 0.],
                 [0., 0., 0.],
                 [1., 1., 1.],
                 [1., 1., 1.]])

注意对比二者区别,拼接之后维度不变,堆叠之后维度升高。拼接是把一个个元素单独提取出来之后再放到二维张量中,而堆叠则是直接将两个二维张量封装到一个三维张量中,因此,堆叠的要求更高,参与堆叠的张量必须形状完全相同。

a = torch.zeros(2, 3)  #创建2行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
                 [0., 0., 0.]])
c = torch.zeros(3, 3)  #创建3行3列元素全部为零(浮点型)的二维张量
#结果为:tensor([[0., 0., 0.],
                 [0., 0., 0.],
                 [0., 0., 0.]])
torch.cat([a, c])    # 横向拼接时,对行数没有一致性要求
#结果为:tensor([[0., 0., 0.],
                 [0., 0., 0.],
                 [0., 0., 0.],
                 [0., 0., 0.],
                 [0., 0., 0.]]) 
torch.stack([a, c])    # 维度不匹配时,堆叠也会报错     
#报错:
RuntimeError                              Traceback (most recent call last)
<ipython-input-167-0311d15e051e> in <module>
----> 1 torch.stack([a, c])               # 维度不匹配时也会报错

RuntimeError: stack expects each tensor to be equal size, but got [2, 3] at entry 0 and [3, 3] at entry 1                          

以上不是本人的浅显见解,还请他人多多指导,更正错误。

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐