多标签分类与binary_cross_entropy_with_logits
1. binary_cross_entropy_with_logits可用于多标签分类torch.nn.functional.binary_cross_entropy_with_logits等价于torch.nn.BCEWithLogitsLosstorch.nn.BCELoss+torch.nn.Sigmoid 等价于torch.nn.BCEWithLogitsLoss在pytorch中torc
1. binary_cross_entropy_with_logits可用于多标签分类
torch.nn.functional.binary_cross_entropy_with_logits等价于torch. nn.BCEWithLogitsLoss
torch.nn.BCELoss+torch.nn.Sigmoid 等价于 torch. nn.BCEWithLogitsLoss
在pytorch中torch.nn.functional.binary_cross_entropy_with_logits和tensorflow中tf.nn.sigmoid_cross_entropy_with_logits,都是二值交叉熵,二者等价。
接受任意形状的输入,target要求与输入形状一致。注意:target的值必须在[0,N-1]之间,其中N为类别数,否则会出现莫名其妙的错误,比如loss为负数。
二值交叉熵的Loss如下:
其中 可以解释为:预测这个样本为第i个类别的损失
解释为类别的权重,重视某个类别,则加大该类别权重。
from torch import nn
from torch.autograd import Variable
bce_criterion = nn.BCEWithLogitsLoss(weight = None, reduce = False)
y = Variable(torch.tensor([[1,0,0],[0,1,0],[0,0,1],[1,1,0],[0,1,0]],dtype=torch.float64))
logits = Variable(torch.tensor([[12,3,2],[3,10,1],[1,2,5],[4,6.5,1.2],[3,6,1]],dtype=torch.float64))
bce_criterion(logits, y)
binary_cross_entropy_with_logits中的target(标签)的one_hot编码中每一维可以出现多个1,而softmax_cross_entropy_with_logits 中的target的one_hot编码中每一维只能出现一个1
2. softmax_cross_entropy_with_logits
binary_cross_entropy_with_logits是二分类的交叉熵,实际是多分类softmax_cross_entropy的一种特殊情况
from torch import nn
from torch.autograd import Variable
bce_criterion = nn.BCEWithLogitsLoss(weight = None, reduce = False)
y = Variable(torch.tensor([[1,0,0],[0,1,0],[0,0,1],[0,1,0],[0,1,0]],dtype=torch.float64))
logits = Variable(torch.tensor([[12,3,2],[3,10,1],[1,2,5],[4,6.5,1.2],[3,6,1]],dtype=torch.float64))
bce_criterion(logits, y)
target中one_hot编码后每一行只能出现一个1
准确率评价参考:(69条消息) 多标签分类中的损失函数与评估指标_小Aer的博客-CSDN博客_多标签分类损失函数
参考:binary_cross_entropy_with_logits-API文档-PaddlePaddle深度学习平台
更多推荐
所有评论(0)