在pytorch训练模型的时候,出现:
在这里插入图片描述
代码:

            pred = model(X)
            print(pred.argmax(1))

打印pred为:

tensor([ 17.0364,  28.3838, -27.5744,   8.5920])

因为只有一维,所以需要改为:

print(pred.argmax(0))

这样就没有问题了

当我们一次使用多个输入数据时,可能tensor就是二维的,这个时候才可以用pred.argmax(1)得到最大值的索引。

tensor([[ 1.1916e-01, -1.7842e-01,  2.4500e-01, -1.1631e-01,  4.5129e-01,
         -2.1620e-01,  2.5249e-01, -3.0434e-01,  1.0978e-01,  4.3598e-02],
        [ 4.3276e-02, -1.3183e-02,  4.3428e-02, -4.3271e-03,  9.0060e-02,
         -2.3623e-02,  6.0345e-03, -2.6779e-02,  6.1037e-02,  2.1716e-02],
        [ 8.4877e-02,  7.0410e-03,  7.7200e-02,  1.4489e-02,  1.6732e-01,
         -6.9628e-02,  5.2289e-02, -1.0901e-01,  2.8091e-02, -5.3942e-03]])
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐