logits是啥、傅里叶变换

发布于:2025-05-22 ⋅ 阅读:(23) ⋅ 点赞:(0)

什么是logtis? 


在深度学习的上下文中,logits 就是一个向量,下一步通常被投给 softmax/sigmoid 的向量。。

softmax的输出是分类任务的概率,其输入是logits层。 logits层通常产生-infinity到+ infinity的值,而softmax层将其转换为0到1的值。

举个例子:

如果一个二分类模型的输出层是一个单节点,没有激活函数,输出值是2.5。则2.5就是这个样本被预测为正类的logits。
如果一个多分类模型的输出层是5个节点,没有激活函数,输出值是[1.2, 3.1, -0.5, 4.8, 2.3]。则这个向量表示该样本属于5个类别的logits。
logits可以取任意实数值,正值表示趋向于分类到该类,负值表示趋向于不分类到该类。与之对应,probability经过softmax/sigmoid归一化到0-1之间,表示属于该类的概率。

logtis在交叉熵损失中的示例:
logtis可以看作神经网络输出的未经过归一化(softmax/sigmoid)的概率,所以将其结果用于分类任务计算loss时,如求cross_entropy的loss函数会设置from_logits参数。

因此,当from_logtis=False(默认情况)时,可以理解为输入的y_pre不是来自logtis,那么它就是已经经过softmax或者sigmoid归一化后的结果了;当from_logtis=True时,则可以理解为输入的y_pre是logits,那么需要对它进行softmax或者sigmoid处理再计算cross entropy。

如下面的两种方式计算得到的loss都是一样的:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
from_logits = True  # 标记logtis or sigmoid
y_pre = torch.FloatTensor([[5, 5], [2, 8]])  # logtis值:神经网络的输出,未概率归一化
y_true = torch.FloatTensor([[1, 0], [0, 1]])  # 真实的二分类标签
 
if from_logits == True:
    # BCEWithLogitsLoss可以直接对logtis值进行二分类的损失计算
    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(y_pre, y_true)
else:
    # BCELoss需要对logtis值进行概率归一化,然后再进行二分类的损失计算
    criterion = nn.BCELoss()
    y_pre = F.sigmoid(y_pre)
    loss = criterion(y_pre, y_true)
print(loss)

傅里叶变换
连续与离散
一般人们口中所说的傅里叶变换都是指连续傅里叶变换,针对的是连续时域信号。维基百科上是这么描述连续信号的:

连续信号或称连续时间信号是指定义在实数域的信号,自变量(一般是时间)的取值连续。若信号的幅值和自变量均连续,则称为模拟信号。根据实数的性质,时间参数的连续性意味着信号的值在时间的任意点均有定义。

简单来说,对于一个sin函数的连续信号,其波形长这样:

对于计算设备的信号处理,因为采样设备的采样率是有限的。因此得到的采样信号都是离散的,所以就有了针对离散信号的离散傅里叶变换。维基百科是这么描述离散信号:

离散信号是在连续信号上采样得到的信号。与连续信号的自变量是连续的不同,离散信号是一个串行,即其自变量是“离散”的。这个串行的每一个值都可以被看作是连续信号的一个采样。由于离散信号只是采样的串行,并不能从中获得采样率,因此采样率必须另外存储。以时间为自变量的离散信号为离散时间信号。
离散信号并不等同于数字信号。数字信号不仅是离散的,而且是经过量化的。即,不仅其自变量是离散的,其值也是离散的。因此离散信号的精度可以是无限的,而数字信号的精度是有限的。而有着无限精度,亦即在值上连续的离散信号又叫抽样信号。所以离散信号包括了数字信号和抽样信号。

信号的筛选

简单概括公式所表达的实际意义后,我们再来看为什么要这么做。废话少说直接按照傅里叶变换公式所描述的做法,将信号与各已频率信号乘法后积分,写一段简单的matlab程序。最终结果如下图:

从图上可以很清楚直观的看出,当未知信号与基信号频率相同时,他们乘积的积分达到了最大。而这个积分最终的图像是不是就很像信号经过傅里叶变换后的频域图像呢!利用了乘法负负得正的特性,当基信号与未知信号频率完全相同的时候,其时域上乘积全为正数。因此,此时的积分达到最大值。通过此类操作即可得到未知信号的频率。

说人话,就是用各种频率的基信号与未知信号做对比,看看哪个频率的基信号与未知信号最像!!!,而这个对比的方法,用数学来描述就是先相乘后积分,这样描述的话,是不是听起来就没那么复杂了。

初始相位的问题

上文所讨论的对未知信号频率的筛选,全部都是建立在未知信号相位为0的情况下。但是实际使用过程中,信号的相位与频率都是未知的,若只使用上文(4)所描述的方式肯定是不行的,如下图:

可以看到若未知信号的初始相位为90°,若还使用公式4,得出的结果与实际就有误了。

此时再回头看看傅里叶变换的公式,不难发现,傅里叶变换使用的是复数信号。e − j ω t = c o s ( ω t ) − j s i n ( ω t ) e^{-j\omega t} = cos(\omega t)-jsin(\omega t)e 
−jωt
 =cos(ωt)−jsin(ωt)。也就是说傅里叶变换对未知信号使用复数信号相乘后求积分。而不是只对单个正弦信号做比较。

不难看出,傅里叶变换的结果是复数信号,其结果实部的值就是表示未知信号与cos信号的相关度,虚部值表示未知信号与sin信号的相关度。由此可以看出,傅里叶变换的结果还包含了未知信号的相位信息。

为了分析其频率信息,我们可以对结果做取模开方处理

快速傅里叶变换(FFT)

快速傅里叶算法本质上是对离散傅里叶变化的改进,本质上是利用矩阵乘法时不同基信号在奇数或偶数点的与被采样信号的乘积相同来减少计算,严格的数学推断可以参考这篇快速理解FFT算法(完整无废话)

MATLAB 仿真代码

% 设置时间范围
t = 0:0.01:2*pi;
freq1 = 5;
freq2 = 0;

% 创建GIF文件
filename = 'sin_wave_animation.gif';
fps = 120;

% 初始化乘积结果的和
sum_of_product = zeros(1, length(t));
sum_of_product_sin = zeros(1, length(t));
sum_of_product_cos = zeros(1, length(t));

for i = 1:length(t)
    % 生成两个不同频率的正弦波
    y1 = cos(freq1 * t);
    y2 = sin((freq2+i/10) * t);
    y3 = cos((freq2+i/10) * t);
    % 计算两个正弦波的乘积

    y_product_sin = y1 .* y2;
    y_product_cos = y1 .* y3;
    % 计算乘积结果的和
    sum_of_product_sin(i+1) = sum(y_product_sin);
    sum_of_product_cos(i+1) = sum(y_product_cos);
    
    sum_of_product = sqrt(power(sum_of_product_sin, 2) + power(sum_of_product_cos, 2));

    arr = 0:0.1:10;
    subplot(3, 1, 1);
    plot(arr(1:i), sum_of_product_sin(1:i), 'g'); % 只绘制当前时间点之前的数据
    xlabel('x');
    ylabel('y');
    xlim([0, 10]);
    ylim([-500, 500]);
    title('Re');

    arr = 0:0.1:10;
    subplot(3, 1, 2);
    plot(arr(1:i), sum_of_product_cos(1:i), 'g'); % 只绘制当前时间点之前的数据
    xlabel('x');
    ylabel('y');
    xlim([0, 10]);
    ylim([-500, 500]);
    title('Im');
    
    arr = 0:0.1:10;
    subplot(3, 1, 3);
    plot(arr(1:i), sum_of_product(1:i), 'g'); % 只绘制当前时间点之前的数据
    xlabel('x');
    ylabel('y');
    xlim([0, 10]);
    ylim([-500, 500]);
    title('sqrt(power(Re, 2) + power(Im, 2));');

    % 保存当前图像为GIF
    frame = getframe(gcf);
    im = frame2im(frame);
    [imind, cm] = rgb2ind(im, 256);
    if i == 1
        imwrite(imind, cm, filename, 'gif', 'Loopcount', inf, 'DelayTime', 1/fps);
    else
        imwrite(imind, cm, filename, 'gif', 'WriteMode', 'append', 'DelayTime', 1/fps);
    end
    
    if i == 100
        break;
    end

end


网站公告

今日签到

点亮在社区的每一天
去签到