• 1. 增加维度:

可以使用 unsqueeze(index) 给tensor增加某一个指定位置的维度,tensor的维度从0开始。若index>=0,则表示在第index维插入一个新的维度;若index<0,则表示在倒数第index维插入一个新的维度。

实例:

import torch 

a = torch.randn(8,4,9)
b = a.unsqueeze(0)
c = a.unsqueeze(1)
d = a.unsqueeze(-1)
e = a.unsqueeze(-2)

print('a.shape:',a.shape)
print('b.shape:',b.shape)
print('c.shape:',c.shape)
print('d.shape:',d.shape)
print('e.shape:',e.shape)

  • 2. 减少维度: 

可以使用squeeze(index) 取删除指定的size=1的维度,将多余的[]去掉。若index为空,则表示删除所有的size=1的维度;若index=0,则表示删除第0维的size=0的维度;若index=1,但第1维的size≠1,则删除失败。

实例:

import torch 

a = torch.randn(1,4,1,2)
b = a.squeeze()
c = a.squeeze(0)
d = a.squeeze(1)
e = a.squeeze(2)

print('a.shape:',a.shape)
print('b.shape:',b.shape)
print('c.shape:',c.shape)
print('d.shape:',d.shape)
print('e.shape:',e.shape)

  • 3. 改变维度:

可以使用reshape()改变tensor的维度。若某一维的值为正,则代表该维度的size大小;若某一维值为-1,则表示该维度的size大小不确定,要视其他维度而定。

实例:

import torch 

a = torch.randn(2,4,3,5)
b = a.reshape(2*4,3,5)
c = a.reshape(2*4*3,5)
d = a.reshape(-1,3,5)
e = a.reshape(-1,5)

print('a.shape:',a.shape)
print('b.shape:',b.shape)
print('c.shape:',c.shape)
print('d.shape:',d.shape)
print('e.shape:',e.shape)

  • 4. 维度变换1(交换2个维度):

可以使用transpose(dim1,dim2) 交换指定的两个维度dim1维和dim2维,但这种变换会使得存储不再连续,因此要加contiguous() 使其连续。

实例:

import torch 

a = torch.randn(2,4,3,5)
b = a.transpose(0,1).contiguous()
c = a.transpose(1,3).contiguous()

print('a.shape:',a.shape)
print('b.shape:',b.shape)
print('c.shape:',c.shape)

  • 5. 维度变换2(变换所有维度):

可以使用permute(dim1,dim2,dim3,...) 一次性变换所有维度的顺序

实例:

import torch 

a = torch.randn(2,4,3,5)
b = a.permute(1,0,3,2)

print('a.shape:',a.shape)
print('b.shape:',b.shape)

  • 6. 维度变换3(2维转置):

可以使用.t() 对二维tensor(也就是矩阵)进行转置操作。

实例:

import torch 

a = torch.randn(2,4)
b = a.t()

print('a.shape:',a.shape)
print('b.shape:',b.shape)

 

  • 7. 维度重复:

可以使用repeat()将每个维度重复到指定的次数。

实例:

import torch 

a = torch.randn(2,4,3,5)
b = a.repeat(1,3,3,2)

print('a.shape:',a.shape)
print('b.shape:',b.shape)

  •  8. 维度扩展:

可以使用expand() 对size=1 的维度进行维度扩展,若某一维不需要扩展,则可取-1

实例:

import torch 

a = torch.randn(1,4,1,1)
b = a.expand(3,-1,6,6)

print('a.shape:',a.shape)
print('b.shape:',b.shape)

 

 

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐