4.8 常见的损失函数

发布于:2025-07-28 ⋅ 阅读:(16) ⋅ 点赞:(0)

损失函数(Loss Function)是机器学习和深度学习的核心组件,用于量化模型预测值与真实值之间的差异,指导模型参数的优化方向。以下是损失函数的系统解析:


一、损失函数的作用

  1. 评估模型性能:数值化衡量预测结果的准确性。
  2. 指导参数优化:通过梯度下降等算法,调整模型参数以最小化损失值。
  3. 任务适配性:不同任务(如分类、回归)需选择不同的损失函数。

二、常见损失函数分类

1. 回归任务损失函数
名称 公式 特点
均方误差(MSE) L=1N∑i=1N(yi−y^i)2\mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2L=N1i=1N(yiy^i)2 对异常值敏感,梯度随误差增大线性增长。
平均绝对误差(MAE) L=1N∑i=1N∣yi−y^i∣\mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} | y_{i}-\hat{y}_i|L=N1i=1Nyiy^i
Huber损失 L={12(y−y^)2   if ∣y−y^∣≤δδ∣y−y^∣−δ22   otherwise\mathcal{L} = \begin{cases} \frac{1}{2}(y-\hat{y})^2 \ \ \ \text{if } | y-\hat{y} | \leq \delta \\ \delta|y-\hat{y}|-\frac{\delta^{2}}{2}\ \ \ \text{otherwise} \end{cases}L={21(yy^)2   if yy^δδyy^2δ2   otherwise
2. 分类任务损失函数
名称 公式 特点
交叉熵损失(Cross-Entropy) L=−1N∑i=1N∑k=1Kyi,klog⁡(y^i,k)\mathcal{L} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{k=1}^{K} y_{i,k} \log(\hat{y}_{i,k})L=N1i=1Nk=1Kyi,klog(y^i,k) 衡量概率分布差异,适用于多分类,梯度更新高效。
二分类交叉熵 L=−1N∑i=1N[yilog⁡(y^i)+(1−yi)log⁡(1−y^i)]\mathcal{L} = -\frac{1}{N} \sum_{i=1}^{N} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right]L=N1i=1N[yilog(y^i)+(1yi)log(1y^i)] 二分类特例,输出需经过Sigmoid激活。
合页损失(Hinge Loss) L=1N∑i=1Nmax⁡(0,1−yiy^i)\mathcal{L} = \frac{1}{N} \sum_{i=1}^N \max(0, 1 - y_i \hat{y}_i)L=N1i=1Nmax(0,1yiy^i) 用于支持向量机(SVM),鼓励分类边界最大化。
3. 其他任务损失函数
名称 公式 应用场景
KL散度(Kullback-Leibler Divergence) L=∑p(x)log⁡p(x)q(x)\mathcal{L} = \sum p(x) \log \frac{p(x)}{q(x)}L=p(x)logq(x)p(x) 衡量两个概率分布的差异,如生成对抗网络(GAN)。
Triplet Loss L=max⁡(d(a,p)−d(a,n)+margin,0)\mathcal{L} = \max(d(a, p) - d(a, n) + \text{margin}, 0)L=max(d(a,p)d(a,n)+margin,0) 用于度量学习(如人脸识别),优化样本间的距离关系。
Focal Loss L=−αt(1−pt)γlog⁡(pt)\mathcal{L} = -\alpha_t (1 - p_t)^\gamma \log(p_t)L=αt(1pt)γlog(pt) 解决类别不平衡问题,调整难易样本的权重(如目标检测)。

三、损失函数的数学意义

1. 交叉熵与MSE的对比
  • 交叉熵
    • 基于概率分布差异(信息论),梯度更新速度与误差成正比,适合分类任务。
    • 梯度公式:∂L∂z=y^−y\frac{\partial \mathcal{L}}{\partial z} = \hat{y} - yzL=y^yzzz为未激活的logits),更新效率高。
  • MSE
    • 梯度公式:∂L∂z=(y^−y)⋅σ′(z)\frac{\partial \mathcal{L}}{\partial z} = (\hat{y} - y) \cdot \sigma'(z)zL=(y^y)σ(z)σ\sigmaσ 为激活函数)。
    • 当使用Sigmoid激活时,梯度易饱和(σ′(z)\sigma'(z)σ(z)接近零),导致训练停滞。
2. 损失函数与模型输出
  • 分类任务:输出需经过概率化处理(如Softmax、Sigmoid),确保损失函数有意义。
  • 回归任务:输出直接为连续值,无需额外激活。

四、如何选择损失函数?

任务类型 推荐损失函数 注意事项
二分类 二分类交叉熵 输出层使用Sigmoid激活。
多分类 交叉熵损失 输出层使用Softmax激活。
回归 MSE(数据无异常) / MAE或Huber(有异常) MSE对异常值敏感,MAE梯度更新稳定但收敛慢。
类别不平衡 Focal Loss 调整参数 γ\gammaγ 控制难易样本权重。
生成模型 KL散度、Wasserstein距离 GAN中常用Wasserstein距离提升训练稳定性。

五、代码示例

1. PyTorch实现常见损失函数
import torch
import torch.nn as nn

# 均方误差
mse_loss = nn.MSELoss()
output = model(x)
loss = mse_loss(output, y_true)

# 交叉熵损失(自动包含Softmax)
ce_loss = nn.CrossEntropyLoss()
output = model(x)  # 输出未归一化的logits
loss = ce_loss(output, y_true)

# 二分类交叉熵(输出需经过Sigmoid)
bce_loss = nn.BCELoss()
output = torch.sigmoid(model(x))
loss = bce_loss(output, y_true.float())

# Focal Loss(需自定义)
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = nn.BCEWithLogitsLoss()(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt)**self.gamma * BCE_loss
        return F_loss
2. TensorFlow/Keras实现
import tensorflow as tf

# 均方误差
model.compile(optimizer='adam', loss='mse')

# 交叉熵损失(输出为概率)
model.compile(optimizer='adam', loss='categorical_crossentropy')

# 自定义Huber损失
def huber_loss(y_true, y_pred, delta=1.0):
    error = y_true - y_pred
    condition = tf.abs(error) < delta
    squared_loss = 0.5 * tf.square(error)
    linear_loss = delta * (tf.abs(error) - 0.5 * delta)
    return tf.reduce_mean(tf.where(condition, squared_loss, linear_loss))

model.compile(optimizer='adam', loss=huber_loss)

六、注意事项

  1. 输出层与损失函数匹配
    • 交叉熵损失要求输出为概率(需Softmax/Sigmoid),MSE要求输出为实数。
  2. 数值稳定性
    • 避免计算 log⁡(0)\log(0)log(0),可在概率值中加小常数(如 log⁡(y^+ϵ)\log(\hat{y} + \epsilon)log(y^+ϵ))。
  3. 正则化与损失函数
    • 正则化项(如L2)通常作为损失函数的一部分,共同参与梯度计算。

七、总结

损失函数是模型训练的导航仪,其选择直接影响模型性能。理解不同损失函数的数学特性、适用场景及实现细节,是构建高效模型的关键。实际应用中,需结合任务需求、数据分布和模型结构综合权衡。


网站公告

今日签到

点亮在社区的每一天
去签到