torch.unbind()
作用:对某一个维度进行长度为1的切片,并将所有切片结果返回。举个例子就知道了:import torchx = torch.tensor([[1, 2],[ 3,4]])torch.unbind(x,0)#第0个维度上进行长度为1的切片。结果:(tensor([1, 2]), tensor([3, 4]))应用减少代码量:m,n=torch.unbind(x,0)print(m)print(n)m,
·
作用:对某一个维度进行长度为1的切片,并将所有切片结果返回。
举个例子就知道了:
import torch
x = torch.tensor([[1, 2],[ 3,4]])
torch.unbind(x,0)#第0个维度上进行长度为1的切片。
结果:
(tensor([1, 2]), tensor([3, 4]))
应用减少代码量:
m,n=torch.unbind(x,0)
print(m)
print(n)
m,n=x[0],x[1]#虽然这个也可以,但是如果x的第一个维度很大,这个就很繁琐。
print(m)
print(n)
结果:
tensor([1, 2])
tensor([3, 4])
tensor([1, 2])
tensor([3, 4])
补充另外一种torch.unbind()的等价形式:
x.unbind(0)
结果:
(tensor([1, 2]), tensor([3, 4]))
更多推荐


所有评论(0)