pytorch:tensor维度理解及合并操作
pytorch中的tensor维度理解及合并操作
·
pytorch:tensor维度理解及合并操作
这是我在做cnn时需要做semi_suprised learning时发现的问题,我需要将两个tensor合并。
例如
import torch as t
import numpy as np
a=t.tensor([[2,3,4],[3,4,5]])
b=t.tensor([1,2,3])
print(a)
print(b)
需要将a、b合并,得到
c=t.tensor([[2,3,4],[3,4,5],[1,2,3]])
问题十分简单,大佬跳过。
只需要:
c=t.cat((a,b),dim=0)
但是实际运行起来却报错维度不一致
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-6-b2f12a333b67> in <module>
----> 1 c=t.cat((a,b),dim=0)
RuntimeError: torch.cat(): Tensors must have same number of dimensions: got 2 and 1
但是使用c=c=torch.cat((a,a),dim=0)却可以
然后我打印了
print(a.shape)
print(b.shape)
##结果
torch.Size([2, 3])
torch.Size([3])
发现了问题,虽然看似形式一样,a确是两维,b是一维。
那怎么办呢?用reshape!
b=b.resahpe(1,3)
再用c=torch.cat((a,b),dim=0)就行了
b=b.reshape(1,3)
c=t.cat((a,b),dim=0)
c
##结果
tensor([[2, 3, 4],
[3, 4, 5],
[1, 2, 3]])
所以看似两次b的形式基本一样,维度却差了一维。
之后我又发现
e=b[0]
print(e)
print(e.shape)
##结果
tensor(1)
torch.Size([])
看到没有!没有维度
但是
e=e.reshape(1)
print(e)
print(e.shape)
##out
tensor([1])
torch.Size([1])
看到没有多了一个[],这又印证了我前面的叙述。
故大家合并时需要注意。
更多推荐
已为社区贡献3条内容
所有评论(0)