Pytorch的torch.nn.functional.cross_entropy的ignore_index细解

发布于:2024-05-14 ⋅ 阅读:(101) ⋅ 点赞:(0)

作用
ignore_index用于忽略ground-truth中某些不需要参与计算的类。假设有两类{0:背景,1:前景},若想在计算交叉熵时忽略背景(0)类,则可令ignore_index=0(同理忽略前景计算可设ignore_index=1)。

代码示例

import torch
import torch.nn.functional as F
pred = torch.Tensor(
    [
        [0.9, 0.1],
        [0.8, 0.2],
        [0.7, 0.3]
    ]
)  # shape=(N,C)=(3,2),N为样本数,C为类数
label = torch.LongTensor([1, 0, 1])  # shape=(N)=(3),3个样本的label分别为1,0,1
out = F.cross_entropy(pred, label, ignore_index=0)  # 忽略0类
print(out)


输出

tensor(1.0421)


验证
pytorch的CrossEntropy使用公式:

计算:
loss=1/2×{[−0.1+ln(e ^{0.9}+e ^{0.1} )]+[−0.3+ln(e ^{0.7}+e ^{0.3})]}= 1/2×(1.1711+0.9130)=1.0421 ​

ignore_index表示计算交叉熵时,自动忽略的标签值,example:

import torch
import torch.nn.functional as F
pred = []
pred.append([0.9, 0.1])
pred.append([0.8, 0.2])
pred = torch.Tensor(pred).view(-1,  2)

label = torch.LongTensor([[1], [-1]])  # 这里输出类别为0或1,-1表示不参与计算loss。且计算平均loss的时候,reduction只计算实际参与计算的个数,这里相当于batchsize=2,但其中第index=1行为-1不参与计算loss。

# out = F.cross_entropy(pred.view(-1, 2), label.view(-1, )) 
out = F.cross_entropy(pred.view(-1, 2), label.view(-1, ), ignore_index=-1) 
print(out)

输出结果:

tensor(1.1711)

再比如:

例如我的pred是(b,2,w,h),而label索引是(b,1,w,h)的矩阵,其中只有0,1值,0值代表从pred的第0个通道选择像素值,1值代表从pred的第1个通道选择像素值。

而此时我发现因为程序的错误,label矩阵中混入了一些-1值,这样正常的话是会报错的,因为pred矩阵没有-1通道。此时最简单的一个方法就是

loss = nn.CrossEntropyLoss(ignore_index=-1) 

上述操作就是相当于忽略-1标签值为-1的位置的对应像素值就不参与计算梯度了

torch.nn.CrossEntropyLoss 同理。


网站公告

今日签到

点亮在社区的每一天
去签到