1、先看torch.cat 函数的官方解释:

网址:https://pytorch.org/docs/stable/generated/torch.cat.html?highlight=torch%20cat#torch.cat请添加图片描述
它的功能是将多个tensor类型矩阵的连接。它有两个参数,第一个是tensor元组或者tensor列表;第二个是dim,如果tensor是二维的,dim=0指在行上连接,dim=1指在列上连接。但是注意这里在行上连接,是扩展行进行连接,在列上连接是扩展列连接
注意:torch.cat 进行连接的tensor的shape,除了需要连接的维度上的shape值可不同,必须拥有相同的shape,a是(2,3),b是(2,20)即torch.cat((a,b),-1)可以进行连接;torch.cat((a,b),0)不可以进行连接,因为3和20值不同

2、下面看一些例子:

    a=torch.randn(2,3)
    print(a)
    b=torch.cat((a,a,a),1)
    print(b)
    
'''
#输出结果,dim=1,可见扩展列了
tensor([[-0.1121, -0.2641,  0.4476],
        [-1.2637,  1.0789,  1.0342]])
tensor([[-0.1121, -0.2641,  0.4476, -0.1121, -0.2641,  0.4476, -0.1121, -0.2641,
          0.4476],
        [-1.2637,  1.0789,  1.0342, -1.2637,  1.0789,  1.0342, -1.2637,  1.0789,
          1.0342]])
'''
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐