torch.chunk:用来将tensor分成很多个块,简而言之我理解的就是切分吧,可以在不同维度上切分。
torch.chunk(tensor,chunk数,维度)
代码示例:

import torch
a=torch.tensor([[[1,2],[3,4]],
               [[5,6],[7,8]]])
b=torch.chunk(a,2,1)
print(a)
print(b)

输出:

tensor([[[1, 2],
         [3, 4]],
        [[5, 6],
         [7, 8]]])
(tensor([[[1, 2]],
        [[5, 6]]]), 
 tensor([[[3, 4]],
        [[7, 8]]]))
Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐