torch.split,用来划分tensor,可以从数量上划分,还有维度上划分。
torch.split(tensor,split_szie,dim)split_size有整数,也有列表,dim默认为0,自己也可以修改。
代码示例:

import torch
a=torch.tensor([[[1,2,3],[4,5,6]],
                [[7,8,9],[10,11,12]]])
print("a的shape:",a.shape)
#在第0维上进行split
b=torch.split(a,1)
print("b:",b)
#在第1维上进行split
c=torch.split(a,[1,1],1)
print("c:",c)

输出:

a的shape: torch.Size([2, 2, 3])
b: (tensor([[[1, 2, 3],
         [4, 5, 6]]]),
          tensor([[[ 7,  8,  9],
         [10, 11, 12]]]))
c: (tensor([[[1, 2, 3]],

        [[7, 8, 9]]]), 
        tensor([[[ 4,  5,  6]],

        [[10, 11, 12]]]))
Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐