torch.argmax方法详解
torch.argmax方法详解
·
torch.argmax方法详解
torch.argmax(x, dim),其中x为张量,dim控制比较的维度,返回最大值的索引。
1.当dim=0时
import torch
x = torch.rand(2, 3,2)
print(x)
torch.argmax(x, dim=0)
当dim=0时,表示后两个维度进行比较,得到结果如下图:
比较过程为:输出结果的张量y的大小为去掉需比较维度dim后的大小,即3x2。然后依次确定这6个值,首先,对x[:,0,0]中的值进行比较,

取较大值的索引值输出结果的值,0.6718>0.6402,即y[0,0]=1;接着,对x[:,0,1]进行比较,
取较大值的索引值输出结果的值,即y[0,1]=0;以此类推,直到将所有的比较完成。
当dim为1,或2的情况也类似,结果如下:
2.当dim=1
3.当dim=2
更多推荐



所有评论(0)