torch.Tensor()是一个类,默认是torch.FloatTensor()的简称,创建的为float32位的数据类型;
torch.tensor()是一个函数,是对张量数据的拷贝,根据传入data的类型来创建Tensor;

a=torch.Tensor([1,2])
type(a)#<class 'torch.Tensor'>
a.type()#'torch.FloatTensor'

a = torch.tensor([1, 2])
type(a)#<class 'torch.Tensor'>
a.type()#'torch.LongTensor'

a=torch.Tensor([1,2])
a.dtype#torch.float32

a = torch.tensor([1, 2])
a.dtype#torch.int64
a = torch.tensor([1., 2.])
a.dtype#torch.float32
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐