torch.argmax方法详解

torch.argmax(x, dim),其中x为张量,dim控制比较的维度,返回最大值的索引。
1.当dim=0时

import torch
x = torch.rand(2, 3,2)
print(x)
torch.argmax(x, dim=0)

当dim=0时,表示后两个维度进行比较,得到结果如下图:
在这里插入图片描述
比较过程为:输出结果的张量y的大小为去掉需比较维度dim后的大小,即3x2。然后依次确定这6个值,首先,对x[:,0,0]中的值进行比较,

在这里插入图片描述

取较大值的索引值输出结果的值,0.6718>0.6402,即y[0,0]=1;接着,对x[:,0,1]进行比较,
在这里插入图片描述
取较大值的索引值输出结果的值,即y[0,1]=0;以此类推,直到将所有的比较完成。

当dim为1,或2的情况也类似,结果如下:
2.当dim=1
在这里插入图片描述
3.当dim=2
在这里插入图片描述

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐