作用:对某一个维度进行长度为1的切片,并将所有切片结果返回。

举个例子就知道了:

import torch
x = torch.tensor([[1, 2],[ 3,4]])
torch.unbind(x,0)#第0个维度上进行长度为1的切片。

结果:

(tensor([1, 2]), tensor([3, 4]))

应用减少代码量:

m,n=torch.unbind(x,0)
print(m)
print(n)
m,n=x[0],x[1]#虽然这个也可以,但是如果x的第一个维度很大,这个就很繁琐。
print(m)
print(n)

结果:

tensor([1, 2])
tensor([3, 4])
tensor([1, 2])
tensor([3, 4])

补充另外一种torch.unbind()的等价形式:

x.unbind(0)

结果:

(tensor([1, 2]), tensor([3, 4]))
Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐