torch.topk() 函数详解
作用: 返回 列表中最大的n个值例子1:m=torch.arange(0,10)print(m.topk(3))torch.return_types.topk(values=tensor([9, 8, 7]),indices=tensor([9, 8, 7]))例子2:pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053]
作用: 返回 列表中最大的n个值
例子1:m=torch.arange(0,10)
print(m.topk(3))
torch.return_types.topk(
values=tensor([9, 8, 7]),
indices=tensor([9, 8, 7]))
例子2:pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
values, indices = pred.topk(4, dim=0, largest=True, sorted=True)
print(values)
print(indices)
tensor([[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[ 0.2021, 0.3041, 1.2590, 0.3849, 1.7255],
[-0.4451, 0.1673, 0.1383, -1.0145, 0.4053],
[-0.5816, -0.3873, -1.0215, -2.0757, -1.6311]])
tensor([[1, 1, 1, 1, 1],
[3, 3, 2, 3, 2],
[2, 2, 3, 0, 0],
[0, 0, 0, 2, 3]])
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None)
input -> 输入tensor
k -> 前k个
dim -> 默认为输入tensor的最后一个维度
sorted -> 是否排序
largest -> False表示返回第k个最小值
更多推荐
所有评论(0)