论文阅读笔记:Dataset Condensation with Gradient Matching
核心思想一句话总结:
本文提出了一种创新的数据集压缩方法——数据集凝缩(Dataset Condensation,DC),其核心思想是通过梯度匹配(Gradient Matching),将一个大型数据集
T
浓缩成一个极小的、信息量丰富的合成数据集S
。在S
上从头训练的模型,其性能可以逼近在T
上训练的模型,从而极大地节省了存储和训练成本。
1. 解决了什么问题?(Motivation)
- 问题:现代深度学习依赖于大规模数据集,导致存储成本、数据传输宽带和模型训练时间急剧增加。
- 目标:创建一个微型合成数据集
S
,它能作为原始大型数据集T
的高效替代品,用于从零开始训练神经网络
2. 关键方法与创新点 (Key Method & Innovation)
2.1 核心思路的演进:从参数匹配到梯度匹配
- 参数匹配 (Parameter Matching) - 一个被否定的思路
- 想法:直接让
S
训练收敛后的模型参数 θ S \theta_S θS与用T
训练收敛后的 θ T \theta_T θT尽可能接近。 - 缺陷:
- 优化路径复杂:深度网络的参数空间非凸,直接走向目标 θ T \theta_T θT极易陷入局部最优。
- 计算成本高:需要嵌套的双层优化,内循环必须将模型训练至收敛,计算上不可行。
- 想法:直接让
- 梯度匹配 (Gradient Matching) - 本文的核心创新
- 想法:放弃匹配静态的”终点“,转而匹配动态的”过程“。即,确保在训练每一步,模型在合成数据
S
上产生的梯度 ∇ L s ∇L_s ∇Ls在真实数据T
上产生的梯度 ∇ L T ∇L_T ∇LT方向一致。 - 优势:
- 计算高效:通过一个巧妙的近似,极大提高了效率和可扩展性。
- 优化路径清晰:每一步都有明确的监督信号(梯度差异),引导
S
的优化,避免了在复杂空间中盲目搜索。 - 对齐学习动态:保证了模型在
S
上的学习方式与T
上一致,结果更鲁棒。
- 想法:放弃匹配静态的”终点“,转而匹配动态的”过程“。即,确保在训练每一步,模型在合成数据
2.2 算法实现细节 (Implementation Details)
- 课程学习 (Curriculum Learning)
- 为了让合成数据
S
具有泛化性,算法采用了一个”课程学习“的框架。在整个凝缩过程中,会周期性地重新随机初始化网络参数 θ \theta θ。 - 这确保了
S
不会过拟合到某一个特定的网络初始化,而是对多种随机起点都有效。
- 为了让合成数据
- 梯度匹配损失函数(Gradient Matching Loss)
- 使用**余弦距离(1-Cosine Similarity)**来衡量两个梯度的差异。这更关注梯度的方向而非大小,与梯度下降的本质契合。
- 按输出节点分组计算:并非所有层的梯度粗暴地展平,而是按输出神经元分组计算余弦距离,更好地保留了网络结构信息。
- 重要的工程技巧(Practical Tricks)
- BatchNorm层预热与冻结:由于合成数据批次极小,为了避免BN层统计量不稳定,每次迭代前都先用一个较大的真实数据批次来计算并”冻结“BN层的均值和方差。
- 按类别独立匹配:在计算梯度时,按类别独立进行,即用”猫“的合成数据区匹配”猫“的真实数据梯度。这降低了学习难度和内存消耗。
3. 实验结果与贡献 (Experiments & Contributions)
- 性能优越:在CIFAR-10, CIFAR-100, SVHN等数据集上,仅用极少量合成样本(如IPC=1或10),就能训练出性能远超当时其他数据压缩方法的模型。
- 开创性贡献:
- 首次提出了梯度匹配这一高效且可扩展的数据集凝缩范式,为后续大量的研究(如DSA, MTT, FTD等)奠定了基础。
- 成功将数据集凝缩技术应用到了大型网络上,证明了其可行性。
- 展示了其在持续学习和神经架构搜索 (NAS) 等资源受限场景下的巨大潜力。
4.个人思考与启发
- ”过程“比”结果”更重要:这篇论文最精妙的哲学在于,它揭示了在复杂优化问题中,对齐“过程”(梯度)比直接追求“结果”(参数)更有效、更可行。这一思想在很多其他领域也具有启发性。
- 理论与实践的结合:论文不仅提出了一个优雅的理论框架,还通过BN层处理等工程技巧解决了实际应用中的痛点。
主要代码
''' training '''
# 为合成图像image_syn创建一个优化器
# 我们只优化image_syn这个张量,所有优化器只传入它。
# 这里的优化器是SGD,意味着我们会用梯度下降法来更新图像的像素值。
optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
# 清空优化器的梯度缓存
optimizer_img.zero_grad()
# 定义用于计算分类损失的损失函数,这里是标准的交叉熵损失。
criterion = nn.CrossEntropyLoss().to(args.device)
print('%s training begins'%get_time())
# 主迭代循环开始
# 这个循环是整个数据集凝缩过程的核心,总共进行Iteration+1次。
for it in range(args.Iteration+1):
# 评估合成数据(在特定迭代点触发)
''' Evaluate synthetic data '''
if it in eval_it_pool:
for model_eval in model_eval_pool:
# 遍历model_eval_pool中的每一个模型架构,用于评估。
# 这运行我们测试合成数据集在不同模型上的泛化能力。
print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
# 设置评估时的数据增强策略
if args.dsa:
# 如果是DSA方法,使用其特定的增强策略。
args.epoch_eval_train = 1000
args.dc_aug_param = None
print('DSA augmentation strategy: \n', args.dsa_strategy)
print('DSA augmentation parameters: \n', args.dsa_param.__dict__)
else:
# 如果是DC方法,调用 get_daparam 获取专为DC设计的增强参数。
# 注意:这些增强只在评估时使用,在生成合成数据时不用。
args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc) # This augmentation parameter set is only for DC method. It will be muted when args.dsa is True.
print('DC augmentation parameters: \n', args.dc_aug_param)
# 如果在评估时使用了任何数据增强,就需要更多的训练轮数来让模型充分学习。
if args.dsa or args.dc_aug_param['strategy'] != 'none':
args.epoch_eval_train = 1000 # Training with data augmentation needs more epochs.
else:
args.epoch_eval_train = 300
# --- 3.2 执行评估 ---
# 创建一个空列表,用于存储多次评估的准确率
accs = []
# 为了结果的稳定性,我们会用当前的合成数据训练num_eval个独立,随机初始化的模型。
for it_eval in range(args.num_eval):
# 每一次都创建一个全新的、随机初始化的评估网络。
net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
# 深拷贝当前的合成数据和标签,以防止在评估函数中被意外修改。
# detach()是为了确保我们只复制数据,不带计算图。
image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
# 调用核心评估函数 evaluate_synset。
# 这个函数会:
# 1. 拿 image_syn_eval 从头开始训练 net_eval。
# 2. 在训练结束后,用训练好的 net_eval 在真实的测试集 testloader 上进行测试。
# 3. 返回在测试集上的准确率 acc_test。
_, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)
# 打印这次评估的平均准确率和标准差。
print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))
# 如果这是最后一次迭代,将这次评估的所有准确率结果记录到总的实验结果字典中。
if it == args.Iteration: # record the final results
accs_all_exps[model_eval] += accs
# 可视化并保存合成图像
''' visualize and save '''
save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
# 深拷贝合成图像,并移到CPU上进行处理。
image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
# 对图像进行反归一化,以便人眼观察
# 训练时图像通常是归一化的。
# 反归一化公式:pixel = pixel * std + mean
for ch in range(channel):
image_syn_vis[:, ch] = image_syn_vis[:, ch] * std[ch] + mean[ch]
# 将像素值裁剪到[0,1]范围内,防止因浮点数误差导致显示异常。
image_syn_vis[image_syn_vis<0] = 0.0
image_syn_vis[image_syn_vis>1] = 1.0
# 使用torchvision.utils.save_image将合成图像保存为一张网格图。
# nrow=args.ipc表示每行显示ipc张图像。
save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.
# --- 初始化课程学习环境 ---
''' Train synthetic data '''
# 每次主迭代(it)开始,都创建一个全新的、随机初始的网络。
# 这是”课程学习“的关键:确保合成数据对不同的网络初始化方法都有效,而不是过拟合到某一个。
net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
net.train() # 将网络设置为训练模式
# 获取网络的所有可学习参数
net_parameters = list(net.parameters())
# 为这个新网络创建一个优化器,用于在内循环中更新网络参数
optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net) # optimizer_img for synthetic data
optimizer_net.zero_grad()
# 初始化平均损失,用于记录和打印
loss_avg = 0
# 在生成合成数据时,不使用任何数据增强,以与DC论文的设置保持一致
args.dc_aug_param = None # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper.
# --- 课程学习外循环(Outer Loop) ---
# 这个循环对应论文算法中的外循环,用于实现课程学习。
for ol in range(args.outer_loop):
# -- BatchNorm层预热与冻结(一个非常重要的工程技巧) --
''' freeze the running mu and sigma for BatchNorm layers '''
# Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.
# So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.
# This would make the training with BatchNorm layers easier.
# 动机:合成数据的批次非常小(例如ipc=1),如果让BN层在这么小的批次上计算均值和方差,结果会极其不稳定,导致训练困难。
# 解决方案:先用一个包含多个真实样本的”大“批次来预热BN层,计算出稳定的统计量,然后将其冻结。
BN_flag = False
BNSizePC = 16 # for batch normalization 每个类别用于BN预热的样本数
# 检查网络中是否存在BN层
for module in net.modules():
if 'BatchNorm' in module._get_name(): #BatchNorm
BN_flag = True
if BN_flag:
# 从每个类别中抽取BNSizePC个真实图像,拼接成一个大批次。
img_real = torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim=0)
# 确保网络在训练模式,以便BN层可以更新其 running_mean 和 running_var。
net.train() # for updating the mu, sigma of BatchNorm
# 进行一次前向传播,这个操作会自动更新BN层的统计量。
output_real = net(img_real) # get running mu, sigma
# 将所有BN层切换到评估模式。
# 在评估模式下,BN层会使用已经计算好的 running_mean 和 running_var,而不会再根据新的输入来更新它们。
# 这就实现了“冻结”的效果。
for module in net.modules():
if 'BatchNorm' in module._get_name(): #BatchNorm
module.eval() # fix mu and sigma of every BatchNorm layer
# --- 核心:通过梯度匹配更新合成数据 ---
''' update synthetic data '''
# 初始化当前外循环的总损失
loss = torch.tensor(0.0).to(args.device)
# 按照类别独立进行梯度匹配,这个是论文提出的另外一个技巧。
for c in range(num_classes):
# 准备真实数据和合成数据
img_real = get_images(c, args.batch_real)
lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
lab_syn = torch.ones((args.ipc,), device=args.device, dtype=torch.long) * c
# 如果使用DSA方法,对真实和合成图像应用相同的可微数据增强
if args.dsa:
seed = int(time.time() * 1000) % 100000
img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)
# --- 计算真实梯度 gw_real ---
output_real = net(img_real)
loss_real = criterion(output_real, lab_real)
# 计算损失对网络参数的梯度
gw_real = torch.autograd.grad(loss_real, net_parameters)
# clone()和detach()是为了将梯度值复制下来,并切断其与计算图的联系,
# 因为我们只需要它的数值作为匹配目标,不希望梯度回流真实数据。
gw_real = list((_.detach().clone() for _ in gw_real))
# -- 计算合成梯度gw_syn --
output_syn = net(img_syn)
loss_syn = criterion(output_syn, lab_syn)
# 关键所在:create_graph=True
# 这个参数告诉pytorch,在计算gw_syn时,要保留其计算图。
# 这意味着gw_syn本身也成为了一个计算图中的节点,它依赖于iamge_syn.
# 因此,后续对gw_syn的损失进行反向传播时,梯度可以一直流回image_syn。
gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)
# 计算真实梯度和合成梯度之间的匹配损失,余弦相似度
loss += match_loss(gw_syn, gw_real, args)
# 更新合成图像
optimizer_img.zero_grad() # 清空image_syn的梯度缓存
loss.backward() # 反向传播,计算匹配损失对image_syn对image_syn的梯度
optimizer_img.step() # 根据梯度更新image_syn的像素值
loss_avg += loss.item() # 累加损失用于打印
# 如果是最后一个外循环,就不需要再更新网络了,直接跳出。
if ol == args.outer_loop - 1:
break
# --- 2.3 内循环:用更新后的合成数据训练网络 ---
''' update network '''
# 第二步:现在轮到网络来适应更新后的合成数据了。
image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
# 对网络进行inner_loop次的训练更新。
for il in range(args.inner_loop):
epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)
# 记录和保存
# 计算并打印平均损失
loss_avg /= (num_classes*args.outer_loop)
if it%10 == 0:
print('%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))
# 如果是最后一次主迭代,保存所有结果
if it == args.Iteration: # only record the final results
data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))
算法逻辑总结
“你追我赶”的双重优化过程:
- 课程学习 (Outer Loop):
- 每一次外循环,都像是新学期开学,我们找来一个“新生”(一个随机初始化的
net
)。 - 这个“新生”的存在,是为了确保我们的“教材”(合成数据
image_syn
)是普适的,对任何基础的学生都有效。
- 每一次外循环,都像是新学期开学,我们找来一个“新生”(一个随机初始化的
- 教材编写 (Update Synthetic Data):
- 这是核心步骤。我们让“新生”
net
分别看“官方教材”(真实数据img_real
)和我们正在编写的“浓缩笔记”(合成数据img_syn
)。 - 我们记录下“新生”看完两种材料后的“学习心得”(梯度
gw_real
和gw_syn
)。 - 我们的目标是修改“浓缩笔记”
img_syn
,使得“新生”看完它之后产生的“学习心得”gw_syn
和看完“官方教材”产生的gw_real
一模一样。 create_graph=True
是实现这一点的技术关键,它允许我们对“学习心得”本身求导,从而知道该如何修改“浓y缩笔记”的每一个字(像素)。
- 这是核心步骤。我们让“新生”
- 学生自习 (Update Network):
- “浓缩笔记”
image_syn
更新完毕后,我们让“新生”net
对着这本新版的笔记自习几遍(inner_loop
次)。 - 这会让“新生”对当前的“浓缩笔记”有更深的理解,为下一轮的“教材编写”做好准备。
- “浓缩笔记”
这个“编写教材 -> 学生自习 -> 换个新生再来一遍”的过程不断重复,最终使得“浓缩笔记” image_syn
变得越来越精华,能够高效地替代“官方教材” T
。