torch.view函数用法
这句话一般出现在model类的forward函数中,具体位置一般都是在调用分类器之前。分类器是一个简单的nn.Linear()结构,输入输出都是维度为一的值,x=x.view(x.size(0),-1)这句话的出现就是为了将前面多维度的tensor展平成一维。view中一个参数定为-1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。在使用pytorch定义神经网络时,经常会看到类似如下的
·
view
一、手动调整size
view( )相当于reshape、resize,对Tensor的形状进行调整。
例:
import torch
x1 = torch.arange(0,16)
print("x1:",x1)
#a1: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
--------------------------------------------------------------------------------------------
x2 = x1.view(8, 2)
x3 = x1.view(2, 8)
x4 = x1.view(4, 4)
print("x2:",x2)
print("x3:",x3)
print("x4:",x4)
x2: tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
x3: tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
x4: tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
二、自动调整size (参数-1)
例:
view中一个参数指定为-1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。
import torch
x1 = torch.arange(0,16)
print(x1)
#a1: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
------------------------------------------------------------------------------------------------------
x2 = x1.view(-1, 16)
x3 = x1.view(-1, 8)
x4 = x1.view(-1, 4)
x5 = x1.view(-1, 2)
x6 = x1.view(4*4, -1)
x7 = x1.view(1*4, -1)
x8 = x1.view(2*4, -1) #-1自动调整,8行有几列自动调整
print(x2)
print(x3)
print(x4)
print(x5)
print(x6)
print(x7)
print(x8)
x2: tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
x3: tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
x4: tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
x5: tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
x6: tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
x7: tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
x8: tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
分类器就是一个简单的nn.Linear()结构,输入输出都是一维的值,x = x.view(x.size(0), -1) 是为了将多维度的tensor展平成一维。
x = x.view(x.size(0), -1)
print(x.size(), '*'*100)
print(x, '*'*100)
x4.size(): torch.Size([1, 2048])
x4:tensor([[0.3893, 0.5719, 0.5537, ..., 0.3605, 0.4108, 0.3296]],device='cuda:0') # 拉平了
更多推荐
已为社区贡献3条内容
所有评论(0)