pytorch:tensor维度理解及合并操作

这是我在做cnn时需要做semi_suprised learning时发现的问题,我需要将两个tensor合并。
例如

import torch as t
import numpy as np
a=t.tensor([[2,3,4],[3,4,5]])
b=t.tensor([1,2,3])
print(a)
print(b)

需要将a、b合并,得到

c=t.tensor([[2,3,4],[3,4,5],[1,2,3]])

问题十分简单,大佬跳过。
只需要:

c=t.cat((a,b),dim=0)

但是实际运行起来却报错维度不一致

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-b2f12a333b67> in <module>
----> 1 c=t.cat((a,b),dim=0)

RuntimeError: torch.cat(): Tensors must have same number of dimensions: got 2 and 1

但是使用c=c=torch.cat((a,a),dim=0)却可以
然后我打印了

print(a.shape)
print(b.shape)
##结果
torch.Size([2, 3])
torch.Size([3])

发现了问题,虽然看似形式一样,a确是两维,b是一维。
那怎么办呢?用reshape!
b=b.resahpe(1,3)
再用c=torch.cat((a,b),dim=0)就行了

b=b.reshape(1,3)
c=t.cat((a,b),dim=0)
c
##结果
tensor([[2, 3, 4],
        [3, 4, 5],
        [1, 2, 3]])

所以看似两次b的形式基本一样,维度却差了一维。
之后我又发现

e=b[0]
print(e)
print(e.shape)
##结果
tensor(1)
torch.Size([])

看到没有!没有维度
但是

e=e.reshape(1)
print(e)
print(e.shape)
##out
tensor([1])
torch.Size([1])

看到没有多了一个[],这又印证了我前面的叙述。
故大家合并时需要注意。

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐