我自己的原文哦~ https://blog.51cto.com/whaosoft/12750420
#傅里叶特征 (Fourier Feature)与核回归
位置编码背后的理论解释
本文探讨了位置编码背后的理论基础,特别是傅里叶特征(Fourier Feature)与核回归(Kernel Regression)的联系,并解释了如何通过这些理论来增强神经网络对高频信息的学习能力。
最近我在看位置编码最新技术时,看到了一个叫做 "NTK-aware" 的词。我想:「"NTK"是什么?Next ToKen (下一个词元)吗?为什么要用这么时髦的缩写?」看着看着,我才发现不对劲。原来,NTK 是神经网络理论里的一个概念,它从 kernel regression 的角度解释了神经网络的学习方法。基于 NTK 理论,有人解释了位置编码的理论原理并将其归纳为一种特殊的 Fourier Feature (傅里叶特征)。这么多专有名词一下就把我绕晕了,我花了几天才把它们之间的关系搞懂。
在这篇文章里,我主要基于论文_Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains_ (后文简称为「傅里叶特征论文」),介绍傅里叶特征这一概念。为了讲清这些理论的发展脉络,我会稍微讲一下 NTK 等理论概念。介绍完傅里叶特征后,我还会讲解它在其他方法中的应用。希望读完本文后,读者能够以这篇论文为基点,建立一个有关位置编码原理的知识网络,以从更深的层次来思考新的科研方向。
用 MLP 表示连续数据
我们先从一个具体的任务入手,直观体会傅里叶特征能够做些什么事。
我们知道,神经网络,哪怕是最简单的多层感知机(MLP),都有着很强的泛化能力:训练完毕后,对于训练集里完全没见过的输入,网络也能给出很正确的输出。特别地,如果新输入恰好和训练集的某个输入很近,那么它的输出也会和对应的训练集输出很近;随着新输出与训练集输入的距离不断增加,新输出也会逐渐变得不同。这反映了神经网络的连续性:如果输入的变化是连续的,那么输出的变化也是连续的。
基于神经网络的这一特性,有人想到:我们能不能用神经网络来表示连续数据呢?比如我想表达一张处处连续的图像,于是我令神经网络的输入是(x, y) 表示的二维坐标,输出是 RGB 颜色。之后,我在单张图像上过拟合这个 MLP。这样,学会表示这张图像后,哪怕输入坐标是分数而不是整数,神经网络也能给出一个颜色输出。
这种连续数据有什么好处呢?我们知道,计算机都是以离散的形式来存储数据的。比如,我们会把图像拆成一个个像素,每个像素存在一块内存里。对于图像这种二维数据,计算机的存储空间还勉强够用。而如果想用密集的离散数据表达更复杂的数据,比如 3D 物体,计算机的容量就捉襟见肘了。但如果用一个 MLP 来表达 3D 物体的话,我们只需要存储 MLP 的参数,就能获取 3D 物体在任何位置的信息了。
这就是经典工作神经辐射场 (Neural Radiance Field, NeRF) 的设计初衷。NeRF 用一个 MLP 拟合 3D 物体的属性,其输入输出如下图所示。我们可以用 MLP 学习每个 3D 坐标的每个 2D 视角处的属性(这篇文章用的属性是颜色和密度)。根据这些信息,利用某些渲染算法,我们就能重建完整的 3D 物体。
上述过程看起来好像很简单直接。但在 NeRF 中,有一个重要的实现细节:必须给输入加上位置编码,MLP 才能很好地过拟合连续数据。这是为什么呢?让我们先用实验复现一下这个现象。
MLP 拟合连续图像实验
为了快速复现和位置编码相关的问题,我们简单地用一个 MLP 来表示图像:MLP 的输入是 2D 坐标,输出是此处的三通道 RGB 颜色。我为这篇博文创建一个 GitHub 文件夹 https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/FourierFeature ,该实验的 Notebook 代码在文件夹的image_mlp.ipynb
中,欢迎大家 clone 项目并动手尝试。
一开始,我们先导入库并可视化要拟合的图片。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.io import read_image, ImageReadMode
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm
from einops import rearrange
def viz_image(pt_img: torch.Tensor):
pil_img = to_pil_image(pt_img)
display(pil_img)
input_image = read_image('misuzu.png', ImageReadMode.RGB)
input_image = input_image.to(torch.float32) / 255
input_image = input_image.unsqueeze(0)
input_image = F.interpolate(input_image, (256, 256), mode='bilinear')
viz_image(input_image[0])
我们再定义一个 MLP 类。稍后我们会并行地传入二维坐标。具体来说,我们会将输入定义为一个[1, 2, H, W]
形状的数据,其中通道数 2 表示(i, j)
格式的坐标。由于输入是以图像的形式并行输入的,我们可以用 的 2D 卷积来表示二维数据上的并行 MLP。所以在下面这个 MLP 里,我们只用到 卷积、激活函数、归一化三种层。按照傅里叶特征论文的官方示例,网络最后要用一个 Sigmoid 激活函数调整输出的范围。
class MLP(nn.Module):
def __init__(self, in_c, out_c=3, hiden_states=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Conv2d(in_c, hiden_states, 1), nn.ReLU(), nn.BatchNorm2d(hiden_states),
nn.Conv2d(hiden_states, hiden_states, 1), nn.ReLU(), nn.BatchNorm2d(hiden_states),
nn.Conv2d(hiden_states, hiden_states, 1), nn.ReLU(), nn.BatchNorm2d(hiden_states),
nn.Conv2d(hiden_states, out_c, 1), nn.Sigmoid()
)
def forward(self, x):
return self.mlp(x)
之后我们来定义训练数据。在一般的任务中,输入输出都是从训练集获取的。而在这个任务中,输入是二维坐标,输出是图像的颜色值。输出图像input_image
我们刚刚已经读取完毕了,现在只需要构建输入坐标即可。我们可以用下面的代码构建一个[1, 2, H, W]
形状的二维网格,grid[0, :, i, j]
处的数据是其坐标(i, j)
本身。当然,由于神经网络的输入一般要做归一化,所以我们会把原本0~H
和0~W
里的高宽坐标缩放都到0~1
。最终grid[0, :, i, j]==(i/H, j/W)
。
H, W = input_image.shape[2:]
h_coord = torch.linspace(0, 1, H)
w_coord = torch.linspace(0, 1, W)
grid = torch.stack(torch.meshgrid([h_coord, w_coord]), -1).permute(2, 0, 1).unsqueeze(0)
准备好一切后,我们就可以开始训练了。我们初始化模型model
和优化器optimizer
,和往常一样训练这个 MLP。如前所述,这个任务的输入输出非常直接,输入就是坐标网格grid
,目标输出就是图片input_image
。每训练一段时间,我们就把当前 MLP 拟合出的图片和误差打印出来。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MLP(2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
n_loops = 400
input_image = input_image.to(device)
grid = grid.to(device)
for epoch in tqdm(range(n_loops)):
output = model(grid)
loss = F.l1_loss(output, input_image)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 100 == 0 or epoch == n_loops - 1:
viz_image(output[0])
print(loss.item())
运行代码,大致能得到如下输出。可以看到,从一开始,图像就非常模糊。
不过,如果我们在把坐标输入进网络前先将其转换成位置编码——一种特殊的傅里叶特征,那么 MLP 就能清晰地拟合出原图片。这里我们暂时不去关注这段代码的实现细节。
class FourierFeature(nn.Module):
def __init__(self, in_c, out_c, scale):
super().__init__()
fourier_basis = torch.randn(in_c, out_c // 2) * scale
self.register_buffer('_fourier_basis', fourier_basis)
def forward(self, x):
N, C, H, W = x.shape
x = rearrange(x, 'n c h w -> (n h w) c')
x = x @ self._fourier_basis
x = rearrange(x, '(n h w) c -> n c h w', h = H, w = W)
x = 2 * torch.pi * x
x = torch.cat([torch.sin(x), torch.cos(x)], dim=1)
return x
feature_length = 256
model = MLP(feature_length).to(device)
fourier_feature = FourierFeature(2, feature_length, 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
n_loops = 400
for epoch in tqdm(range(n_loops)):
x = fourier_feature(grid)
output = model(x)
loss = F.l1_loss(output, input_image)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 100 == 0 or epoch == n_loops - 1:
viz_image(output[0])
print(loss.item())
prev_output = output
简单地对比一下,此前方法的主要问题是 MLP 无法拟合高频的信息(如图块边缘),只能生成模糊的图像。而使用位置编码后,MLP 从一开始就能较好地表示高频信息。可见,问题的关键在于如何让 MLP 更好地拟合数据的高频信息。
接下来,我们来从一个比较偏理论的角度看一看论文是怎么分析位置编码在拟合高频信息中的作用的。
核回归
傅里叶特征论文使用了神经正切核(Nerual Tangent Kernel, NTK)来分析 MLP 的学习规律,而 NTK 又是一种特殊的核回归 (Kernel Regression) 方法。在这一节里,我会通过代码来较为仔细地介绍核回归。下一节我会简单介绍 NTK。
和神经网络类似,核回归也是一种数学模型。给定训练集里的输入和输出,我们建立这样一个模型,用来拟合训练集表示的未知函数。相比之下,核回归的形式更加简单,我们有更多的数学工具来分析其性质。
核回归的设计思想来源于我们对于待拟合函数性质的观察:正如我们在前文的分析一样, 要用模型拟合一个函数时, 该模型在训练数据附近最好是连续变化的。离训练集输入越近, 输出就要和其对应输出越近。基于这种想法,核回归直接利用和所有数据的相似度来建立模型:假设训练数据为 , 我们定义了一个计算两个输入相似度指标 , 那么任意输入 的输出为:
也就是说,对于一个新输入 ,我们算它和所有输入 的相似度 ,并把相似度归一化。最后的输出 是现有 的相似度加权和。
这样看来,只要有了相似度指标,最终模型的形式也就决定下来了。我们把这个相似度指标称为「核」。至于为什么要把它叫做核,是因为这个相似度指标必须满足一些性质,比如非负、对称。但我们这里不用管那么多,只需要知道核是一种衡量距离的指标,决定了核就决定了核回归的形式。
我们来通过一个简单的一维函数拟合实验来进一步熟悉核回归。该实验代码在项目文件夹下的kernel_regression.ipynb
中。
先导入库。
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
再创建一个简单的非线性函数,做为我们的拟合目标。这个函数就是一个简单的周期为 2 的正弦函数乘上线性函数 。我们可以简单可视化一下函数在 之间的图像。
def func(x):
return np.sin(np.pi * x) * (1 - x)
xs = np.linspace(-1, 1, 100)
ys = func(xs)
plt.plot(xs, ys)
plt.show()
基于这个函数,我们等间距地选一些点做为训练数据。
sample_x = np.linspace(-1, 1, 10)
sample_y = func(sample_x)
plt.scatter(sample_x, sample_y)
plt.show()
有了数据后,我们来用核回归根据数据拟合这个函数。在决定核回归时,最重要的是决定核的形式。这里我们用正态分布的概率密度函数来表示核,该核唯一的超参数是标准差,需要我们根据拟合结果手动调整。标准差为1
的标准正态分布核的图像如下所示。由于最后要做归一化,正态分布密度函数的系数被省略掉了。
def kernel_func(x_ref, x_input, sigma=1):
return np.exp(-(x_input-x_ref)**2 / (2 * sigma**2))
xs = np.linspace(-1, 1, 100)
ys = kernel_func(0, xs)
plt.plot(xs, ys)
plt.show()
可以从图像中看出,离某输入越近(假设该输入是0
),那么相似度就越高。这符合我们对于相似度函数的要求。
有了核函数后,我们就直接得到了模型。根据核回归模型计算结果的函数为kernel_regression
。函数参数xs, ys
表示训练数据,x_input
表示测试时用的输入坐标,sigma
是核回归的超参数。
假设有n
个训练样本,有m
个测试输入,那么我们要计算每个测试输入对每个训练输入的n * m
个相似度,这些相似度会存到矩阵weight
里。为此,我们需要对xs
和x_input
做一些形状变换,再用上面定义的核函数kernel_func
求出每对相似度。有了相似度后,我们根据公式计算点乘结果weight_dot
及归一化系数weight_sum
,并最终计算出核回归的结果res
。
基于这个函数,我们可以将测试输入定义成[-1, 1]
上一些更密集的坐标,并用上面定义好的 10 个样本做为训练集,得到核回归的结果。
def kernel_regression(xs, ys, x_input, sigma=1):
# xs: [n, ]
# ys: [n, ]
# x_input: [m, ]
N = xs.shape[0]
xs = np.expand_dims(xs, 1)
ys = np.expand_dims(ys, 1)
x_input = np.expand_dims(x_input, 0)
x_input = np.repeat(x_input, N, 0)
weight = kernel_func(xs, x_input, sigma) # [n, m]
weight_sum = np.sum(weight, 0)
weight_dot = weight.T @ ys
weight_dot = np.squeeze(weight_dot, 1)
res = weight_dot / weight_sum
return res
sigma = 1
xs = np.linspace(-1, 1, 100)
ys = kernel_regression(sample_x, sample_y, xs, sigma)
plt.title(f'sigma = {sigma}')
plt.plot(xs, ys)
plt.show()
我们可以通过修改sigma
来得到不同的拟合效果。以下是我的一些结果:
可以看出,标准差越小,模型倾向于过拟合;随着标准差变大,曲线会逐渐平缓。我们需要不断调整超参数,在过拟合和欠拟合之间找到一个平衡。这种现象很容易解释:正态分布核函数的标准差越小,意味着每个训练数据的影响范围较小,那么测试样本更容易受到少数样本的影响;标准差增大之后,各个训练样本的影响开始共同起作用,我们拟合出的函数也越来越靠近正确的函数;但如果标准差过大,每个训练样本的影响都差不多,那么模型就什么都拟合不了了。
从实验结果中,我们能大致感受到核回归和低通滤波很像,都是将已知数据的平均效果施加在未知数据上。因此,在分析核回归的时候,往往会从频域分析核函数。如果核函数所代表低通滤波器的带宽 (bandwidth)越大,那么剩下的高频信息就更多,核回归也更容易拟合高频信息较多的数据。
神经正切核
那么,核回归是怎么和神经网络关联起来的呢?有研究表明,在一些特殊条件下,MLP 的最终优化结果可以用一个简单的核回归来表示。这不仅意味着我们可以神奇地提前预测梯度下降的结果,还可以根据核回归的性质来分析神经网络的部分原理。这种能表示神经网络学习结果的核被称为神经正切核(NTK)。
这些特殊条件包括 MLP 无限宽、SGD 学习率的学习率趋近 0 等。由于这些条件和实际神经网络的配置相差较远,我们难以直接用核回归预测复杂神经网络的结果。不过,我们依然可以基于这些理论来分析和神经网络相关的问题。傅里叶特征的分析就是建立在 NTK 上的。
NTK 的形式为
其中, 是参数为 的神经网络, 为内积运算。简单来看, 这个式子是说神经网络的核回归中,任意两个向量间的相似度等于网络对参数的偏导的内积的期望。基于 NTK,我们可以分析出很多神经网络的性质, 比如出乎意料地, 神经网络的结果和随机初始化的参数无关, 仅和网络结构和训练数据有关。
在学习傅里叶特征时, 我们不需要仔细研究这些这些理论, 而只需要知道一个结论: 一般上述 NTK 可以写成标量函数 , 也就是可以先算内积再求偏导。这意味用核回归表示神经网络时, 真正要关心的是输入间的内积。别看 NTK 看起来那么复杂, 傅里叶特征论文其实主要就用到了这一个性质。
为了从理论上讲清为什么 MLP 难以拟合高频,作者还提及了很多有关 NTK 的分析,包括一种叫做谱偏差(spectral bias)的现象:神经网络更容易学习到数据中的低频特征。可能作者默认读者已经熟悉了相关的理论背景,这部分论述经常会出现逻辑跳跃,很难读懂。当然,不懂这些理论不影响理解傅里叶特征。我建议不要去仔细阅读这篇文章有关谱偏差的那一部分。
正如我们在前文的核回归实验里观察到的,核回归模型能否学到高频取决于核函数的频域特征。因此,这部分分析和 NTK 的频域有关。对这部分内容感兴趣的话可以去阅读之前有关谱偏差的论文。
傅里叶特征的平移不变性
在上两节中,我们花了不少功夫去认识谱回归和 NTK。总结下来,其实我们只需要搞懂两件事:
- 神经网络最终的收敛效果可以由简单的核回归决定。而核回归重点是定义两个输入之间的相似度指标(核函数)。
- 表示神经网络的核回归相似度指标是 NTK,它其实又只取决于两个输入的内积。
根据这一性质,我们可以部分解释为什么在文章开头那个 MLP 拟合连续图像的实验中,位置编码可以提升 MLP 拟合高频信息的能力了。这和位置输入的特性有关。
当 MLP 的输入表示位置时, 我们希望模型对输入位置具有平移不变性。比如我们现在有一条三个样本组成的句子 。当我们同时改变句子的位置信息时, 比如将句子的位置改成 时, 网络能学出完全一样的东西。但显然不对输入位置做任何处理的话, 和 对神经网络来说是完全不同的意思。
而使用位置编码的话,情况就完全不同了。假如输入数据是二维坐标 ,我们可以用下面的式子建立一个维度为 的位置编码:
其中 是系数, 是一个投影矩阵, 用于把原来 2 D 的位置变成一个更长的位置编码。当然, 由于位置编码中既要有 也要有 , 所以最终的位置编码长度为 。
根据我们之前的分析,NTK 只取决于输入间的内积。算上位置编码后,一对输入位置 的内积为:
而根据三角函数和角公式可知:
这样,上面那个内积恰好可以写成:
上式完全由位置间的相对距离决定。上式决定了 NTK,NTK 又决定了神经网络的学习结果。所以,神经网络的收敛结果其实完全取决于输入间的相对距离,而不取决于它们的绝对距离。也因此,位置编码使得 MLP 对于输入位置有了平移不变性。
加入位置编码后,虽然 MLP 满足了平移不变性,但这并不代表 MLP 学习高频信息的能力就变强了。平移不变性能给我们带来什么好处呢?作者指出,当满足了平移不变性后,我们就能手动调整 NTK 的带宽了。回想一下我们上面做的核回归实验,如果我们能够调整核的带宽,就能决定函数是更加高频(尖锐)还是更加低频(平滑)。这里也是同理,如果我们能够调大 NTK 的带宽,让它保留更多高频信息,那么 MLP 也就能学到更多的高频信息。
作者在此处用信号处理的知识来分析平移不变性的好处,比如讲了新的 NTK 就像一个重建卷积核 (reconstruction filter),整个 MLP 就像是在做卷积。还是由于作者省略了很多推导细节,这部分逻辑很难读懂。我建议大家直接记住推理的结论:平移不变性使得我们能够调整 NTK 的带宽,从而调整 MLP 学习高频的能力。
那我们该怎么调整 NTK 的带宽呢?现在的新 NTK 由下面的式子决定:
为了方便分析, 我们假设 和 都是一维实数。那么, 如果我们令 的话:
这个式子能令你想到什么? 没错, 就是傅里叶变换。 较大的项就表示 NTK 的高频分量。我们可以通过修改前面的系数 来手动调整 NTK 的频域特征。我们能看到,位置编码其实就是在模拟傅里叶变换,所以作者把位置编码总结为傅里叶特征。
作者通过实验证明我们可以手动修改 NTK 的频谱。实验中, 作者令 。 表示位置编码只有第一项: 。不同 时 NTK 的空域和频域示意图如下所示。可以看出, 令 时, 即傅里叶特征所有项的系数都为 1 时, NTK 的高频分量不会衰减。这也意味着 MLP 学高频信息和低频信息的能力差不多。
随机傅里叶特征
现在我们已经知道傅里叶特征的公式是什么, 并知道如何设置其中的参数 了。现在, 还有一件事我们没有决定:该如何设置傅里叶特征的长度 呢?
既然我们说傅里叶特征就是把输入的位置做了一次傅里叶变换, 那么一般来讲, 傅里叶特征的长度应该和原图像的像素数一样。比如我们要表示一个 的图像, 那么我们就需要令 表示不同方向上的频率: 。但这样的话, 神经网络的参数就太多了。可不可以令 更小一点呢?
根据之前的研究Random features for large-scale kernel machines 表明, 我们不需要密集地采样傅里叶特征, 只需要稀疏地采样就行了。具体来说, 我们可以从某个分布随机采样 个频率 来, 这样的学习结果和密集采样差不多。当然, 根据前面的分析, 我们还是令所有系数 。在实验中, 作者发现, 从哪种分布里采样都无所谓, 关键是 的采样分布的标准差, 因为这个标准差决定了傅里叶特征的带宽, 也决定了网络拟合高频信息的能力。实验的结果如下:
我们可以不管图片里 是啥意思, 只需要知道 是三组不同的实验就行。虚线是密集采样傅里叶特征的误差,它的结果反映了一个「较好」的误差值。令人惊讶的是,不管从哪种分布里采样 , 最后学出来的网络误差都差不多。问题的关键在于采样分布的标准差。把标准差调得够好的话, 模型的误差甚至低于密集采样的误差。
也就是说,虽然我们花半天分析了位置编码和傅里叶变换的关系,但我们没必要照着傅里叶变换那样密集地采样频率,只需要随机选一些频率即可。当然,这个结论只对 MLP 拟合连续数据的任务有效,和 Transformer 里的位置编码无关。
代码实现随机傅里叶特征
现在,我们可以回到博文开头的代码,看一下随机傅里叶特征是怎么实现的。
class FourierFeature(nn.Module):
def __init__(self, in_c, out_c, scale):
super().__init__()
fourier_basis = torch.randn(in_c, out_c // 2) * scale
self.register_buffer('_fourier_basis', fourier_basis)
def forward(self, x):
N, C, H, W = x.shape
x = rearrange(x, 'n c h w -> (n h w) c')
x = x @ self._fourier_basis
x = rearrange(x, '(n h w) c -> n c h w', h = H, w = W)
x = 2 * torch.pi * x
x = torch.cat([torch.sin(x), torch.cos(x)], dim=1)
return x
feature_length = 256
model = MLP(feature_length).to(device)
fourier_feature = FourierFeature(2, feature_length, 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
n_loops = 400
for epoch in tqdm(range(n_loops)):
x = fourier_feature(grid)
output = model(x)
loss = F.l1_loss(output, input_image)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 100 == 0 or epoch == n_loops - 1:
viz_image(output[0])
print(loss.item())
prev_output = output
傅里叶特征通过类FourierFeature
实现。其代码如下:
class FourierFeature(nn.Module):
def __init__(self, in_c, out_c, scale):
super().__init__()
fourier_basis = torch.randn(in_c, out_c // 2) * scale
self.register_buffer('_fourier_basis', fourier_basis)
def forward(self, x):
N, C, H, W = x.shape
x = rearrange(x, 'n c h w -> (n h w) c')
x = x @ self._fourier_basis
x = rearrange(x, '(n h w) c -> n c h w', h = H, w = W)
x = 2 * torch.pi * x
x = torch.cat([torch.sin(x), torch.cos(x)], dim=1)
return x
构造函数里的fourier_basis
表示随机傅里叶特征的频率,对应论文公式里的,scale
表示采样的标准差。初始化好了随机频率后,对于输入位置x
,只要按照公式将其投影到长度为out_c / 2
的向量上,再对向量的每一个分量求sin, cos
即可。按照之前的分析,我们令所有系数 为,所以不需要对输出向量乘系数。
傅里叶特征在 StyleGAN3 里的应用
傅里叶特征最经典的应用就是 NeRF 这类过拟合连续数据任务。除此之外,傅里叶特征另一次大展身手是在 StyleGAN3 中。
StyleGAN3 希望通过平滑地移动生成网络的输入来使输出图片也发生对应的移动。为此,StyleGAN3 将生成网络的输入定义为频域上的一个有限带宽图像信号:根据信号处理知识,我们能够将有限带宽信号转换成空域上无限连续的信号。也就是说,不管输入的分辨率(采样率)多低,我们都能够平滑地移动输入图片。StyleGAN3 借助随机傅里叶特征来实现这样一个频域图像。
以下代码选自 StyleGAN3 中傅里叶特征的构造函数。这个函数的关键是随机生成一些频率固定,但方向可以不同的傅里叶频率。函数先随机采样了一些频率,再将它们归一化,最后乘上指定的带宽bandwidth
,保证所有频率大小相等。
class SynthesisInput(torch.nn.Module):
def __init__(self,
w_dim, # Intermediate latent (W) dimensionality.
channels, # Number of output channels.
size, # Output spatial size: int or [width, height].
sampling_rate, # Output sampling rate.
bandwidth, # Output bandwidth.
):
super().__init__()
self.w_dim = w_dim
self.channels = channels
self.size = np.broadcast_to(np.asarray(size), [2])
self.sampling_rate = sampling_rate
self.bandwidth = bandwidth
# Draw random frequencies from uniform 2D disc.
freqs = torch.randn([self.channels, 2])
radii = freqs.square().sum(dim=1, keepdim=True).sqrt()
freqs /= radii * radii.square().exp().pow(0.25)
freqs *= bandwidth
phases = torch.rand([self.channels]) - 0.5
而在使用这个类获取网络输入时,和刚刚的 MLP 实现一样,我们会先生成一个二维坐标表格grid
用于查询连续图片每一处的颜色值,再将其投影到各个频率上,并计算新向量的正弦函数。
这段代码中,有两块和我们自己的实现不太一样。第一,StyleGAN3 允许对输入坐标做仿射变换(平移和旋转)。仿射变换对坐标的影响最终会转化成对三角函数相位phases
和频率freqs
的影响。第二,在计算三角函数时,StyleGAN3 只用了正弦函数,没有用余弦函数。
def forward(self, ...):
...
# Transform frequencies.
phases = ...
freqs = ...
# Construct sampling grid.
theta = torch.eye(2, 3, device=w.device)
theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False)
# Compute Fourier features.
x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel]
x = x + phases.unsqueeze(1).unsqueeze(2)
x = torch.sin(x * (np.pi * 2))
x = x * amplitudes.unsqueeze(1).unsqueeze(2)
...
# Ensure correct shape.
x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]
return x
我们在 MLP 拟合连续图像的实验里复现一下这两个改动。首先是二维仿射变换。给定旋转角theta
和两个方向的平移tx, ty
,我们能够构造出一个 的仿射变换矩阵。把它乘上坐标[x, y, 1]
后,就能得到仿射变换的输出。我们对输入坐标grid
做仿射变换后得到grid_ext
,再用grid_ext
跑一遍傅里叶特征和 MLP。
N, C, H, W = grid.shape
tx = 50 / H
ty = 0
theta = torch.tensor(torch.pi * 1 / 8)
affine_matrix = torch.tensor([
[torch.cos(theta), -torch.sin(theta), tx],
[torch.sin(theta), torch.cos(theta), ty],
[0, 0, 1]
]
).to(device)
grid_ext = torch.ones(N, 3, H, W).to(device)
grid_ext[:, :2] = grid.clone()
grid_ext = grid_ext.permute(0, 2, 3, 1)
grid_ext = (grid_ext @ affine_matrix.T)
grid_ext = grid_ext.permute(0, 3, 1, 2)[:, :2]
x = fourier_feature(grid_ext)
output = model(x)
viz_image(output[0])
在示例代码中,我们可以得到旋转 45 度并向下平移 50 个像素的图片。可以看到,变换成功了。这体现了连续数据的好处:我们可以在任意位置对数据采样。当然,由于这种连续数据是通过过拟合实现的,在训练集没有覆盖的坐标处无法得到有意义的颜色值。
之后,我们来尝试在傅里叶特征中只用正弦函数。我们将投影矩阵的输出通道数从out_c / 2
变成out_c
,再在forward
里只用sin
而不是同时用sin, cos
。经实验,这样改了后完全不影响重建质量,甚至由于通道数更多了,重建效果更好了。
class FourierFeature(nn.Module):
def __init__(self, in_c, out_c, scale):
super().__init__()
fourier_basis = torch.randn(in_c, out_c) * scale
self.register_buffer('_fourier_basis', fourier_basis)
def forward(self, x):
N, C, H, W = x.shape
x = rearrange(x, 'n c h w -> (n h w) c')
x = x @ self._fourier_basis
x = rearrange(x, '(n h w) c -> n c h w', h = H, w = W)
x = 2 * torch.pi * x
x = torch.sin(x)
return x
StyleGAN3 论文并没有讲为什么只用sin
,网上也很少有人讨论傅里叶特征的实现细节。我猜傅里叶特征并不是非得和傅里叶变换完全对应,毕竟它只是用来给神经网络提供更多信息,而没有什么严格的意义。只要把输入坐标分解成不同频率后,神经网络就能很好地学习了。
只用sin
而不是同时用sin, cos
后,似乎我们之前对 NTK 平移不变的推导完全失效了。但是,根据三角函数的周期性可知,只要是把输入映射到三角函数上后,网络主要是从位置间的相对关系学东西。绝对位置对网络来说没有那么重要,不同的绝对位置只是让所有三角函数差了一个相位而已。只用sin
的神经网络似乎也对绝对位置不敏感。为了证明这一点,我把原来位于[0, 1]
间的坐标做了一个幅度为10
的平移。结果网络的误差几乎没变。
for epoch in tqdm(range(n_loops)):
x = fourier_feature(grid + 10)
output = model2(x)
loss = F.l1_loss(output, input_image)
optimizer.zero_grad()
loss.backward()
optimizer.step()
根据这些实验结果,我感觉是不是从 NTK 的角度来分析傅里叶特征完全没有必要?是不是只要从直觉上理解傅里叶特征的作用就行了?按我的理解,傅里叶特征在真正意义在于显式把网络对于不同频率的关注度建模出来,从而辅助网络学习高频细节。
总结
在这篇博文中,我们学习了傅里叶特征及其应用,并顺带了解其背后有关核回归、NTK 的有关理论知识。这些知识很杂乱,我来按逻辑顺序把它们整理一下。
为了解释为什么 NeRF 中的位置编码有效,傅里叶特征论文研究了用 MLP 拟合连续数据这一类任务中如何让 MLP 更好地学到高频信息。论文有两大主要结论:
- 通过从 NTK 理论的分析,位置编码其实是一种特殊的傅里叶特征。这种特征具有平移不变性。因此,神经网络就像是在对某个输入信号做卷积。而我们可以通过调整傅里叶特征的参数来调整卷积的带宽,也就是调整网络对于不同频率的关注程度,从而使得网络不会忽略高频信息。
- 傅里叶特征的频率不需要密集采样,只需要从任意一个分布随机稀疏采样。影响效果的关键是采样分布的标准差,它决定了傅里叶特征的带宽,也就决定了网络是否能关注到高频信息。
除了过拟合连续数据外,傅里叶特征的另一个作用是直接表示带宽有限信号,以实现在空域上的连续采样。StyleGAN3 在用傅里叶特征时,允许对输入坐标进行仿射变换,并且计算特征时只用了正弦函数而不是同时用正弦、余弦函数。这表明有关 NTK 的理论分析可能是没有必要的,主要说明问题的还是实验结果。
傅里叶特征论文仅研究了拟合连续数据这一类问题,没有讨论 Transformer 中位置编码的作用。论文中的一些结论可能无法适用。比如在大模型的位置编码中,我们还是得用密集的sin, cos 变换来表示位置编码。不过,我们可以依然借助该论文中提到的理论分析工具,来尝试分析所有位置编码的行为。
只通过文字理解可能还不太够,欢迎大家尝试我为这篇博客写的 Notebook,通过动手做实验来加深理解。https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/FourierFeature
#让模型预见分布漂移
动态系统颠覆性设计引领时域泛化新革命本研究提出了一种方法,能够在领域数据分布持续变化的动态环境中,基于随机时刻观测的数据分布,在任意时刻生成适用的神经网络。
下图展示了模型在领域数据随时间发生旋转和膨胀时的泛化表现。通过在一些随机时间点(蓝色标记点)的观测,模型可以在任意时刻生成适用的神经网络,其决策边界始终与数据分布保持协调一致。
01 摘要
在实际应用中,数据集的数据分布往往随着时间而不断变化,预测模型需要持续更新以保持准确性。时域泛化旨在预测未来数据分布,从而提前更新模型,使模型与数据同步变化。
然而,传统方法假设领域数据在固定时间间隔内收集,忽视了现实任务中数据集采集的随机性和不定时性,无法应对数据分布在连续时间上的变化。此外,传统方法也难以保证泛化过程在整个时间流中保持稳定和可控。
为此,本文提出了连续时域泛化任务,并设计了一个基于模型动态系统的时域泛化框架 Koodos,使得模型在连续时间中与数据分布的变化始终保持协调一致。Koodos 通过库普曼算子将模型的复杂非线性动态转化为可学习的连续动态系统,同时利用先验知识以确保泛化过程的稳定性和可控性。
实验表明,Koodos 显著超越现有方法,为时域泛化开辟了全新的研究方向。
02 论文信息
论文链接:
https://arxiv.org/pdf/2405.16075
开源代码:
https://github.com/Zekun-Cai/Koodos/
OpenReview:
https://openreview.net/forum?id=G24fOpC3JE
我们在代码库中提供了详细的逐步教程,涵盖了 Koodos 的实现、核心概念的解读以及可视化演示:
https://github.com/Zekun-Cai/Koodos/blob/main/Tutorial_for_Koodos.ipynb
整个教程流程紧凑,十分钟即可快使掌握 Koodos 的使用方法,力荐尝试!
03 情景导入
在实际应用中,训练数据的分布通常与测试数据不同,导致模型在训练环境之外的泛化能力受限。领域泛化(Domain Generalization, DG)作为一种重要的机器学习策略,旨在学习一个能够在未见目标领域中也保持良好表现的模型。
近年来研究人员发现,在动态环境中,领域数据(Domain Data)分布往往具有显著的时间依赖性,这促使了时域泛化(Temporal Domain Generalization, TDG)技术的快速发展。
时域泛化将多个领域视为一个时间序列而非一组独立的静态个体,利用历史领域预测未来领域,从而实现对模型参数的提前调整,显著提升了传统 DG 方法的效果。
然而,现有的时域泛化研究集中在“离散时间域”假设下,即假设领域数据在固定时间间隔(如逐周或逐年)收集。基于这一假设,概率模型被用于预测时域演变,例如通过隐变量模型生成未来数据,或利用序列模型(如 LSTM)预测未来的模型参数。
然而在现实中,领域数据的观测并不总是在离散、规律的时间点上,而是随机且稀疏地分布在连续时间轴上。例如,图 1 展示了一个典型的例子——基于推文数据进行社交媒体舆情预测。
与传统 TDG 假设的领域在时间轴上规律分布不同,实际中我们只能在特定事件(如总统辩论)发生时获得一个域,而这些事件的发生时间并不固定。同时,概念漂移(Concept Drift)在时间轴上发生,即领域数据分布随着时间不断演变:如活跃用户增加、新交互行为形成、年龄与性别分布变化等。
理想情况下,每个时态域对应的预测模型也应随时间逐渐调整,以应对这种概念漂移。最后,由于未来的域采集时间未知,我们希望可以泛化预测模型到未来任意时刻。
▲ 图1:连续时域泛化示意图。图中展示了通过推文训练分类模型进行舆情预测。其中训练域仅能在特定政治事件(如总统辩论)前后采集。我们希望通过这些不规律时间分布的训练域来捕捉分布漂移,并最终使模型能够推广到任意未来时刻。
事实上,领域分布在连续时间上的场景十分常见,例如:
事件驱动的数据采集:仅在特定事件发生时采集领域数据,事件之间没有数据。
流数据的随机观测:领域数据在数据流的任意时间点开始或结束采集,而非持续进行。
离散时态域但缺失:尽管领域数据基于离散时间点采集,但部分时间节点的领域数据缺失。
为了应对这些场景中的模型泛化,我们提出了“连续时域泛化”(Continuous Temporal Domain Generalization, CTDG)任务,其中观测和未观测的领域均分布于连续时间轴上随机的时间点。CTDG 关注于如何表征时态领域的连续动态,使得模型能够在任意时间点实现稳定、适应性的调整,从而完成泛化预测。
04 核心挑战
CTDG 任务的挑战远超传统的 TDG 方法。CTDG 不仅需要处理不规律时间分布的训练域,更重要的是,它旨在让模型泛化到任意时刻,即要求在连续时间的每个点上都能精确描述模型状态。
而 TDG 方法则仅关注未来的单步泛化:在观测点优化出当前模型状态后,只需将其外推一步即可。这使得 CTDG 区别于 TDG 任务:CTDG 的关键在于如何在连续时间轴上同步数据分布和模型参数的动态演变,而不是仅局限于未来某一特定时刻的模型表现。
具体而言,与 TDG 任务相比,CTDG 的复杂性主要来自以下几个尚未被充分探索的核心挑战:
如何建模数据动态并同步模型动态:CTDG 要求在连续时间轴上捕捉领域数据的动态,并据此同步调整模型状态。然而,数据动态本身难以直接观测,需要通过观测时间点来学习。此外,模型动态的演变过程也同样复杂。理解数据演变如何驱动模型演变构成了 CTDG 的首要挑战。
如何在高度非线性模型动态中捕捉主动态:领域数据的预测模型通常依赖过参数化(over-parametrized)的深度神经网络,模型动态因此呈现出高维、非线性的复杂特征。这导致模型的主动态嵌藏在大量潜在维度中。如何有效提取并将这些主动态映射到可学习的空间,是 CTDG 任务中的另一重大挑战。
如何确保长期泛化的稳定性和可控性:为实现未来任意时刻的泛化,CTDG 必须确保模型的长期稳定性。此外,在许多情况下,我们可能拥有数据动态的高层次先验知识。如何将这些先验知识嵌入 CTDG 的优化过程中,进而提升泛化的稳定性和可控性,是一个重要的开放性问题。
05 技术方法
5.1 问题定义
在 CTDG 中,一个域 表示在时间 采集的数据集,由实例集 组成,其中 和 分别为特征值,目标值和实例数。我们重点关注连续时间上的渐进性概念漂移,表示为领域数据的条件概率分布 随时间平滑变化。
在训练阶段,模型接收一系列在不规律时间点 上收集的观测域 ,其中每个时间点 是定义在连续时间轴 上的实数,且满足 $t_1<t_2<\ldots<t_t$ 。<="" p="">
在每个 上,模型学习到领域数据 的预测函数 ,其中 表示 时刻的模型参数。CTDG 的目标是建模参数的动态变化,以便在任意给定时刻 上预测模型参数 ,从而得到泛化模型 。
在后续部分中,我们使用简写符号 、、 和 ,分别表示在时间 上的 、 、 和 。
5.2 设计思路
我们的方法通过模型与数据的同步、动态简化表示,以及高效的联合优化展开。具体思路如下:
1. 同步数据和模型的动态:我们证明了连续时域中模型参数的连续性,而后借助神经微分方程(Neural ODE)建立模型动态系统,从而实现模型动态与数据动态的同步。
2. 表征高维动态到低维空间:我们将高维模型参数映射到一个结构化的库普曼空间(Koopman Space)中。该空间通过可学习的低维线性动态来捕捉模型的主要动态。
3. 联合优化模型与其动态:我们将单个领域的模型学习与各时间点上的连续动态进行联合优化,并设计了归纳偏置的约束接口,通过端到端优化保证泛化的稳定性和可控性。
▲ 模型设计
5.3 解决方案
Step 1. 数据动态建模与模型动态同步
分布变化的连续性假设:我们首先假设数据分布在时间上具有连续演化的特性,即条件概率分布 随时间平滑变化, 其演化规律可由一个函数 所描述的动态系统刻画。尽管真实世界中的渐进概念漂移可能较为复杂,但因概念漂移通常源于底层的连续过程(如自然、生物、物理、社会或经济因素),这一假设不失普适性。
分布变化引发的模型参数连续演化:基于上述假设,模型的函数功能空间应随数据分布变化同步调整。我们借助常微分方程来描述这一过程:
由此可推导出模型参数的演化满足:
其中, 是 对 的雅可比矩阵。
这一结果表明,如果数据分布的演化在时间上具有连续性,那么的演化过程也具有连续性,即模型参数会随数据分布的变化而平滑调整。上式为建立了一个由微分方程描述的模型动态系统。
模型动态系统学习:由于数据动态 的具体形式未知, 直接求解上述微分方程并不可行。为此, 我们引入一个由神经网络定义的连续动态系统, 用可学习的函数 描述模型参数 的变化。
通过鼓励模型动态和数据动态之间的拓扑共轭(Topological Conjugation)关系使 逼近真实动态。具体而言, 拓扑共轭要求通过泛化获得的模型参数与直接训练得到的参数保持一致。为此, 我们设定以下优化目标, 以学习 的参数 :
其中, 通过在时刻 的领域上直接训练获得, 则表示从时间 通过动态 演变至 的泛化参数:
通过这一优化过程,我们建立了模型动态与数据动态之间的同步机制。借助动态函数,我们可以在任意时刻精确求解模型的状态。
Step 2. 通过库普曼算子简化模型动态
非线性动态线性化
在实际任务中, 预测模型通常依赖于过参数化的深度神经网络, 使得模型动态 呈现为在高维空间中纠缠的非线性动态。直接对 建模不仅计算量大,且极易导致泛化不稳定。
然而, 受数据动态 的支配, 而数据动态通常是简单、可预测的。这意味着在过参数化空间中,模型的主动态(Principal Dynamics)可以在适当转换的空间内进行更易于管理的表示。
受此驱动,我们引入库普曼理论(Koopman Theory)来简化复杂的模型动态。库普曼理论在保持动态系统特征的同时将复杂的非线性动态线性化。
具体而言, 我们定义一个库普曼嵌入函数 , 将原始的高维参数空间映射到一个低维的库普曼空间中:
其中, 表示库普曼空间中的低维表示。通过库普曼算子 , 我们可以在线性空间中刻画 的动态:
一旦获得了简化的动态表示,我们可以在库普曼空间中更新模型参数,而后将其反映射回原始参数空间:
最终,通过库普曼算子的引入,我们实现了对模型动态的简化,保证了泛化过程的稳健性。
Step 3. 联合优化与先验知识结合
模型及其动力学的联合优化:我们对多个组件同时施加约束确保模型能稳定泛化,其包含以下关键项:
- 预测准确性:通过最小化预测误差,使预测模型在每个观测时间点都能准确预测实际数据。
- 泛化准确性:通过最小化预测误差,使泛化模型在每个观测时间点都能准确预测实际数据。
- 重构一致性:确保模型参数在原始空间与库普曼空间之间的转换具有一致性。
- 动态保真性:约束库普曼空间的动态行为,使得映射后的空间符合预期的动态系统特征。
- 参数一致性:确保泛化模型参数映射回原始空间后与预测模型参数保持一致。
利用库普曼算子评估和控制泛化过程:引入库普曼理论的另一优势在于,我们可以通过库普曼算子的谱特性来评估模型的长期稳定性。此外,还可以在库普曼算子中施加约束来控制模型的动态行为。
1. 系统稳定性评估
通过观察库普曼算子的特征值,可以判断系统是否稳定:
- 若所有特征值实部为负,系统会稳定地趋向于一个平衡状态。
- 若存在特征值实部为正,系统将变得不稳定,模型在未来可能会崩塌。
- 若特征值实部为零,系统可能表现出周期性行为。通过分析这些特征值的分布,我们可以预测系统的长期行为,识别模型在未来是否可能出现崩溃的风险。
2. 泛化过程约束
我们可以通过对库普曼算子施加显式约束来调控模型的动态行为。例如:
- 周期性约束:当数据动态为周期性时,可将库普曼算子设为反对称矩阵,使其特征值为纯虚数,从而使模型表现出周期性行为。
- 低秩近似:将表示为低秩矩阵,有助于控制模型的自由度,避免过拟合到次要信息。
通过这些手段,我们不仅提高了泛化的长期稳定性,还增强了模型在特定任务中的可控性。
06 实验
6.1 实验设置
为验证算法效果,我们使用了合成数据集和多种真实世界场景的数据集:
合成数据集:包括 Rotated 2-Moons 和 Rotated MNIST 数据集,通过在连续时间区间内随机生成时间戳,并对 Moons 和 MNIST 数据按时间戳逐步旋转生成连续时域。
真实世界数据集:
- 事件驱动数据集 Cyclone:基于热带气旋的卫星图像预测风力强度,气旋发生日期对应连续时域。
- 流数据集 Twitter 和 House:分别从任意时间段抽取推文和房价数据流构成一个领域,多次随机抽取形成连续时域。
- 不规则离散数据集 Yearbook:人像图片预测性别,从 84 年中随机抽取 40 年数据作为连续时域。
6.2 实验结果与分析
定量分析
我们首先对比了 Koodos 方法与各基线方法的定量性能。表 1 显示,Koodos 方法在所有数据集上展现了显著的性能提升。
在合成数据集上,Koodos 能够轻松应对持续的概念漂移,而所有基线方法在这种场景下全部失效。
在真实世界数据集上,尽管某些基线方法(如 CIDA、DRAIN 和 DeepODE)在少数场景中略有表现,但其相较于简单方法(如 Offline)的改进非常有限。相比之下,Koodos 显著优于所有现有方法,彰显出在时域泛化任务中考虑分布连续变化的关键作用。
▲ 实验结果
定性分析
决策边界:为直观展示泛化效果,我们在 Rotated 2-Moons 数据集上进行了决策边界的可视化。该任务具有极高难度:模型需在 0 到 35 秒左右的 35 个连续时域上训练,随后泛化到不规律分布在 35 到 50 秒的 15 个测试域。而现有方法通常只能泛化至未来的一个时域(T+1),且难以处理不规律的时间分布。图 3 从 15 个测试域中选取了 7 个进行可视化。结果清晰地表明,基线方法在应对连续时域的动态变化时表现不足。随着时间推进,决策边界逐渐偏离理想状态。尤其是最新的 DRAIN 方法(ICLR23)在多步泛化任务中明显失效。
相比之下,Koodos 在所有测试域上展现出卓越的泛化能力,始终保持清晰、准确的决策边界,与实际数据分布变化高度同步。这一效果突显了 Koodos 在时域泛化任务中的革命性优势。
▲ 图3:2-Moons 数据集决策边界的可视化(紫色和黄色表示数据区域,红线表示决策边界)。从上到下比较了两种基线方法和 Koodos;从左到右显示了部分测试域(15 选 7,所有测试域的分布在时间轴上用红点标记)。
模型演变轨迹:为更深入地分析模型的泛化能力,我们通过 t-SNE 降维,将不同方法的模型参数的演变过程(Model Evolution Trajectory)在隐空间中可视化(图 4)。
可以看出,Koodos 的轨迹呈现出平滑而有规律的螺旋式上升路径,从训练域平滑延伸至测试域。这一轨迹表明,Koodos 能够在隐空间中有效捕捉数据分布的连续变化,并随时间自然地扩展泛化。
相比之下,基线模型的轨迹在隐空间中缺乏清晰结构,随着时间推移,逐渐出现明显的偏离,未能形成一致的动态模式。
▲ 图4:模型状态在隐空间中的时空轨迹。Koodos 展现出与数据动态和谐同步的模型动态。
时域泛化的分析与控制:在 Koodos 模型中,库普曼算子为分析模型动态提供了有效手段。我们对 Koodos 在 2-Moons 数据集上分析表明,库普曼算子的特征值在复平面上分布在稳定区和不稳定区,这意味着 Koodos 在中短期内能稳定泛化,但在极长时间的预测上将会逐渐失去稳定性,偏离预期路径(图 5b)。
为提升模型的稳定性,我们通过将库普曼算子配置为反对称矩阵(即Koodos版本),确保所有特征值为纯虚数,使模型具有周期性稳定特性。在这一配置下,Koodos展现出高度一致的轨迹,即使在长时间外推过程中依然保持稳定和准确,证明了引入先验知识对增强模型稳健性的效果(图 5c)。
,时长00:23
▲ 图5:非受控和受控条件下的极长期泛化预测模型轨迹。a:部分训练域数据;b:不受控,模型最终偏离预期;c:受控,模型始终稳定且准确。
▲ 图5:非受控和受控条件下的极长期泛化预测模型轨迹。a:部分训练域数据;b:不受控,模型最终偏离预期;c:受控,模型始终稳定且准确。
07 结论
我们设计了一种基于模型连续动态系统的时域泛化方法,能够在数据域随时间逐渐演变的环境中,实现泛化模型的稳定性与可控性。未来,我们计划从多个方向进一步拓展这一技术的应用:
生成式模型扩展:时域泛化与生成式模型任务有天然的关联,Koodos 所具备的泛化能力能够为神经网络生成技术带来新的可能。
非时态泛化任务:Koodos 的应用并不局限于时域泛化,它也可以适用于其他分布变化的任务中。我们计划探索其在非时态领域的应用。
大模型集成:我们将探索时域泛化在大模型中的集成,帮助 LLM 在复杂多变的分布中保持鲁棒性和稳定性。
我们对时域泛化任务在未来的广阔应用前景充满期待。如有任何问题或合作意向,欢迎联系我们!
邮箱: caizekun@csis.u-tokyo.ac.jp
GitHub: https://github.com/Zekun-Cai/Koodos/
Paper: https://arxiv.org/pdf/2405.16075
#Scaling Laws for Precision 解读
本文探讨了模型量化对性能的影响,并提供了关于训练时量化和后训练量化的实用建议。文章强调了在不同训练精度下,如何平衡模型性能和量化损失,以及在实际应用中选择合适的量化策略的重要性。
来自链接 https://zhuanlan.zhihu.com/p/6848989432
原文
前置知识:
scaling law:
- Training Compute-Optimal Large Language Models(Chinchilla scaling law)
个人讨厌晦涩难懂+无法应用于实际场景的"装逼结论",因此先按照自己的理解帮大家rephrase一下论文的主要发现(in plain language):
首先,这是一篇研究精度(precision)、参数量(parameters)和训练数据量(tokens)之间关系的重要论文。
1. 关于后训练量化(Post-Training Quantization, PTQ):1.1 基本概念
- 指的是pretrain以较高精度(bf16)进行,结束后再量化到更低精度(如int4)
1.2 结论1
模型预训练的trained_token/parameter比率越高,预训练结束后,使用PTQ带来的性能下降就越大。这里作者没写明白有误导性!!!实际上这个结论指的是:
- 我们都知道PTQ一定会带来性能下降(PTQ后,valid loss相比pretrain之后会上升),这个下降可以用
- 论文提出了预测这个下降值的公式:
- 其中:
- 训练数据量D越大,PTQ带来的损失越大(正相关)
- 参数量N越大,PTQ带来的损失越小(负相关)
- 量化后的精度Ppost越低,损失增加越多(负指数关系)
- N: 参数量
- D: 训练token数
- : PTQ后的精度
- γγγ: 拟合常数
- 这个公式告诉我们:
- 注意,δPTQ还有一种完整形式(section 5) 同时考虑了训练精度和推理精度(继续往后看):
- 那么如果你必须进行PTQ,那么对于同样参数量大小的模型,被训更多token的模型的 δPTQ 会比喂更少数据的模型要大。但最终loss的绝对数量是多少并不一定,因为即便 δPTQ 这个正数会让loss上升(性能下降),但模型终归被训了更多数据,这么一抵消可能loss还是会下降。相当于两只无形的手(数据量的上升带来的loss下降、PTQ带来的loss上升)在掰手腕;给定模型参数量和固定的精度,具体谁能掰过谁会有一个打平手的cutoff数据量。
- 举例子,如果你要固定70B模型参数量并pretrain时候采用bf16,并且pretrain后要PTQ到int4。那么采用两种数据量:
- a) 用10B token训出来模型
- b) 5B token训出来的模型
- 那么一定是a)情况的 δPTQ 更大,但最终PTQ结束之后的loss的数值是多少就不一定了。
- 因此作者也在原文中提到了**there exists an amount of pretraining data beyond which additional data is actively harmful to performance at inference-time (see top-left, Figure 2),也就是给定你要进行PTQ,那么对于你的实验设置,总有一个cutoff的数据量,称之为临界的数据量 Dcrit ,超过这个量后继续训练会导致PTQ后性能下降。这个临界点并不是说超过后训练数据就“有害”,而是说在进行PTQ后,性能的提升可能会被性能的下降所抵消。因此,在实际应用中,需要权衡训练数据量与模型量化后的性能。
- 论文给出了计算这个临界点的公式:
其他结论
- 在某些情况下,过度训练(more tokens)反而会让PTQ后的模型性能变差
- 更大的模型在相同的token/parameter比率下,对PTQ更鲁棒
- 对于固定大小的数据集,增加模型参数量可以提高PTQ的鲁棒性
- 这种规律在不同的PTQ方法中都存在(论文验证了GPTQ、AWQ和RTN三种方法)
训练精度的影响
- 训练时使用较低精度的模型在PTQ时性能下降较小
- 如果你知道模型最终需要被量化到很低的精度(比如4bit),那么在训练时就使用相对较低的精度(比如8bit)可能比使用高精度(比如16bit)更好,因为这样可以让模型在训练过程中就适应量化噪声。
- 实话说这个结论初看有点脱裤子放屁,因为太符合直觉了(bushi)。用脚想想就知道【训练用int8然后量化到int4】肯定比【训练用bf16然后量化到int4】要好,原文section 5:models trained in lower precision are more robust to post-train quantization in the sense of incurring lower degradation.
- 这也解释了为什么一些较新的大语言模型倾向于使用BF16而不是FP32来训练,因为这不仅可以节省计算资源,还可能让模型在后续量化时表现更好
1.3 PTQ造成loss degradation的深入分析1.3.1 两个竞争效应(section 5)
在分析PTQ对模型性能的影响时,论文发现了两个相互竞争的效应:
- Robustification效应
- 低精度训练会让模型更适应量化噪声
- 这使得模型在后续PTQ时更加鲁棒
- 可以理解为模型学会了如何在噪声环境中运作
- Overtraining效应
- 低精度训练会降低模型的有效参数量(),这意味着模型在相同的数据量下“看起来”参数量更少,从而在PTQ时对参数量化的敏感性增加
- 因为 和 成正比, 较低的Neff理论上会导致更大的性能下降: (section 5这边第一次读还以为写错了)。作者说的 实际上应该参考公式 9 变为 ,随着 的增加, 确实增加, 也就是成正比。说明白点就是低精度训练会下降Neff, 也就是一个 模型的可能有效的参数只有 10 B , 然后 变大, 然后根据section 3 的公式就会造成更大的degradation)
- 这个效应与Robustification效应相反
在实践中,Robustification效应通常占主导,这就是为什么低精度训练的模型在PTQ时表现更好。
1.3.2 精度阈值效应
一个重要发现是,当精度低于5-bit时,PTQ带来的性能下降会急剧增加:
- 在高精度区间(如8-bit以上),D/N比率的增加对性能的影响相对温和
- 在5-bit以下,即使很小的D/N比率增加也可能导致显著的性能下降
- 这个发现对实践中选择量化精度有重要指导意义-- 在实际应用中,应避免将模型量化到低于5-bit的精度,除非有特定的需求和相应的优化技术支持
1.3.3 理论解释
论文在附录中提供了两个可能的理论解释:
Sharpness假说
- 模型在训练过程中会逐渐变得更"sharp"-- 随着训练的进行,模型的损失函数变得更加“尖锐”(sharp),即梯度和Hessian矩阵的特征值增加,这导致模型对参数扰动更加敏感。因此,PTQ带来的参数量化噪声会对尖锐的损失函数产生更大的影响。
- Sharp的模型对参数扰动更敏感
- 这种敏感性会随着训练的进行而增加
- 这解释了为什么过度训练可能导致更大的PTQ降质
分层学习假说
- 模型通过分层方式学习特征-- 模型通过逐步学习更复杂的特征,这些特征依赖于之前学习的基础特征。量化噪声影响基础特征,会级联地影响到更高层次的复杂特征,从而导致整体性能的下降。
- 早期学习基础特征,后期学习复杂特征
- 复杂特征依赖于基础特征的准确性
- 当基础特征受到量化噪声影响时,会对依赖它们的复杂特征造成级联效应
- 这解释了为什么训练时间越长,模型对量化越敏感
2. 关于训练时量化(Training-time Quantization)
2.1 基本概念
论文中将训练时量化分为两种情况:
- 仅量化权重(Quantization-Aware Training, QAT):只将模型的权重量化到低精度,其他部分保持高精度,以适应推理阶段的低精度环境。
- 全面量化(Low-precision Training):同时量化模型的权重、激活值和注意力计算(即键-值缓存),以减少计算资源需求。
注意:这里的权重指模型中所有线性层(Linear layers)的权重矩阵,包括:
- Transformer 中的所有投影矩阵(例如 query、key、value 的投影权重);
- 嵌入层(Embedding layers)权重矩阵;
- 最终输出层的权重矩阵。
但在论文的实验中未对嵌入层(Embedding layer)进行量化。
量化实现细节:
- 论文遵循了 FP8 训练的标准规范(Micikevicius et al., 2022);
- 权重采用 按通道(per-channel) 量化;
- 激活值采用 按张量(per-tensor) 量化;
- 对于后训练量化(PTQ),主要针对模型权重进行量化。
2.2 核心发现
权重、激活值和注意力的量化效果是独立且可乘的,这一点非常关键。
论文提出了“有效参数量 Neff effective parameter count)”的概念。简而言之, Neff 代表了模型在低精度下的“真实有效”参数量。在低精度训练时,模型的实际参数量 N会被折减为较低的 Neff ,这有助于评估模型在低精度量化下的性能损失。
基本形式:
完整形式(全面量化):
其中:
- N:模型的实际参数量;
- Pw :权重精度;
- Pa:激活值精度;
- Pkv :注意力计算精度;
- γw、γa、γkv :各部分的敏感度系数,反映了模型对不同量化精度的适应性。
举个例子,在相同的计算预算下,有两种方案:
- a) 使用 16-bit 精度训练较小的模型;
- b) 使用 8-bit 精度训练较大的模型(参数量约为前者的 2 倍)。
根据论文的 Neff 分析,第二种方案通常更优,因为:
- 增加的参数量带来的性能提升超过了精度降低造成的损失;
- 8-bit 精度已接近论文中发现的计算最优精度(7-8 bits);
- 低精度训练可以在相同的计算预算下处理更多的数据。
最优训练精度的计算:论文发现,在一般情况下,最优的训练精度为 7-8 bits。这意味着当前常用的 16-bit(BF16)训练精度其实存在冗余。但如果追求极低精度(例如 4-bit 以下),则需要不成比例地增加模型大小才能维持性能。
但是,如果模型大小被固定(例如受限于硬件资源),情况会有所不同:
- 此时,最优训练精度会随着训练数据量的增加而提高。具体来说,最优精度与训练数据量和参数量的比值成对数关系,即:
最优精度训练数据量参数量最优精度∝log(训练数据量参数量)(见论文 Section 4.3.3)
2.3 训练成本分析
训练成本的计算公式如下:
其中:
- C:计算成本;
- N :模型参数量;
- D :训练 token 数;
- P :训练精度;
- 6/16:标准化系数(基于 Chinchilla 成本模型)。
这意味着什么? 举个例子:假设你的计算预算是固定的,希望训练一个模型,有两种选择:
- 使用 16-bit 精度训练一个 35B 参数量的模型;
- 使用 8-bit 精度训练一个 70B 参数量的模型。
根据论文的发现,第二种方案可能更优,因为增加的参数量带来的性能提升超过了精度降低带来的损失。
2.4 实践建议
如果计算预算有限:
- 优先选择 7-8 bit 的训练精度,并利用节省下来的资源增加模型参数量;
- 避免使用低于 4-bit 的训练精度,因为这需要大幅增加模型大小才能维持性能(见论文 Section 4.3.2)。
如果模型大小受限:
- 在需要处理更大量数据时,提高训练精度;
- 例如,当 token/parameter 比率超过 1000 时,建议使用 8-bit 以上的精度;
- 在高 token/parameter 比率下,避免使用低于 6-bit 的训练精度(见论文 Section 4.3.3)。
各部分的精度选择:
- 权重(Weights)在极低精度(3-bit)下仍能保持稳定;
- 激活值(Activations)和注意力计算(KV-cache)在低于 3-bit 时可能会出现不稳定;
- 这种差异可能与量化方式有关(权重采用按通道量化,激活值采用按张量量化),而不一定是固有特性。
3.限制与未来研究方向
3.1 固定的模型架构
这篇论文采用了固定的Transformer++架构,以便在一个可控的环境中分析精度、参数量和数据量之间的关系。然而,在实际应用中,低精度训练通常会伴随着模型架构的调整。例如,一些先进的低精度训练方法可能会引入特殊的正则化技术或优化策略,以减轻低精度带来的负面影响。因此,论文的结论主要适用于固定架构的情况,尚未在经过优化的低精度架构中进行验证。
3.2 计算成本与系统开销
虽然理论上,降低训练精度(比如从16-bit降到8-bit)可以按比例减少计算需求,但在实际操作中,由于系统开销和硬件实现的限制,精度降低所带来的性能提升通常低于理论预期。例如,某些硬件可能无法高效支持极低精度(如4-bit以下)的计算,导致实际的加速效果有限。此外,不同精度下的数据移动和存储优化表现也可能有所不同,这进一步影响了低精度训练的实际效率。
3.3 仅关注验证损失,缺乏下游任务评估
论文主要关注于训练过程中的验证损失(validation loss)作为性能评估指标,而没有对下游任务的具体表现进行评估。尽管验证损失是衡量模型性能的重要指标,但不同任务对模型精度和量化的敏感性可能存在差异。
3.4 实验规模的限制
虽然论文中训练了多达17亿(17B)参数的模型,并使用了高达26B tokens的数据集,但这些规模相对较小,与当前最先进的大规模语言模型(如数百亿甚至千亿参数级别)相比仍有差距。因此,论文的scaling law在更大规模模型上的适用性尚未得到验证。
4. 量化方法的多样性
这篇论文主要关注于整数类型的量化方法,并通过GPTQ、AWQ和RTN等方法进行了验证。然而,浮点类型的量化方法(如FP8、FP4)在实际应用中也具有重要意义,尤其是在某些硬件平台上具有更好的支持和性能表现。不同量化方法在引入量化噪声和影响模型性能方面可能存在显著差异,因此,未来的研究应涵盖更多种类的量化方法,以全面理解量化对模型性能的影响。
5. 数据集和训练策略的单一性
这篇论文使用了Dolma V1.7数据集,并采用了特定的训练策略和超参数设置。不同的数据集和训练策略可能会影响模型对量化的敏感性。例如,某些数据集可能具有更高的复杂性或多样性,导致模型在低精度下表现出不同的鲁棒性。因此,未来的研究应在多样化的数据集和训练配置下进行,以验证缩放规律的普适性。
#图解OpenRLHF中基于Ray的分布式训练流程
本文详细分析了OpenRLHF中基于Ray的分布式训练流程。
本文着重分析OpenRLHF中的PPO-Ray训练架构设计,没有使用过Ray的朋友也可以通过本文快速上手,本文共分成四块:
1. 为什么用Ray
2. 使用图例抽象出整体训练流程
3. Ray核心知识速过
4. 使用图例,进一步抽象出核心代码细节,包括:
- 训练入口
- 部署PPO-Actor/Ref/Critic/RM实例
- 部署vllm_engines实例
- PPO-Actor与vllm_engines之间的通讯
- PPO-Actor/Critic训练
一、为什么要使用Ray
对于通常的rlhf框架,在训练时会在单卡上同时部署actor/ref/reward/critic四类模型,这种单一的部署方式可能存在如下问题:
- 难以突破单卡显存的限制。
- 无法实现更多的并行计算。例如在收集exp阶段,拿到(prompt, responses)结果的四类模型其实可以做并行推理;在训练阶段,拿到exp的actor和critic也可以做并行训练。但受到单卡显存等因素影响,通常的rlhf框架中使用更多的是串行。
- 无法独立优化训练和推理过程。诸如vllm之类的框架,是可以用来提升actor生成(prompt, responses)的速度的,而对于其它模型,我们也可能会视算法需要有不同的推理需求。因此我们期望能更加灵活地设计训练、推理过程
而解决以上问题,需要开发者能设计一套较为灵活的分布式计算框架,能够实现资源定制化分配、分布式调度、节点内外通信等目标,同时相关的代码不能太复杂,能够让使用者更专注于算法部分的研发。而Ray天然可以帮我们做这件事:我们只需提供自己的资源分配方案,告诉Ray我想怎么部署这些模型,不管是分开还是独立部署Ray都可以帮我们实现。而复杂的调度策略和通信等事项,就由Ray在后台去做,我们无需关心这个过程。
二、整体流程
本节我们将提供2个例子,帮助大家更好理解使用Ray可以做什么样的“定制化”部署。注意,这些例子只做讲解用,不代表它们一定是训练的最优配置。
2.1 非共同部署
这个例子展示如何完全独立部署各个模型。假设我们有3台node,每台node 8张卡。以下展示其中一种可行的部署方式:
(1)部署4类模型
在这个例子中,4类模型分开部署在node0和node1上。以Actor为例,它分布在“node0的gpu0/1 + node1的gpu0/1”上。这一点是由Ray实现的:我们自己定制化资源分配的方案,进而管控模型的分配方式
而当实际训练时,我们还可进一步引入Deepspeed zero做优化:以Actor为例,上图中的4个Actor构成zero中的数据并行组(world_size = 4),根据zero的配置,我们可以在这4张卡间做optimizer/gradients/weights的切片。
(2)部署vllm_engines
前文说过,对于Actor模型,在收集exp阶段我们可以采用vllm之类的框架加速(prompt, responses)的生成。在这个例子中:
- 1个vllm_engine维护着一个vllm实例,每个vllm实例下维护一个完整的Actor模型,这里我们还假设一个vllm实例按tp_size = 2的方法切割模型。
- 在node2中,共有4个vllm_engines(也即4个vllm实例),这种分配方式是通过Ray实现的。而每个vllm实例内的分布式推理则是由vllm自己管控。
(3)Actor与vllm_engines之间的通讯
我们称:
- vllm_engines中的actor为vllm_actor
- node0/1中的actor为ds_actor
在整个训练过程中,vllm_actor需要和ds_actor保持权重一致。我们来看这个一致性是如何维护的:
1. 初始化阶段
假设pretrain路径下存储着sft模型,当我们首次开始训练时,ds_actor和vllm_actor都直接从pretrain中加载权重,两者互不影响,独立加载。
2. 训练中
在1个step结束后,ds_actor需要把更新后的权重broadcast给vllm_actor,具体步骤如下:
- 首先,对
ds_rank0 + all_vllm_ranks
创建一个通讯组。在本例中:
- node0/gpu0上的actor是ds_rank0
- node2中所有的gpu构成all_vllm_ranks。
- 我们就是把这两者纳入一个通讯组内,这个通讯组的world_size = 9。如果我们多一台node3来做vllm_engines,那么这个通讯组的world_size = 19,以此类推。
- 若我们使用ds_zero1/2,则ds_rank0上维护的是完整的actor权重,我们把ds_rank0上的权重broadcast到每一个vllm_rank,如有设置tp,vllm会自动帮我们完整接下来的模型切割。
- 若我们使用ds_zero3,则ds_rank0上只维护部分actor权重,那么:
- ds_rank0先从ds_actor组内all gather回完整的模型权重
- 再将完整的模型权重brocast给每一个vllm_rank
3. 从检查点恢复训练(load_checkpoint)
当我们需要从检查点恢复训练时,ds_actor会负责把检查点权重broadcast给vllm_actor,方式同2。
(4)整体运作流程
结合2.1开头的图例,我们来简述一下整体运作流程。
- 首先明确一些表达。例如,
node0中的Actor0/1 + node1中的Actor0/1
属于相同的数据并行组,所以接下来我们会用它们在dp组中的rank来描述它们,也就是分别改称Actor0/1/2/3。对于其余三类模型也是同理。 - 接着进行分组:
-
Actor0 / Ref0 / RM0 / Critic0 / vllm_engine0为一组
-
Actor1 / Ref1 / RM1 / Critic1 / vllm_engine1为一组
-
Actor2 / Ref2 / RM2 / Critic2 / vllm_engine2为一组
-
Actor3 / Ref3 / RM3 / Critic3 / vllm_engine3为一组
- 你可以把每一组想象成原来的一张单卡,那么它的作用就是负责一个micro_batch的训练,这样我们就能大致想象到它们之间是如何配合运作的了。需要注意的是,在我们的例子中,这些实例都是一一对应的(各自有4个实例),但在实际操作中,根据不同用户的资源配置,不一定存在这个一一对应的关系。例如你可能用4卡部署Actor,2卡部署Critic,8个vllm_engines...以此类推。不管怎样,我们应该尽量在处理micro_bathes的各个组间均匀分配负载,在代码里相关的操作如下:
1.为每个actor分配其对应的critic/reward/ref,并启动每个分组的训练:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ray/launcher.py#L278-L299 2.为每个actor分配对应的vllm_engine,并使用vllm_engine进行推理:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ppo_utils/experience_maker.py#L627
2.2 共同部署
同样,我们可以按照自己的需求,选择性地在单卡上部署不同种类的模型,例如下面的例子中,actor/ref共部署,critic/remote共部署,图例如下,运作流程和2.1相似,这里不赘述:
三、Ray的核心概念
在传统的编程中,我们经常使用到2个核心概念:function和class。而在分布式系统中,我们希望可以分布式并行执行这些function和class。Ray使用装饰器@ray.remote来将function包装成Ray task,将class包装成Ray actor,包装过后的结果可以在远程并行执行。接下来我们就来细看task/actor,请大家特别关注代码中的注释
(注意,这里的actor是ray中的概念,不是rlhf-ppo中actor模型的概念)
3.1 Ray Task
import ray
ray.init()
@ray.remote
def f(x):
return x * x
# ===================================================================
# 创建driver进程,运行main
# ===================================================================
if __name__ == "__main__":
# ===================================================================
# 创建4个worker进程,可以在远端并行执行。
# 每执行1次f.remote(i),会发生以下事情:
# - 创建1个worker进程,它将在远端执行函数f(i)
# - 在driver进程上立刻返回一个引用(feature),该引用指向f(i)远程计算的结果
# ===================================================================
futures = [f.remote(i) for i in range(4)]
# ===================================================================
# 阻塞/同步操作:等待4个worker进程全部计算完毕
# ===================================================================
results = ray.get(futures))
# ===================================================================
# 确保全部计算完毕后,在driver进程上print结果
# ===================================================================
print(f"The final result is: {results}") # [0, 1, 4, 9]
3.2 Ray Actor
import ray
ray.init()
@ray.remote
class Counter(object):
def __init__(self):
self.x = 0
def inc(self):
self.x += 1
def get_value(self):
return self.x
# ===================================================================
# 创建driver进程,运行main
# ===================================================================
if __name__ == "__main__":
# ===================================================================
# 创建1个worker进程,具体做了以下事情:
# - 在远端创建Counter实例
# - 在driver端即刻返回对该实例的引用c(称为actor handle)
# - 我们可以在Ray集群的任何结点上传递和使用这个actor handle。即在任何地方,
# 我们可以通过c来invoke对应Counter实例下的各种方法
# ===================================================================
c = Counter.remote()
# ===================================================================
# 阻塞/同步:通过c来invoke远端Counter实例的get_value()方法,并确保方法执行完毕。
# 执行完毕后才能接着执行driver进程上剩下的代码操作
# ===================================================================
print(ray.get(c.get_value.remote())) # 0
# ===================================================================
# Increment the counter twice and check the value again.
# 道理同上,不赘述
# ===================================================================
c.inc.remote()
c.inc.remote()
print(ray.get(c.get_value.remote())) # 2
3.3 Ray cluster架构简图
现在我们已经通过以上例子对Ray运作原理有了一些基本感知,我们来进一步探索一个ray cluster的组成:
- 在一个ray cluster中,会有一台head node和若干worker node
- Driver process是一种特殊的worker process,它一般负责执行top-level application(例如python中的
__main__
),它负责提交想要执行的任务,但却不负责实际执行它们。理论上driver process可以运行在任何一台node内,但默认创建在head node内。 - Worker process负责实际任务的执行(执行Ray Task或Ray Actor中的方法)。
- 每台node中还有一个Raylet process,它负责管控每台node的调度器和共享资源的分配。
- Head node中的GCS将会负责维护整个ray cluster的相关服务。
四、代码细节
本章将解读更多代码实践上的重要细节。我们通过图例的方式抽象出代码运行的过程,而具体代码可参考文中给出的相关链接
4.1 训练入口
ppo_ray相关的训练入口在:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/cli/train_ppo_ray.py。
在main中我们启动了driver进程,并执行训练函数train(args),这里主要做了如下几件事:
- 在ray集群上部署Actor/Ref/Critic/RM实例
- 在ray集群上部署vllm_engines实例
- 训练Actor和Critic模型
我们依次来解读这三个步骤。同时为了在表述上消除歧义,我们接下来谈到“Actor”时,会使用Ray-Actor和PPO-Actor来做区分,从之前的介绍中可知,Ray-Actor是指部署在Ray集群中的远端class,PPO-Actor/Ref/Critic/RM都属于Ray-Actor。
4.2 部署Actor/Ref/Critic/RM实例(1)非共同部署
针对图2.1的情况,我们以PPO-Actor为例,看代码是如何将其部署到Ray集群上的。
-
PPORayActorGroup
:创建在driver进程上,可将它理解成一种部署方案,专门负责部署PPO中的4类模型。 -
PPORayActorGroup
中维护着self._actor_handlers
,它是一个List[ray.actor.ActorHandle
],列表中每个元素表示某个远端Ray-Actor的引用,而这个远端Ray-Actor可以是PPO-Actor/Ref/Critic/RM实例。如前文所说,我们可以在ray集群中的任何位置调用这个handler,来对相应的远端Ray-Actor执行操作。 - 在本例中,我们创建了4个Ray-Actor(1个master-actor,3个worker_actor)。每个Ray-Actor都运行在一个worker进程中。在创建Ray-Actor的同时,我们也会去修改worker进程的环境变量。后续当我们在这些worker进程中启动ds_zero相关的分布式配置时,ds会读取这些环境变量信息,这样我们就知道哪些Ray-Actor同时由构成ds中的数据并行组。
- 使用
PPORayActorGroup
部署模型实例的代码如下:
model = PPORayActorGroup(
# 为部署该模型的全部实例,我们想用多少台node,例如本例中为2
args.actor_num_nodes,
# 为部署该模型的全部实例,我们每台node上想用多少gpu,例如本例中为2
args.actor_num_gpus_per_node,
# Actor/Critic/Reward/ReferenceRayActor
ActorModelRayActor,
# pg可理解为,在ray cluster中锁定/预留一片资源,然后只在这片资源上部署该模型全部实例。
# (pg维护在Head Node的GCS上,参见3.3)
# 例如本例中,pg锁定的资源为node0 gpu0/1, node1 gpu0/1,
# 我们只在上面部署ActorModelRayActor全部实例
pg=pg,
# 当我们在pg指向的预留资源中分配模型实例时,再进一步指定每个实例占据一张gpu的多少部分
# 等于1说明每个实例占满一张gpu,即“非共同部署”
# 小于1说明每个实例只占部分gpu,即“共同部署”,例如PPO-Actor/Ref共同部署在一张卡上
num_gpus_per_actor=0.75 if pg else 1,
)
-
ActorModelRayActor
:创建在远端worker进程上,是Ray-Actor。它包含了设置ds_zero分布式环境、加载模型权重、数据集准备、optimizer/scheduler准备、训练等一系列操作。
PPORayActorGroup代码参见:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ray/launcher.py#L143根据这份代码,大家可自行去找Actor/Critic/Reward/ReferenceRayActor的相关实现。
(2)共同部署
针对图2.2的情况,我们以PPO-Actor为例,看代码是如何将其部署到Ray集群上的。
-
PPORayActorGroup
:在driver进程上创建2个PPORayActorGroup,分别管理PPO-Actor,PPO-Ref的部署 - 使用
actor_model = PPORayActorGroup(..., pg = pg, num_gpus_per_actor=0.75)
创建PPO-Actor部署方案实例;使用ref_model = PPORayActorGroup(..., pg = pg, num_gpus_per_actor=0.25)
创建PPO-Ref部署方案实例 - 这里,两个方案实例使用的pg都是同一个,即这个pg都指向“1台node,每台node 8张卡”这片预留好的资源。
- num_gpus_per_actor = 0.75/0.25是一种创建trick,虽然我们的最终目的是为了让PPO-Actor和PPO-Ref对半分一张卡,但是:
- 假设设置为0.5,当我们实际部署ActorModelRayActor时,Ray先在单卡上部署1个ActorModelRayActor实例,当它准备部署第二个ActorModelRayActor实例时,它发现由于每个实例只占0.5块卡,因此完全可以把第二个实例接着第一个实例部署,这样就导致最终无法让PPO-Actor和PPO-Ref共享一张卡
- 假设设置0.75,当我们在单卡上部署完1个ActorModelRayActor实例后,ray发现单卡剩下的空间不足以部署第2个ActorModelRayActor实例,所以就会把第二个实例部署到别的卡上,这样最终实现PPO-Actor和PPO-Ref共享一张卡
- 所以,这个设置是为了达到不同类型模型的实例共享一张卡的目的,而并非真正指模型实际占据的单卡显存空间。
- 最后,在这一步中,我们对全部ActorModelRayActor共创建8个worker进程,对全部RefenreceModelRayActor共创建8个worker进程,一共创建16个工作进程。
相关代码依然在:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ray/launcher.py#L143
4.3 部署vllm_engines实例
-
create_vllm_engines
:在driver端,我们通过运行该函数来创建vllm_engines,过程相似于4.2节中的介绍,信息都在图中,这里不赘述。 -
LLMRayActor
:worker端Ray-Actor,它主要是把vllm实例进行了一些包装,包装的目的是为了让ds_rank0和all vllm ranks间可以进行PPO-Actor的权重通讯(参见2.1(3)) - 在上面的例子中,我们会创建4个worker进程,用于运行管理4个vllm_engine。在每个worker进程内,vllm实例还会创建属于自己的worker进程做分布式运行。
相关代码参见:
4.4 ds_rank0与vllm_ranks之间的通讯
在2.2中,我们说过,PPO-Actor的ds_rank0需要和all_vllm_ranks进行通讯,传递最新的PPO-Actor权重,例如以下ds_rank0要把完整的权重broadcast给16个vllm_ranks:
我们分成如下几步实现这个目标:
(1)创建通信组
如上图所示,创建通信组实际包含了2步。
Step1:
代码来自:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ray/ppo_actor.py#L58
这段代码执行在PPO-Actor0(ds_rank0)所在的worker进程中。这个worker进程将通过handler引用,触发远端每个vllm_engine上的init_process_group操作,并将ds_rank0纳入通讯组
# Create torch group with deepspeed rank 0 and all vllm ranks
# to update vllm engine's weights after each training stage.
#
# Say we have 3 vllm engines and eache of them has 4 GPUs,
# then the torch group is:
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# |ds rank 0 | engine-0 | engine-1 | engine-2 |
#
# For ZeRO-1/2:
# 1. Broadcast parameters from rank 0 to all vllm engines
# For ZeRO-3:
# 1. AllGather paramters to rank 0
# 2. Broadcast parameters from rank 0 to all vllm engines
if self.vllm_engines is not None and torch.distributed.get_rank() == 0:
...
# world_size = num_of_all_vllm_ranks + 1 ds_rank0
world_size = vllm_num_engines * vllm_tensor_parallel_size + 1
...
# =====================================================================
# 遍历每个vllm_engines,将其下的每个vllm_rank添加进通讯组中,这里又分成两步:
# 1. engine.init_process_group.remote(...):
# 首先,触发远程vllm_engine的init_process_group方法
# 2. 远程vllm_engine是一个包装过的vllm实例,它的init_process_group
# 方法将进一步触发这个vllm实例下的各个worker进程(见4.4图例),
# 最终是在这些worker进程上执行“将每个vllm_rank"添加进ds_rank0通讯组的工作
# =====================================================================
refs = [
engine.init_process_group.remote(
# ds_rank0所在node addr
master_address,
# ds_rank0所在node port
master_port,
# 该vllm_engine的第一个rank在"ds_rank0 + all_vllm_ranks“中的global_rank,
# 该值将作为一个offset,以该值为起点,可以推算出该vllm_engine中其余vllm_rank的global_rank
i * vllm_tensor_parallel_size + 1,
world_size,
"openrlhf",
backend=backend,
)
for i, engine in enumerate(self.vllm_engines)
]
# =====================================================================
# 将ds_rank0添加进通讯组中
# =====================================================================
self._model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=0,
group_name="openrlhf",
)
# =====================================================================
# 确保all_vllm_ranks都已添加进通讯组中
# =====================================================================
ray.get(refs)
Step2:
代码来自:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ray/vllm_worker_wrap.py#L11
这段代码实际运行在每个vllm_engine(即每个包装后的vllm实例)下的worker进程内。例如tp_size=2,那么每个vllm实例下就有2个worker进程,这两个worker进程都会运行这段代码。
class WorkerWrap(Worker):
def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl"):
"""Init torch process group for model weights update"""
assert torch.distributed.is_initialized(), f"default torch process group must be initialized"
assert group_name != "", f"group name must not be empty"
# =====================================================================
# torch.distributed.get_rank(): 在当前vllm_engine内部的rank,
# 例如在tp_size = 2时,这个值要么是0,要么是1
# rank_offset:当前vllm_engine中的第一个rank在“ds_rank0 + all_vllm_ranks"中的global_rank
# 两者相加:最终得到当前rank在“ds_rank0 + all_vllm_ranks"中的global_rank
# =====================================================================
rank = torch.distributed.get_rank() + rank_offset
self._model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=rank,
group_name=group_name,
)
...
(2)_broadcast_to_vllm
构建好通讯组,我们就可以从ds_rank0广播PPO-Actor权重到all_vllm_ranks上了,这里也分成两步。
Step1:PPO-Actor ds_rank0发送权重
代码在:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ray/ppo_actor.py#L146
这段代码运行在ds_rank0对应的worker进程中
def _broadcast_to_vllm(self):
# avoid OOM
torch.cuda.empty_cache()
model = self.actor.model.module
count, num_params = 0, len(list(model.named_parameters()))
for name, param in model.named_parameters():
count += 1 # empty_cache at last param
# Fire all vllm engines for broadcast
if torch.distributed.get_rank() == 0:
shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
refs = [
# 远端vllm_engine的每个rank上,初始化一个尺寸为shape的empty weight张量,
# 用于接收广播而来的权重
engine.update_weight.remote(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params)
for engine in self.vllm_engines
]
# For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
# ds_rank0发出权重(视是否使用zero3决定在发出前是否要做all-gather)
with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
if torch.distributed.get_rank() == 0:
torch.distributed.broadcast(param.data, 0, group=self._model_update_group)
ray.get(refs) # 确保所有vllm_ranks接收权重完毕
Step2: 各个vllm_ranks接收权重
代码在:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ray/vllm_worker_wrap.py#L29
代码运行在每个vllm_engine(即每个包装后的vllm实例)下的各个worker进程中。例如tp_size = 2,那么每个vllm实例下有2个worker进程,这2个worker进程都会运行这段代码。
def update_weight(self, name, dtype, shape, empty_cache=False):
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
if torch.distributed.get_rank() == 0:
print(f"update weight: {name}, dtype: {dtype}, shape: {shape}")
assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}"
# 创建同尺寸空张量用于接收ds_rank0广播来的权重
weight = torch.empty(shape, dtype=dtype, device="cuda")
# 接收权重
torch.distributed.broadcast(weight, 0, group=self._model_update_group)
# 使用接收到的权重进行更新
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
4.5 PPO-Actor/Critic Training
正如2.1(4)中所说,我们将部署在ray集群上的PPO-Actor/Ref/Critic/RM实例们进行分组,每组分别负责一份micro-batch的训练,上图刻画了某个组内的训练流程。一组内的训练流程发起自PPO-Actor实例(fit方法),共分成如下步骤执行。
Step1:发送prompts,并从vllm_engine上收集(prompt, response)。
代码参见:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ppo_utils/experience_maker.py#L627
Step2:从Ref/Reward/Critic上收集并处理exps。
代码参见:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ppo_utils/experience_maker.py#L492
Step3: 确保将处理后的exps传送给Critic,并行执行Actor和Critic的训练
- 将exps传送给Critic:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ppo_utils/experience_maker.py#L470
- Actor训练:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ray/ppo_actor.py#L125
- Critic训练:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ray/ppo_actor.py#L122
我们在Actor实例所在的worker进程上出发Actor和Critic的训练。以上代码只给出了训练入口,更多细节需要顺着入口去阅读。Step4:vllm_engine权重更新。
代码参见:https://github.com/OpenRLHF/OpenRLHF/blob/bb46342711a203c457df2fbca5967fd0549557e0/openrlhf/trainer/ray/ppo_actor.py#L130
五、参考
1、OpenRLHF:https://github.com/OpenRLHF/OpenRLHF
2、Ray official architecture whitepaper: https://docs.google.com/document/d/1tBw9A4j62ruI5omIJbMxly-la5w4q_TjyJgJL_jN2fI/preview?tab=t.0#heading=h.iyrm5j2gcdoq
(建议想看ray架构的朋友,直接看这个最新的官方白皮书,不要看2018年的那篇paper了,那个比较老了)
3、Ray official document:https://docs.ray.io/en/latest/index.html
4、推荐一篇快速了解Ray应用层核心概念的blog:https://towardsdatascience.com/modern-parallel-and-distributed-python-a-quick-tutorial-on-ray-99f8d70369b8
5、Ray:https://github.com/ray-project/ray
6、vllm: https://github.com/vllm-project/vllm
#Qwen2.5思维链微调代码实操 + 多卡Lora微调完整代码
最近对于Scaling Law的讨论异常火热。包括ilya大神自己都下场演讲关于大模型数据规模碰壁的问题(参考:机器之心官网发文)。直觉上,现在大模型思维的过程更像是人对一件事情直觉的反应,而不是多步思考和迭代思考的过程。正如下图ilya的PPT中的一张图,10层神经网络可以干人在0.1秒干的事情。而现在大模型上十亿的参数也可能只是解决人经过一分钟思考的回答。像OpenAI o1或者强化对齐可能是通往AGI的方法之一。刚好趁这个机会尝试一下一直没有进行的思维链微调。下面简单介绍一下思维链技术,并且使用阿里通义千问进行CoT数据微调并且简单测试一下。
网上关于思维链微调的实操比较少,甚至对于Qwen的指令微调高质量的文章都不多,许多细节都描述的不清楚,希望这篇文章能够进一步帮助到读者微调Qwen时能够关注到一些细节。
这里感谢魔乐社区赞助了华为昇腾910卡进行微调。尝试了下国产卡做微调的效果还是非常不错!本篇教程专门做了openMind Library的适配,兼容华为昇腾910卡。
友情链接:
- 魔乐社区
- Qwen2.5模型
- SwanLab训练跟踪工具
思维链技术介绍
思维链技术(Chain of Thought,也简称为CoT),最早由Json Wei等人在Chain-of-Thought Prompting Elicits Reasoning in Large Language Models文章提出。简单来说就是通过提示词让模型能够将一个复杂的问题分步思考。比如举个文章中提到的例子(见下图),一个数学问题是:
食堂有 23 个苹果。如果他们用掉了 20 个来做午餐,又买了 6 个,现在他们有多少个苹果?
对于一个人类,他的思考步骤是:
- 食堂有23个苹果,用了20个,所以是23-20=3
- 又买了6个,所以是3+6=9
- 共有9个苹果
当然这个思维过程还能猜的更碎。比如上面的过程中第一个实际上蕴涵了“因为食堂有23个苹果,3-20=3”两个步骤。对于进行了“指令微调”的模型来说,更倾向于简短的回答入,比如直接回答“他现在有XX个苹果”,而且对于一个需要多步计算的数学题往往是错误的。CoT技术的主要目标就是通过提示词让模型一步一步来,像上面的思考步骤那样要求模型不仅回答问题,同时还将问题的生成过程写出来。
Json Wei的这篇文章的工作是在提示词上做的(文中分了few-shot和zero-shot两种方式,简单来说就是给样例和不给样例),用学术些的话来说就是“上下文学习”。这篇文章的实验部分证明了CoT确实能有效提升LLM的推理能力,尤其是数学任务。当然很多人一下就想到了,我能否用微调的方式直接将这种“一步步思考”的能力直接微调到模型中呢?实际上Json Wei大神也很快想到了,所以在紧接着下一篇Scaling Instruction-Finetuned Language Models、Google的FLAN数据集改进版FLAN PaLM中直接引入了CoT数据集,这篇Json Wei虽然不是一作,但也是参与者之一。下图展示了FLAN PaLM所使用微调数据集,和对CoT部分的消融实验,很明显CoT使得模型能够具备更强的zero-shot能力。(建议读一下原文,CoT部分在第四章节,我个人觉得文章中的消融实验不足以直接证明CoT数据对于模型的性能提升,但是目前看来大家在构建finetune数据时还是倾向于将CoT数据构建进去,起码能够显著提升模型使用CoT提示词的能力)。
Qwen2.5思维链微调实操CoT数据集准备
本文使用的是早年收集的一个CoT英文数据集,从质量判断是利用GPT4+提示词生成的,因为时间太久了实在是忘了出处无法附上链接,在这里跟原作者抱歉了orz。现在网上已经有许多高质量的CoT数据集,包括一些数据集专门区分了思考中间过程与返回结果,以及一些将思维步骤使用Special Token进行分隔的。
本文考虑到仅仅是练手项目,选用了最基础的数据集,并且使用经过指令微调的Qwen2.5-7B将英文数据集进一步翻译成了中文+整理成了标准的Instruction-Input-Output样式,以方便对Qwen模型的微调和练手学习。注意这个数据集仅仅是练手,LLM翻译的质量并不高,很容易被要翻译的内容带偏了。下图展示了部分数据集和样式:
{
"instruction": "这家宠物店一共有多少只啮齿动物?宠物店里有6个笼子装着啮齿动物。其中3个笼子里各有10只仓鼠,另外3个笼子里各有5只豚鼠。那么这家宠物店一共养了多少只啮齿动物呢?",
"input": "",
"output": "让我们一步步来思考。\n\n有3个笼子,每个笼子里有10只豚鼠。\n这意味着总共有30只豚鼠。\n\n还有3个笼子,每个笼子里有5只豚兔。\n这意味着总共有15只豚兔。\n\n现在我们可以把30和15加起来得到答案:\n\n总共有45只啮齿动物。\n因此,答案(阿拉伯数字)是45。"
}
可以使用如下链接直接下载测试数据集
- 数据集下载链接
环境安装昇腾NPU + openMind Library环境安装
国产卡是未来,这里先放昇腾NPU环境安装!
使用昇腾NPU的话推荐在魔乐社区中找模型,里面能找到完成NPU适配的模型。魔乐社区使用的是openMind Library工具包,这个包支持在Nvidia GPU和Ascend NPU上运行,使用起来和transfomers接口一致。如果说做昇腾NPU迁移的话非常推荐使用。
魔乐社区的模型分为MindSpore支持和Pytorch-NPU支持,这里主要看本地装什么环境,考虑到新手学习的话推荐使用Pytorch-NPU,和Pytorch逻辑基本一致。
驱动安装&验证
首先得确定有NPU卡和NPU相关驱动,驱动是8.0.RC3.beta1,具体可以参考软件安装-CANN商用版8.0.RC3开发文档-昇腾社区。
安装好后的验证方法是运行下面的命令,该命令作用与nvidia-smi类似,这里是查看NPU的状态和性能
npu-smi info
可以看到如下信息的话就表示驱动已经安装完成了,左侧是安装成功后运行代码后的结果,右侧是每一部分的含义。
openMind环境搭建
openMind环境安装比较简单,这边列出所需用到的全部安装命令:
# 下载PyTorch安装包
wget https://download.pytorch.org/whl/cpu/torch-2.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
# 下载torch_npu插件包
wget https://gitee.com/ascend/pytorch/releases/download/v6.0.rc3-pytorch2.4.0/torch_npu-2.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
# 安装命令
pip3 install torch-2.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
pip3 install torch_npu-2.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
# 安装openMind Library
pip install openmind[pt]
pip install transformers accelerate datasets peft # 部分场景会用到hf几个包,干脆全装了
# 安装SwanLab
pip install swanlab
Nvidia GPU + Transformers环境安装
这个流程比较简答,首先也是得确保Nvidia驱动存在,验证命令:
nvida-smi
如果没显示同样需要先安装cuda环境,这里贴上CUDA官方安装链接
网上有大量cuda安装安装教程,这里笔者就不赘述了。同样放出transformers环境安装的全部命令:
pip install torch
pip install transformers accelerate datasets peft
# 安装SwanLab
pip install swanlab
关于提示词模版构建(大坑)
这里需要强调一下,在使用Qwen2.5的Instruct模型微调时,为了保障效果建议严格按照模型自身的Instruct的提示词模版构建。HF Transformers在4.3几的版本开始支持Chat Templates。Qwen2.5关于Instruct和Chat的提示词模版被直接写到了tokenziers的设置保存中,这导致了很多人在原始代码中找不到instruct提示词格式的构造。很多教程在教微调的时候还用的是Qwen1的老提示词模版或者自己构建的提示词模版,这会严重影响使用已经微调的模型做进一步微调时的效果。建议针对模型微调时一定要仔细检查提示词模版的实现部分。尽量使用模型已经定义好的格式和结构。
可以在Qwen的HF项目中找到提示词模版,点击HF Qwen查看chat_template设置。chat_template默认使用的是一种前端模版语言jinja,并不好看懂,笔者把qwen2.5的提示词模版格式化后粘贴在下文:
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0]['role'] == 'system' %}
{{- messages[0]['content'] }}
{%- else %}
{{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
{%- endif %}
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0]['role'] == 'system' %}
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
{%- else %}
{{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{{- '<|im_start|>' + message.role }}
{%- if message.content %}
{{- '\n' + message.content }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '\n<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{{- tool_call.arguments | tojson }}
{{- '}\n</tool_call>' }}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
可以看到超级长,因为定义了好几种情况,包括是否有system prompt。以及针对function tools怎么处理等等等等。如果读不懂(我感觉大多数搞deep learning的除了做LLM Finetune也很小有机会去学一个前端语言)我建议用大模型给你逐行解释下,这里附上jinja的官方文档
这里笔者简单提供我所使用的Qwen2.5简化版python模版(下脚本),去除了Function Calling和多轮对话的部分。并且只包含对Instruct和Inputs的处理部分,以及Assitants的生成头。这分为带inputs的版本和不带inputs的版本。我自己专门测试了使用此模版构造的提示词长度上和使用Qwen带chat_template的tokenziers完全一致。你只需要将outputs部分增加一个\n<|im_end|>\n
即可直接拼接成finetune LLM模型的targets部分。
PROMPT_DICT = {
"prompt_no_input": """<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n""",
"prompt_input": """<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n""",
}
如果你直接偷懒使用chat_template来tokenizer仅带outputs部分的数据。你会发现由于Qwen的chat template处理机制,实际上生成的outputs部分会默认带上system prompts。导致最后训练阶段会出现奇怪的内容。Qwen的tokenizers针对未增加system角色的对话输入会自动加上如下提示词
system:You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
更神奇的是,这个system prompt居然是个英文的。Qwen可是个中文模型。。。这个system prompt的出现会影响后续的模型微调效果。
可视化工具配置(SwanLab使用教程)
SwanLab可以将微调的许多关键参数自动记录下来并且能够再现可视化查看训练。能够在线或者离线保存+查看训练日志。SwanLab(有可能是唯一的)同时支持记录NVIDIA GPU和华为昇腾NPU设备的日志记录工具。最新版本已经支持对NPU的内存使用、功率、温度等进行记录。甚至还有黑夜模式,方便苦逼研究生大晚上搞科研。:)
关于SwanLab的使用方法可以参考SwanLab官方文档-快速开始
对于Huggingface Transformers或者支持华为昇腾NPU的openMind Library,可以使用SwanLab Integration轻松完成实验数据记录:
...
from swanlab.integration.huggingface import SwanLabCallback
swanlab_call = SwanLabCallback( #
"Ascend_finetune_v2",
experiment_name=os.path.basename(os.path.normpath(training_args.output_dir)),
cnotallow=asdict(data_args)
| asdict(model_args)
| asdict(training_args)
| asdict(lora_config),
public=True,
)
trainer = openmind.Trainer( # 使用hf transformers的话则是把openmind替换为transformers
model=model,
tokenizer=tokenizer,
args=training_args,
callbacks=[swanlab_call], # callback加入进去即可
**data_module,
)
...
使用后不仅能进行多图表对比,更重要的是把一大堆的huggingface transformers的训练超参数全部记录下来了,简直调参党福音。
微调代码(多卡,支持华为Ascend卡)
下面附上完整的微调代码。在项目目录下创建finetune.py
文件,并将如下代码粘贴进文件中
import copy
import os
import io
import json
import logging
from dataclasses import dataclass, field, asdict
from typing import Dict, Optional, Sequence
import torch
from torch.utils.data import Dataset
try:
import openmind as tf_module
except:
import transformers as tf_module
import transformers
from peft import LoraConfig, get_peft_model
from swanlab.integration.huggingface import SwanLabCallback
IGNORE_INDEX = -100
PROMPT_DICT = {
"prompt_no_input": """<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n""",
"prompt_input": """<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n""",
}
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(
default="./weights/Qwen/Qwen2.5-7B-Instruct"
)
@dataclass
class DataArguments:
data_path: str = field(
default="./data/cot_train_cn.jsonl",
metadata={"help": "Path to the training data."},
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
def _tokenize_fn(strings: Sequence[str], tokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncatinotallow=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def jload(f, mode="r", jsnotallow=True):
if not isinstance(f, io.IOBase):
with open(f, mode=mode, encoding="utf-8") as f:
if jsonl:
# Parse JSON Lines
return [json.loads(line) for line in f if line.strip()]
else:
# Parse standard JSON
return json.load(f)
else:
if jsonl:
return [json.loads(line) for line in f if line.strip()]
else:
return json.load(f)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer) for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer):
super(SupervisedDataset, self).__init__()
logging.warning("Loading data...")
list_data_dict = jload(data_path)
logging.warning("Formatting inputs...")
prompt_input, prompt_no_input = (
PROMPT_DICT["prompt_input"],
PROMPT_DICT["prompt_no_input"],
)
sources = [
(
prompt_input.format_map(example)
if example.get("input", "") != ""
else prompt_no_input.format_map(example)
)
for example in list_data_dict
]
targets = [
f"{example['output']}\n{tokenizer.eos_token}\n"
for example in list_data_dict
]
logging.warning("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer)
try:
self.input_ids = data_dict["input_ids"]
except KeyError as e:
raise KeyError("input_ids is invalid") from e
try:
self.labels = data_dict["labels"]
except KeyError as e:
raise KeyError("labels is invalid") from e
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: object
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def make_supervised_data_module(tokenizer, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = SupervisedDataset(
tokenizer=tokenizer, data_path=data_args.data_path
)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
)
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model = tf_module.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
trust_remote_code=True,
)
# 定义LoRA配置
lora_config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1,
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
tokenizer = tf_module.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
trust_remote_code=True,
)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
swanlab_call = SwanLabCallback(
"Ascend_finetune_v2",
experiment_name=os.path.basename(os.path.normpath(training_args.output_dir)),
cnotallow=asdict(data_args)
| asdict(model_args)
| asdict(training_args)
| asdict(lora_config),
public=True,
)
trainer = tf_module.Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
callbacks=[swanlab_call],
**data_module,
)
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
多卡训练的话可以使用torchrun,这里附上一个启动多卡的bash脚本,在当前目录下创建finetune.sh
,并且粘贴如下脚本:
NPU_NUM=${1:-8}
EXP_NAME=$(basename "$0" .sh)
if [ -d ./output ];then
rm -rf ./output/$EXP_NAME
mkdir -p ./output/$EXP_NAME
else
mkdir -p ./output/$EXP_NAME
fi
# master_port参数需用户根据实际情况进行配置
torchrun --nproc_per_node=$NPU_NUM --master_port=20248 finetune.py \
--model_name_or_path "./weights/Qwen/Qwen2.5-7B-Instruct" \
--data_path data/cot_train_cn.jsonl \
--bf16 True \
--output_dir ./output/$EXP_NAME \
--max_steps 2000 \
--per_device_train_batch_size 2 \
--eval_strategy "no" \
--save_strategy "steps" \
--save_steps 3000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--seed 42 \
--logging_steps 10
开启多卡训练的方式如下:
bash finetune.sh <使用的GPU/NPU数量>
如果提示登录swanlab,可以在官网完成注册后,使用获取API KEY找到对应的登陆密钥并粘贴,这样将能够使用云上看版随时查看训练过程与结果。
微调效果(附上Gradio代码)
本来准备了Ceval的测试结果,结果不知道为什么Ascend服务器连不上了,等过段时间更新下教程文档。
这里放出使用CoT数据微调qwen-7b-instruct、qwen-0.5b-instruct和使用qwen-7b-instruct(8NPU)的loss结果。可以看到使用8个NPU能带来更好的训练loss表现和稳定性,哪怕在使用同样迭代数据量的情况下,8个NPU依然能带来更好的loss结果。可能更大的loss有助于模型稳定下降。
最后展现下使用gradio完成的官方Qwen2.5-7B-Instruct、基于Qwen2.5-7B在中文alpaca数据集上指令微调、以及cot微调后的模型回复对比。可以看到CoT微调后模型确实具备了“step by step”的回复模式。
当然许多读者注意到了官方模型也展现出了“step by step”的回答模式,这主要是因为现在较新的模型在finetune数据集甚至pretrain数据集中就会预先加入CoT数据,所以模型在进行问答、尤其是数学题问答时,会展现出“步骤分解”的现象。笔者后续会尝试在较早期的demo中更新微调的
附上启用gradio的demo测试代码:
使用pip install gradi
o安装依赖包
import gradio as gr
from openmind import AutoModelForCausalLM, pipeline
from peft import PeftModel
TOTAL_GPU_NUMS = 8
TOKENIZE_PATH = "~/weightsweights/Qwen/Qwen2.5-7B-Instruct"
MODEL_LIST = {
"office_qwen7b": "~/weights/Qwen/Qwen2.5-7B-Instruct", # 官方模型
"alpaca_qwen7b_lora": "./projects/qwen_finietune_cot/output/qwen25-7B-alpaca", # 7b+alpaca
"cot_qwen7b_lora": "./projects/qwen_finietune_cot/output/qwen25-7Bi-cot", # cot微调
}
model_names = MODEL_LIST.keys()
pipes = dict()
for i, model_name in enumerate(model_names):
save_path = MODEL_LIST[model_name]
model = AutoModelForCausalLM.from_pretrained(save_path)
if model_name[:-5] == "_lora":
model = PeftModel.from_pretrained(model, save_path)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=TOKENIZE_PATH,
framework="pt",
device=f"npu:{i%TOTAL_GPU_NUMS}",
)
pipes[model_name] = pipe
def generate_response(instruct_text, input_text):
messages = [
{
"role": "system",
"content": instruct_text,
},
{
"role": "user",
"content": input_text,
},
]
outputs = [
pipes[model_name](messages, max_new_tokens=256)[-1]["content"]
for model_name in model_names
]
return tuple(outputs)
# 创建 Gradio 界面
demo = gr.Interface(
fn=generate_response, # 函数名
inputs=[
gr.Textbox(label="instruction"),
gr.Textbox(label="input"),
], # 输入文本框
outputs=[gr.Textbox(label=model_name) for model_name in model_names],
)
if __name__ == "__main__":
demo.launch()