【神经网络笔记】——多分类交叉熵损失函数公式及代码实现
背景mse均方误差、mae绝对值平均误差用于拟合回归,公式已经熟悉了,但交叉熵的每次都只是应用,没有了解公式,这对于自己写交叉熵损失函数以及分析损失函数不利。公式详解C是损失值;n是求平均用的,所以是样本数量,也就是batchsize;x是预测向量维度,因为需要在输出的特征向量维度上一个个计算并求和;y是onehot编码后的真实值 对应x维度上的标签,是1或0;a是onehot格式输出的预测标签,
背景
mse均方误差、mae绝对值平均误差用于拟合回归,公式已经熟悉了,但交叉熵的每次都只是应用,没有了解公式,这对于自己写交叉熵损失函数以及分析损失函数不利。
公式详解
C是损失值;
n是求平均用的,所以是样本数量,也就是batchsize;
x是预测向量维度,因为需要在输出的特征向量维度上一个个计算并求和;
y是onehot编码后的真实值 对应x维度上的标签,是1或0;
a是onehot格式输出的预测标签,是0~1的值,a经过了softmax激活,所以a的和值为1
对于某个维度
x
i
x_i
xi,
y
=
1
y=1
y=1时a越大越好,相反a越小越好,C值为两者和的负值,所以越好→C↓ 所以可以最优化C(神经网络需要最小化损失函数)
公式计算举例
C ( ( 0.8 , 0.1 , 0.1 ) , ( 1 , 0 , 0 ) ) = − 1 / 1 ∗ ( 1 l n ( 0.8 ) + 0 + 0 + 1 l n ( 0.9 ) + 0 + 1 l n ( 0.9 ) ) C((0.8,0.1,0.1),(1,0,0))=-1/1*(1ln(0.8)+0+0+1ln(0.9)+0+1ln(0.9)) C((0.8,0.1,0.1),(1,0,0))=−1/1∗(1ln(0.8)+0+0+1ln(0.9)+0+1ln(0.9))
公式编程实现
计算举例是为了理解计算过程,最终还是要落到编程实现上:
def cross_entropy(y_true,y_pred):
C=0
# one-hot encoding
for col in range(y_true.shape[-1]):
y_pred[col] = y_pred[col] if y_pred[col] < 1 else 0.99999
y_pred[col] = y_pred[col] if y_pred[col] > 0 else 0.00001
C+=y_true[col]*np.log(y_pred[col])+(1-y_true[col])*np.log(1-y_pred[col])
return -C
# 没有考虑样本个数 默认=1
num_classes = 3
label=1#设定是哪个类别 真实值
y_true = np.zeros((num_classes))
# y_pred = np.zeros((num_classes))
# preset
y_true[label]=1
y_pred = np.array([0.0,1.0,0.0])
C = cross_entropy(y_true,y_pred)
print(y_true,y_pred,"loss:",C)
y_pred = np.array([0.1,0.8,0.1])
C = cross_entropy(y_true,y_pred)
print(y_true,y_pred,"loss:",C)
y_pred = np.array([0.2,0.6,0.2])
C = cross_entropy(y_true,y_pred)
print(y_true,y_pred,"loss:",C)
y_pred = np.array([0.3,0.4,0.3])
C = cross_entropy(y_true,y_pred)
print(y_true,y_pred,"loss:",C)
执行结果:
[0. 1. 0.] [1.0000e-05 9.9999e-01 1.0000e-05] loss: 3.0000150000863473e-05
[0. 1. 0.] [0.1 0.8 0.1] loss: 0.43386458262986227
[0. 1. 0.] [0.2 0.6 0.2] loss: 0.9571127263944101
[0. 1. 0.] [0.3 0.4 0.3] loss: 1.62964061975162
Process finished with exit code 0
结论
- 分类任务神经网络的输出层往往是经过了softmax激活,所以最后一层输出的预测向量各个维度的值均为0~1范围内的数。
- python的引用类型,如array,在函数传参后是传的引用,所以函数内的修改会影响到实际的值,通过控制台输出的第一条信息即可知。
- 计算过程可能出现溢出NaN的报错,所以需要进行近似处理。
更多推荐
所有评论(0)