在看许多pytorch的代码时,为了计算上的方便,通常会用到unsqueeze函数,一直不得要领,这次专门去做个实验学习一下。

官方文档对这个函数描述如下,就是在指定的位置插入一个维度,有两个参数,input是输入的tensor,dim是要插到的维度

需要注意的是dim的范围是[-input.dim()-1, input.dim()+1),是一个左闭右开的区间,当dim为负值时,会自动转换为dim = dim+input.dim()+1,类似于使用负数对python列表进行切片。

 

下面使用一个二维矩阵看下dim不同时呈现出的效果:

# 创建一个3*4的全1二维tensor
a = torch.ones(3,4)
'''
运行结果
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
'''

在0维度上插入一个维度,可以看到现在a的形状变为[1, 3, 4],第0维度的大小默认是1

a = a.unsqueeze(0)
print(a.shape)
'''
运行结果
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])

torch.Size([1, 3, 4])
'''

在最后一个维度上插入一个维度,形状变为[3, 4, 1]

a = a.unsqueeze(a.dim())
print(a.shape)
'''
运行结果
tensor([[[1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.]]])

torch.Size([3, 4, 1])
'''

 

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐