BCEWithLogitsLoss 函数介绍
概述
BCEWithLogitsLoss
是 PyTorch 中的一个损失函数,专门用于二分类和多标签分类任务。它将 Sigmoid
激活函数与二元交叉熵损失(BCELoss)结合在一起,设计上旨在提高数值稳定性。通过将这两个操作合并为一个层,BCEWithLogitsLoss
能够有效避免分开计算时可能出现的数值上溢和下溢问题。
数学定义
对于给定的输入 (x) 和目标 (y),BCEWithLogitsLoss
的损失可以通过以下公式定义:
l n = − w n [ y n ⋅ log σ ( x n ) + ( 1 − y n ) ⋅ log ( 1 − σ ( x n ) ) ] l_n = - w_n \left[ y_n \cdot \log \sigma(x_n) + (1 - y_n) \cdot \log (1 - \sigma(x_n)) \right] ln=−wn[yn⋅logσ(xn)+(1−yn)⋅log(1−σ(xn))]
其中:
- (N) 是批次大小。
- (x_n) 是模型的输出(logits)。
- (y_n) 是真实标签(0 或 1)。
- (w_n) 是样本的权重,用于加权损失。
归约方式
损失函数可以通过以下方式进行归约:
- 无归约(‘none’): 返回每个样本的损失。
- 均值(‘mean’): 默认选项,对损失进行平均。
- 求和(‘sum’): 对损失进行求和。
权重调整
在处理类别不平衡的情况下,BCEWithLogitsLoss
允许通过 pos_weight
参数为正样本设置权重。这可以帮助模型更好地关注那些相对较少的阳性样本。例如,如果某个类别的正样本数量远少于负样本,您可以为该类别设置更大的权重,以便模型在训练时更加重视这些少数类样本。
官方代码
class BCEWithLogitsLoss(_Loss):
"""
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Target: :math:`(*)`, same shape as the input.
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same
shape as input.
Examples::
>>> loss = nn.BCEWithLogitsLoss()
>>> input = torch.randn(3, requires_grad=True)
>>> target = torch.empty(3).random_(2)
>>> output = loss(input, target)
>>> output.backward()
"""
def __init__(
self,
weight: Optional[Tensor] = None,
size_average=None,
reduce=None,
reduction: str = "mean",
pos_weight: Optional[Tensor] = None,
) -> None:
super().__init__(size_average, reduce, reduction)
self.register_buffer("weight", weight)
self.register_buffer("pos_weight", pos_weight)
self.weight: Optional[Tensor]
self.pos_weight: Optional[Tensor]
def forward(self, input: Tensor, target: Tensor) -> Tensor:
return F.binary_cross_entropy_with_logits(
input,
target,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction,
)
使用示例
以下是如何在 PyTorch 中使用 BCEWithLogitsLoss
的示例代码:
import torch
import torch.nn as nn
# 创建目标张量和预测张量
target = torch.tensor([[1.0], [0.0], [1.0], [0.0]]) # 目标值
logits = torch.tensor([[0.8], [-1.2], [1.5], [-0.5]]) # 模型输出(logits)
# 初始化损失函数
loss_function = nn.BCEWithLogitsLoss()
# 计算损失
loss = loss_function(logits, target)
print('Loss:', loss.item())
适用场景
BCEWithLogitsLoss 函数适用于以下场景:
- 二分类问题: 适合需要对样本进行0/1标签分类的任务。
- 多标签分类问题: 每个样本可以同时属于多个类别时,利用该损失函数处理每个类别的标签。
- 类别不平衡: 通过使用 pos_weight 来加大少数类别样本的影响,使模型更加关注这些样本。
总结
BCEWithLogitsLoss 是一个强大的损失函数,为二分类和多标签分类任务提供了高效和稳定的解决方案。通过结合 Sigmoid 和交叉熵损失,它确保了数值计算的稳定性,同时能够处理类别不平衡问题。