从代码学习深度学习 - 自然语言推断:微调BERT PyTorch版

发布于:2025-07-16 ⋅ 阅读:(23) ⋅ 点赞:(0)


前言

自然语言推断(NLI)是自然语言处理(NLP)领域一个核心且富有挑战性的任务。它的目标是判断两个句子——“前提(Premise)”和“假设(Hypothesis)”——之间的逻辑关系。这种关系通常分为三类:

  1. 蕴含(Entailment): 假设的意义可以从前提中推断出来。
  2. 矛盾(Contradiction): 假设的意义与前提相矛盾。
  3. 中性(Neutral): 前提和假设之间没有明确的逻辑关系。

例如:

  • 前提: 一个人在马上。
  • 假设: 一个人在动物身上。
  • 关系: 蕴含

近年来,以BERT(Bidirectional Encoder Representations from Transformers)为代表的预训练语言模型在众多NLP任务中取得了革命性的突破。其强大的上下文理解能力,使其成为解决NLI等任务的理想选择。

本篇博客将带领大家,通过PyTorch代码,一步步实现如何“微调(Fine-tuning)”一个预训练好的BERT模型,使其适应并高效地完成自然语言推断任务。我们将使用经典的SNLI(Stanford Natural Language Inference)数据集,并详细剖析从数据加载、模型构建到最终训练的全过程。
在这里插入图片描述

完整代码:[通过网盘分享的文件:自然语言推断:微调BERT.rar
链接: https://pan.baidu.com/s/1OxS-BU0MSOJXXB5wJA394w?pwd=8rc6 提取码: 8rc6
–来自百度网盘超级会员v6的分享]


加载预训练的BERT

微调的第一步,是加载一个已经在海量文本数据上(如维基百科)预训练好的BERT模型。这个预训练过程让BERT学会了通用的语言知识,我们要做的是在这个基础上,针对我们的特定任务(NLI)进行“微调”。

我们定义一个函数 load_pretrained_model 来加载模型及其对应的词汇表。词汇表(Vocabulary)是词元(token)到索引(index)的映射,是模型处理文本的基础。

import json
import os
import torch
import utils_for_vocab
import utils_for_model
import utils_for_train


def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
                         num_heads, num_layers, dropout, max_len, devices):
    """
    加载预训练的BERT模型和词汇表
    
    参数:
        pretrained_model (str): 预训练模型名称,用于构建数据目录路径
        num_hiddens (int): 隐藏层维度 [256]
        ffn_num_hiddens (int): 前馈网络隐藏层维度 [512]
        num_heads (int): 多头注意力机制的头数 [4]
        num_layers (int): Transformer层数 [2]
        dropout (float): dropout比例 [0.1]
        max_len (int): 最大序列长度 [512]
        devices (list): 可用的GPU设备列表
    
    返回:
        bert (BERTModel): 加载了预训练参数的BERT模型
        vocab (Vocab): 词汇表对象
    """
    # 构建数据目录路径
    data_dir = pretrained_model + ".torch"
    
    # 定义空词表以加载预定义词表
    vocab = utils_for_vocab.Vocab()
    
    # 从JSON文件加载词汇表的索引到词汇的映射
    # vocab.idx_to_token: list,维度为 [vocab_size],存储索引到词汇的映射
    vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
    
    # 构建词汇到索引的映射字典
    # vocab.token_to_idx: dict,存储词汇到索引的映射
    vocab.token_to_idx = {
   token: idx for idx, token in enumerate(vocab.idx_to_token)}
    
    # 创建BERT模型实例
    # bert: BERTModel对象,包含编码器和预训练任务头
    bert = utils_for_model.BERTModel(
        len(vocab),                    # vocab_size: 词汇表大小 [vocab_size]
        num_hiddens,                   # num_hiddens: 隐藏层维度 [256]
        norm_shape=[256],              # norm_shape: 层归一化的形状 [256]
        ffn_num_input=256,             # ffn_num_input: 前馈网络输入维度 [256]
        ffn_num_hiddens=ffn_num_hiddens,  # ffn_num_hiddens: 前馈网络隐藏层维度 [512]
        num_heads=4,                   # num_heads: 多头注意力头数 [4]
        num_layers=2,                  # num_layers: Transformer层数 [2]
        dropout=0.2,                   # dropout: dropout比例 [0.2]
        max_len=max_len,               # max_len: 最大序列长度 [512]
        key_size=256,                  # key_size: 注意力机制中key的维度 [256]
        query_size=256,                # query_size: 注意力机制中query的维度 [256]
        value_size=256,                # value_size: 注意力机制中value的维度 [256]
        hid_in_features=256,           # hid_in_features: 隐藏层输入特征维度 [256]
        mlm_in_features=256,           # mlm_in_features: 掩码语言模型输入特征维度 [256]
        nsp_in_features=256            # nsp_in_features: 下一句预测任务输入特征维度 [256]
    )
    
    # 加载预训练的BERT模型参数
    # torch.load返回的是state_dict,包含模型的所有参数
    bert.load_state_dict(torch.load(os.path.join(data_dir, 'pretrained.params')))
    
    return bert, vocab


# 获取所有可用的GPU设备
# devices: list,包含可用GPU设备的列表
devices = utils_for_train.try_all_gpus()

# 加载预训练的BERT模型和词汇表
# bert: BERTModel对象,已加载预训练参数
# vocab: Vocab对象,包含词汇表映射
bert, vocab = load_pretrained_model(
    'bert.small',                # pretrained_model: 预训练模型名称
    num_hiddens=256,             # num_hiddens: 隐藏层维度 [256]
    ffn_num_hiddens=512,         # ffn_num_hiddens: 前馈网络隐藏层维度 [512]
    num_heads=4,                 # num_heads: 多头注意力头数 [4]
    num_layers=2,                # num_layers: Transformer层数 [2]
    dropout=0.1,                 # dropout: dropout比例 [0.1]
    max_len=512,                 # max_len: 最大序列长度 [512]
    devices=devices              # devices: 可用GPU设备列表
)

这里我们加载了一个小型的BERT模型(bert.small),它包含2个Transformer层,隐藏层维度为256。加载完成后,我们可以打印bert对象,查看其详细的模型结构。

bert
```输出的模型结构会非常详细,它清晰地展示了BERT的内部组件,包括词元嵌入(`token_embedding`)、片段嵌入(`segment_embedding`)、由多个编码器块(`EncoderBlock`)组成的编码器(`encoder`),以及用于预训练的MLM(`MaskLM`)和NSP(`NextSentencePred`)任务头。在微调阶段,我们主要关心的是`encoder`部分。

## 微调BERT的数据集

数据是模型训练的“养料”。对于NLI任务,我们需要将SNLI数据集处理成BERT能够理解的格式。

### 数据读取与预处理

首先,我们需要一个函数来读取SNLI数据集的原始文本文件。该文件是制表符分隔的,我们需要从中抽取出前提、假设和标签。

```python
# 该函数位于 utils_for_data.py
def read_snli(data_dir, is_train):
    """
    将SNLI数据集解析为前提、假设和标签
    """
    # ... (代码见附录)

BERT处理成对的句子(如前提和假设)时,需要一种特殊的输入格式。两个句子被拼接在一起,并用特殊标记隔开:
[CLS] 前提词元 [SEP] 假设词元 [SEP]

  • [CLS]:位于序列开头,它的最终隐藏状态被用作整个序列的聚合表示,通常用于分类任务。
  • [SEP]:用于分隔两个句子。

我们还需要一个“片段索引(Segment ID)”,用来区分哪个词元属于前提(标记为0),哪个属于假设(标记为1)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
图2: BERT处理句子对(如NLI任务)的输入格式图示。

get_tokens_and_segments 函数负责实现这个格式转换。

# 该函数位于 utils_for_data.py
def get_tokens_and_segments(tokens_a, tokens_b=None):
    """
    获取输入序列的词元及其片段索引
    """
    # ... (代码见附录)

构建PyTorch数据集

为了与PyTorch的 DataLoader 高效配合,我们创建一个自定义的Dataset类——SNLIBERTDataset。这个类封装了所有的数据预处理逻辑:

  1. 词元化(Tokenization): 将句子切分成词元。
  2. 格式化: 调用get_tokens_and_segments构建BERT输入格式。
  3. 截断(Truncation): 由于BERT输入有最大长度限制(如128或512),需要将过长的句子对进行截断。
  4. 填充(Padding): 将所有序列填充到相同的最大长度,以便进行批量处理。
  5. 数值化: 将词元转换为词汇表中的索引。

这个类还巧妙地使用了Python的multiprocessing库来并行处理数据,极大地加速了预处理过程。

import torch
import multiprocessing
import utils_for_data
import utils_for_vocab


class SNLIBERTDataset(torch.utils.data.Dataset):
    """
    用于BERT模型的SNLI数据集处理类
    
    该类继承自torch.utils.data.Dataset,用于处理Stanford Natural Language Inference (SNLI)
    数据集,将其转换为适合BERT模型训练的格式。
    """
    
    def __init__(self, dataset, max_len, vocab=None):
        """
        初始化SNLI BERT数据集
        """
        # 对前提和假设句子进行词元化处理
        all_premise_hypothesis_tokens = [[\
            p_tokens, h_tokens] for p_tokens, h_tokens in zip(\
            *[utils_for_vocab.tokenize([s.lower() for s in sentences])
              for sentences in dataset[:2]])]
        
        # 将标签转换为张量
        self.labels = torch.tensor(dataset[2])
        self.vocab = vocab
        self.max_len = max_len
        
        # 预处理所有的词元对,生成模型输入格式
        (self.all_token_ids, self.all_segments,
         self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
        
        print('read ' + str(len(self.all_token_ids)) + ' examples')
    
    def _preprocess(self, all_premise_hypothesis_tokens):
        """
        使用多进程预处理所有的前提-假设词元对
        """
        pool = multiprocessing.Pool(4)  # 使用4个进程
        out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
        all_token_ids = [token_ids for token_ids, segments, valid_len in out]
        all_segments = [segments for token_ids, segments, valid_len in out]
        valid_lens = [valid_len for token_ids, segments, valid_len in out]
        
        return (torch.tensor(all_token_ids, dtype=torch.long),
                torch

网站公告

今日签到

点亮在社区的每一天
去签到