torch.chunk:用来将tensor分成很多个块,简而言之我理解的就是切分吧,可以在不同维度上切分。
torch.chunk(tensor,chunk数,维度)
代码示例:

import torch
a=torch.tensor([[[1,2],[3,4]],
               [[5,6],[7,8]]])
b=torch.chunk(a,2,1)
print(a)
print(b)

输出:

tensor([[[1, 2],
         [3, 4]],
        [[5, 6],
         [7, 8]]])
(tensor([[[1, 2]],
        [[5, 6]]]), 
 tensor([[[3, 4]],
        [[7, 8]]]))
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐