On Metric Learning for Audio-Text Cross-Modal Retrieval
原文摘要
任务定义:
- 音频-文本检索(audio-text retrieval)旨在给定一个模态(音频或文本)的查询时,从候选池中检索出另一个模态的目标内容(音频片段或文本描述)。
任务挑战:
- 该任务具有挑战性,因为需要同时满足两个要求:
- 学习两种模态的鲁棒特征表示
- 捕获两种模态之间的细粒度对齐关系
- 该任务具有挑战性,因为需要同时满足两个要求:
现有方法:
现有的跨模态检索模型主要通过度量学习目标进行优化
- 这些方法试图将数据映射到嵌入空间,使得相似数据距离近,不相似数据距离远
与其他跨模态检索任务相比,音频-文本检索仍是一个未被充分探索的领域
本文研究内容:
研究不同度量学习目标对音频-文本检索任务的影响
在AudioCaps和Clotho数据集上对流行的度量学习目标进行了广泛评估
主要发现:
从自监督学习领域改编的NT-Xent损失在不同数据集和训练设置下表现出稳定的性能
NT-Xent损失优于流行的基于三元组的损失函数
1. Introduction
研究任务定义与挑战
audio-text retrieval 任务定义:
输入:一个音频片段或文本描述作为查询。
目标:从另一个模态的候选池中检索出与之匹配的内容(即给定音频找文本,或给定文本找音频)。
挑战:
需要学习鲁棒的特征表示,分别处理音频和文本模态。
需要捕捉细粒度的跨模态交互,并在共享嵌入空间中对齐它们。
跨模态检索研究现状
图像-文本检索和视频-文本检索已经得到广泛研究并取得显著进展。
音频-文本检索关注较少,原因可能包括:
数据集的缺乏
研究主要聚焦于基于标签的音频检索,即查询是单词而非自然语言句子。
早期相关工作:
Chechik et al.:使用传统机器学习方法(SVM、GMM)进行标签音频检索。
Ikawa et al.:研究拟声词(onomatopoeic words)检索音频。
Elizalde et al.:使用 Siamese 网络 对齐音频和文本特征到联合嵌入空间。
局限性: 这些方法仅限于标签查询,而自由形式的自然语言查询 更符合人类习惯。
自由语言音频-文本检索的兴起
音频描述 的发展促进了相关数据集的发布(如 AudioCaps、Clotho),使得基于自由语言的音频-文本检索成为可能。
关键研究:Koepke et al.首次建立自由语言音频检索基准,借鉴视频检索方法,并利用预训练模型缓解数据稀缺问题。
任务特点:
音频和文本都是序列数据,因此比标签检索更具挑战性。
本文主要研究自由语言的音频-文本检索。
音频-文本检索模型架构
与其它跨模态检索模型类似,音频-文本检索模型通常由两个子网络组成:
音频编码器
文本编码器
目标:
将音频和文本映射到 联合嵌入空间,使得语义相似的样本距离近,不相似的样本距离远。
这种嵌入称为 声学语义嵌入(Acoustic Semantic Embeddings, ASE)
- 因为它们通过联合建模音频和语言模态学习得到。
训练目标:与度量学习的目标一致,即优化样本间的相似性度量。
度量学习&对比学习
尽管已有多种度量学习目标用于不同任务,但 尚无明确结论 表明哪种方法最适合音频-文本检索,因为:
某些方法在特定任务或数据上表现良好,但泛化性较差。
本文的目标是在固定训练设置下,比较不同度量学习目标对音频-文本检索的影响。
本文研究的度量学习方法:三元组损失(triplet loss)及其变体
Triplet-Sum(基于铰链的排序损失):
- 计算 mini-batch 内所有负样本的损失之和。
Triplet-Max(难样本挖掘):
- 仅关注 最难的负样本(hardest negative),避免简单负样本主导损失。
Triplet-Weighted(加权框架):
- 根据相似度分数动态加权样本对。
自监督学习中的对比损失:
- NT-Xent(Normalized Temperature-scaled Cross Entropy Loss):
- 基于 softmax,旨在识别 mini-batch 内的正样本对。
- NT-Xent(Normalized Temperature-scaled Cross Entropy Loss):
实验发现:
与普遍观点相反,难样本挖掘的三元组损失(Triplet-Max)对训练设置敏感,可能难以收敛。
NT-Xent 损失 在不同数据集和训练设置下表现稳定,且优于三元组损失。
2. Audio-Text Retrieval with Metric Learning
2.1 问题定义
给定一个音频-文本数据集 D={(ai,ti)}i=1ND = \{(a_i, t_i)\}_{i=1}^ND={(ai,ti)}i=1N,其中:
aia_iai 是音频片段
tit_iti 是配对的文本描述
(ai,ti)(a_i, t_i)(ai,ti) 是 正样本对
(ai,tj≠i)(a_i, t_{j \neq i})(ai,tj=i) 是 负样本对
模型架构:
音频编码器(audio encoder)fff:将音频映射到共享嵌入空间。
文本编码器(text encoder)ggg:将文本映射到共享嵌入空间。
相似度计算:
音频 aia_iai 和文本 tjt_jtj 的相似度 sijs_{ij}sij 使用 余弦相似度 计算:
sij=f(ai)⋅g(tj)∥f(ai)∥2∥g(tj)∥2 s_{ij} = \frac{f(a_i) \cdot g(t_j)}{\|f(a_i)\|_2 \|g(t_j)\|_2} sij=∥f(ai)∥2∥g(tj)∥2f(ai)⋅g(tj)训练目标:使正样本对的相似度 siis_{ii}sii 高于负样本对的相似度 sijs_{ij}sij。
2.2 模型架构
- 由于音频-文本数据稀缺,作者采用 预训练模型 进行迁移学习。
2.2.1 音频编码器
- 采用 PANNs(Pre-trained Audio Neural Networks),在 AudioSet 上预训练(音频分类任务)。
具体结构:
- 使用ResNet-38,舍弃最后两个线性层。
- 在最后一个卷积块输出的特征图上,沿频率维度应用平均池化和最大池化。
- 使用MLP将特征投影到共享嵌入空间:2 个线性层 + ReLU 激活函数。
结构总结:PANNs(ResNet-38 + 池化 + MLP)
2.2.2 文本编码器
采用 BERT(Bidirectional Encoder Representations from Transformers)。
- 具体结构:
- 在每个句子前添加
<CLS>
标记,并以其作为句子表示。 - 同样使用MLP将 BERT 输出投影到共享嵌入空间。
- 在每个句子前添加
- 结构总结:BERT + MLP
- 具体结构:
2.3 损失函数
在训练时,采样一个 mini-batch {ai,ti}i=1B\{a_i, t_i\}_{i=1}^B{ai,ti}i=1B,其中 BBB 是 batch size。
2.3.1 Triplet-Sum Loss
核心思想:最大化正样本对的相似度,同时最小化所有负样本对的相似度。
L=1B∑i=1B∑j≠i[m+sij−sii]++[m+sji−sii]+ \mathcal{L} = \frac{1}{B} \sum_{i=1}^B \sum_{j \neq i} [m + s_{ij} - s_{ii}]_+ + [m + s_{ji} - s_{ii}]_+ L=B1i=1∑Bj=i∑[m+sij−sii]++[m+sji−sii]+- [x]+=max(0,x)[x]_+ = \max(0, x)[x]+=max(0,x) 是 hinge loss。
- mmm 是 边界超参数。
- 第一项:音频到文本检索(audio-to-text)。
- 第二项:文本到音频检索(text-to-audio)。
- 如果正样本对的相似度比所有负样本对高至少 mmm,则损失为 0。
2.3.2 Triplet-Max Loss
核心思想:仅关注最难负样本,避免简单负样本主导损失。
L=1B∑i=1B(maxj≠i[m+sij−sii]++maxj≠i[m+sji−sii]+) \mathcal{L} = \frac{1}{B} \sum_{i=1}^B \left( \max_{j \neq i} [m + s_{ij} - s_{ii}]_+ + \max_{j \neq i} [m + s_{ji} - s_{ii}]_+ \right) L=B1i=1∑B(j=imax[m+sij−sii]++j=imax[m+sji−sii]+)- 仅对最难负样本(相似度最高的负样本)计算损失。
2.3.3 Triplet-Weighted Loss
核心思想:根据相似度动态加权正负样本对。
权重函数:
- 正样本对权重 GposG_{pos}Gpos:Gpos=apsiip+ap−1siip−1+⋯+a1sii+a0G_{pos} = a_p s_{ii}^p + a_{p-1} s_{ii}^{p-1} + \cdots + a_1 s_{ii} + a_0Gpos=apsiip+ap−1siip−1+⋯+a1sii+a0
- 负样本对权重 GnegG_{neg}Gneg:Gneg=bqsijq+bq−1sijq−1+⋯+b1sij+b0G_{neg} = b_q s_{ij}^q + b_{q-1} s_{ij}^{q-1} + \cdots + b_1 s_{ij} + b_0Gneg=bqsijq+bq−1sijq−1+⋯+b1sij+b0
- 正样本相似度 siis_{ii}sii 越高,权重越低;负样本相似度 sijs_{ij}sij 越高,权重越高。
损失函数(最大多项式损失):
L=1B∑i=1B([∑p=0Papsiip+∑q=0Qbqmax{Naiq}]++[∑p=0Papsiip+∑q=0Qbqmax{Ntiq}]+) \mathcal{L} = \frac{1}{B} \sum_{i=1}^B \left( \left[ \sum_{p=0}^P a_p s_{ii}^p + \sum_{q=0}^Q b_q \max \{N_{a_i}^q\} \right]_+ + \left[ \sum_{p=0}^P a_p s_{ii}^p + \sum_{q=0}^Q b_q \max \{N_{t_i}^q\} \right]_+ \right) L=B1i=1∑B [p=0∑Papsiip+q=0∑Qbqmax{Naiq}]++[p=0∑Papsiip+q=0∑Qbqmax{Ntiq}]+ - Na={sij∣j≠i}N_a = \{s_{ij} | j \neq i\}Na={sij∣j=i}:音频 aia_iai 的所有负样本相似度。
- Nt={sji∣i≠j}N_t = \{s_{ji} | i \neq j\}Nt={sji∣i=j}:文本 tit_iti 的所有负样本相似度。
2.3.4 NT-Xent Loss
核心思想:基于 softmax 的对比损失,最大化正样本对的概率。
L=−1B(∑i=1Blogexp(sii/τ)∑j=1Bexp(sij/τ)+∑i=1Blogexp(sii/τ)∑j=1Bexp(sji/τ)) \mathcal{L} = -\frac{1}{B} \left( \sum_{i=1}^B \log \frac{\exp(s_{ii}/\tau)}{\sum_{j=1}^B \exp(s_{ij}/\tau)} + \sum_{i=1}^B \log \frac{\exp(s_{ii}/\tau)}{\sum_{j=1}^B \exp(s_{ji}/\tau)} \right) L=−B1(i=1∑Blog∑j=1Bexp(sij/τ)exp(sii/τ)+i=1∑Blog∑j=1Bexp(sji/τ)exp(sii/τ))- τ\tauτ 是温度超参数。
- 第一项:音频到文本检索。
- 第二项:文本到音频检索。
3. Experiments
3.1 数据集
AudioCaps
- 规模:约 50k 音频片段(10 秒/段),来源自 AudioSet [27]。
- 数据划分:
- 训练集:49,274 段音频,每段 1 条描述。
- 验证集/测试集:494/957 段音频,每段 5 条描述。
Clotho v2
- 特点:音频长度 15-30 秒(均匀分布),来自 Freesound 存档。
- 数据划分:
- 训练集:3,839 段音频。
- 验证集/测试集:各 1,045 段音频。
- 标注:每段音频 5 条描述,长度 8-20 词。
3.2 实验设置
音频特征:
- 提取 对数梅尔频谱图(log mel-spectrograms),参数:
- 汉宁窗(1024 点),步长 320 点,64 个梅尔频带。
- 提取 对数梅尔频谱图(log mel-spectrograms),参数:
训练配置:
- 优化器:Adam ,初始学习率 1×10−41 \times 10^{-4}1×10−4 或 5×10−55 \times 10^{-5}5×10−5,每 20 轮衰减至 1/10。
- 训练轮次:50 epochs。
- 批量大小:AudioCaps(32),Clotho(24)。
- 预训练模型处理:对比 冻结 和 微调 的效果。
- 模型选择:根据验证集 Recall 总和选择最佳模型。
损失函数超参数:
- Triplet 系列损失:边界 $m = 0.24。
- NT-Xent 损失:温度参数 τ=0.07\tau = 0.07τ=0.07。
- Triplet-Weighted 损失:
- 正样本权重:P=2P=2P=2,系数 {a0=0.5,a1=−0.7,a2=0.2}\{a_0=0.5, a_1=-0.7, a_2=0.2\}{a0=0.5,a1=−0.7,a2=0.2}。
- 负样本权重:Q=2Q=2Q=2,系数 {b0=0.03,b1=−0.4,b2=0.9}\{b_0=0.03, b_1=-0.4, b_2=0.9\}{b0=0.03,b1=−0.4,b2=0.9}。
嵌入空间:
- 维度 1024,嵌入向量归一化。
- 硬件:单卡 RTX3090 GPU。
3.3 评估协议
评价指标:
Recall@k(R@k):衡量在前 kkk 个检索结果中找到目标的比例,报告 R@1、R@5、R@10。
稳定性测试:重复 3 次实验(不同随机种子),汇报均值与标准差。