torch.split用法
torch.split,用来划分tensor,可以从数量上划分,还有维度上划分。torch.split(tensor,split_szie,dim),split_size有整数,也有列表,dim默认为0,自己也可以修改。代码示例:import torcha=torch.tensor([[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]])print("a的shape:",
·
torch.split,用来划分tensor,可以从数量上划分,还有维度上划分。
torch.split(tensor,split_szie,dim),split_size有整数,也有列表,dim默认为0,自己也可以修改。
代码示例:
import torch
a=torch.tensor([[[1,2,3],[4,5,6]],
[[7,8,9],[10,11,12]]])
print("a的shape:",a.shape)
#在第0维上进行split
b=torch.split(a,1)
print("b:",b)
#在第1维上进行split
c=torch.split(a,[1,1],1)
print("c:",c)
输出:
a的shape: torch.Size([2, 2, 3])
b: (tensor([[[1, 2, 3],
[4, 5, 6]]]),
tensor([[[ 7, 8, 9],
[10, 11, 12]]]))
c: (tensor([[[1, 2, 3]],
[[7, 8, 9]]]),
tensor([[[ 4, 5, 6]],
[[10, 11, 12]]]))
更多推荐
已为社区贡献11条内容
所有评论(0)