Cross Entropy Loss详解
pytorch中的CrossEntropyLoss
究竟干了什么?
1 logits
一般多分类模型的原始输出称为logits
, 可将其理解为没有标准化的概率.
logits
不必在$0\sim 1$之间, 也不必满足每个样本各种情况概率相加等于1.
2 softmax
softmax
可以将logits标准化, 设其中一个样本的logit为:
$$
\begin{equation}
logit = [a_1, a_2, a_3, \cdots, a_c]
\end{equation}
$$
若将logits解释成没有归一化的对数概率: $$\begin{equation} a_i = \ln{b_i}, i=1\sim c \end{equation}$$ 那么对其做指数操作并归一化,就可以得到标准化的概率. 因此softmax的公式为: $$ \begin{equation} q_i = \frac{e^{a_i}}{\sum_{j=1}^{c}e^{a_j}} \end{equation} $$
3 交叉熵损失 Cross Entropy Loss
在有监督多分类问题中, one-hot
编码相当于样本的真实概率分布p
, 而模型预测的logit经过softmax得到了预测的概率分布q
, 因此可以方便地采用交叉熵作为损失函数:
$$ \begin{equation}
Loss = \sum_j^c p_j \ln \frac{p_j}{q_j}
\end{equation} $$
而one-hot
编码中只有一项为1, 不妨设$p_k=1$, 则损失函数变为:
$$ \begin{equation}
Loss = -\ln q_k
\end{equation} $$
4 反向传播
在梯度的反向传播中需要求loss对模型输出的导数, 也即: $\frac{\partial L}{\partial a_i}, i=1\sim c$.
5 torch.nn.CrossEntropyLoss 究竟干了什么
CrossEntropyLoss
输入为 logits, 输出为交叉熵损失。
也就是包含 softmax 和 交叉熵$-ln()$ 两个步骤: $$ \begin{equation} loss(x,class)=−\ln\left(\frac{e^{x[class]}}{\sum_j e^{x[j]}} \right)=-x[class]+\ln \left( \sum_j e^{x[j]} \right) \end{equation} $$
把两步操作写成一个op, 目的是为了让反向传播梯度更加稳定, 从式(6)中可以看出, softmax的梯度是$p(1-p)$, 而交叉熵的梯度是$-1/p$, 后者在$p\to 0$的时候并不稳定, 而将两者相乘可以抵消掉分母的$p$.
还有另一种计算交叉熵损失的步骤使用 LogSoftmax
+ NLLLoss
能够到达跟CrossEntropyLoss
一样的效果.
其中 LogSoftmax()(logits)
等效于 log(Softmax()(logits))
, 即通过logits先计算标准化概率, 再取对数概率$y$:
$$
y_i = \ln \left( \frac{e^{x[i]}}{\sum_j e^{x[j]}} \right)
$$
NLLLoss
的作用是根据标签选择对应的对数概率, 并添加负号:
$$
loss(y, class) = -y[class]
$$
两种方法殊途同归, 虽然使用的是不同函数, 但本质是同一个公式, 并且LogSoftmax
也将两个操作写成了一个op.
参考代码如下:
|
|