文章目录
前言
大家好!欢迎来到新一期的“从代码学习深度学习”系列。今天,我们将深入探索自然语言处理(NLP)领域最具影响力的模型之一——BERT(Bidirectional Encoder Representations from Transformers)。自2018年问世以来,BERT凭借其强大的语言理解能力,彻底改变了NLP领域的格局。
BERT的核心思想在于通过两个巧妙的预训练任务,在一个巨大的文本语料库上学习语言的深层双向表示。这两个任务分别是:
- 掩码语言模型 (Masked Language Model, MLM):随机遮盖句子中的一些词,让模型去预测这些被遮盖的词是什么。这迫使模型不仅要理解单个词的含义,还要理解词与词之间的上下文关系。
- 下一句预测 (Next Sentence Prediction, NSP):给定两个句子,让模型判断第二个句子是否是第一个句子的真实下一句。这帮助模型学习句子间的逻辑关系和连贯性。
在本篇文章中,我们将通过一个完整的PyTorch项目,一步步地实现BERT的预训练过程。我们将从最基础的组件(如缩放点积注意力)开始,搭建完整的BERT模型,然后处理专门为预训练设计的数据集,并最终完成整个训练流程。我们的目标不仅仅是运行代码,更是要通过每一行代码,深入理解BERT的内在机制。
让我们开始这场编码之旅,亲手揭开BERT神秘的面纱!
完整代码:下载链接
一、数据准备:为BERT量身打造“教科书”
模型是学习者,而数据是教科书。要让BERT学会语言,我们首先需要为它准备一本精心设计的教科书。这个过程远比简单的文本加载要复杂,它需要为MLM和NSP两大任务生成特定的训练样本。我们将使用维基百科(WikiText-2)数据集作为原始语料。
下面,我们将详细分解数据处理的每一步。所有这些功能都封装在utils_for_data.py
文件中。
1.1 数据处理工具函数 (utils_for_data.py
)
这个脚本包含了从原始文本文件到最终可供模型训练的批次数据的所有处理逻辑。
# --- START OF FILE utils_for_data.py ---
import os
import random
def _read_wiki(data_dir):
"""
读取维基百科训练数据文件并进行预处理
参数:
data_dir (str): 数据目录路径 [标量]
返回:
paragraphs (list): 处理后的段落列表 [N x M], 其中N为段落数量,M为每个段落的句子数量
"""
# 构建完整的文件路径
# file_name: 字符串类型,表示wiki训练数据的完整文件路径 [标量]
file_name = os.path.join(data_dir, 'wiki.train.tokens')
# 以只读模式打开文件
with open(file_name, 'r') as f:
# 读取文件所有行
# lines: 列表类型,包含文件中的所有行 [行数 x 1],每个元素为字符串
lines = f.readlines()
# 处理文本数据:转换为小写并按句点分割
# 筛选条件:只保留分割后至少包含2个部分的行(即至少有一个句点分隔符)
paragraphs = [] # paragraphs: 嵌套列表 [段落数 x 句子数],外层为段落,内层为句子
for line in lines: # line: 字符串类型,表示文件中的单行文本 [标量]
# 将当前行去除首尾空白字符,转换为小写,然后按' . '分割
# split_result: 列表类型,包含分割后的句子片段 [句子数 x 1]
split_result = line.strip().lower().split(' . ')
# 只保留分割后至少有2个部分的段落(确保至少有一个完整的句子)
if len(split_result) >= 2:
paragraphs.append(split_result)
# 随机打乱段落顺序,增加数据的随机性
random.shuffle(paragraphs)
return paragraphs
import random
def _get_next_sentence(sentence, next_sentence, paragraphs):
"""
为BERT模型的下一句预测任务生成训练样本
参数:
sentence (str): 当前句子
next_sentence (str): 候选的下一个句子
paragraphs (list): 段落列表 [段落数 x 句子数],每个段落包含多个句子
返回:
sentence (str): 当前句子(保持不变)
next_sentence (str): 最终确定的下一个句子
is_next (bool): 标识next_sentence是否为sentence的真实下一句
"""
# 生成0到1之间的随机数,用于决定是否使用真实的下一句
# random.random(): 浮点数类型,范围[0.0, 1.0) [标量]
random_prob = random.random()
# 50%的概率保持原始的下一句关系(正样本)
if random_prob < 0.5:
# 保持输入的next_sentence不变,表示这是真实的下一句
# is_next: 布尔值,True表示next_sentence确实是sentence的下一句 [标量]
is_next = True
else:
# 50%的概率生成随机的下一句(负样本)
# 从paragraphs中随机选择一个段落,再从该段落中随机选择一个句子
# random.choice(paragraphs): 随机选择的段落 [句子数],是paragraphs中的一个元素
# random.choice(random.choice(paragraphs)): 从随机选择的段落中随机选择的句子 [标量]
random_paragraph = random.choice(paragraphs) # random_paragraph: 列表类型 [句子数 x 1]
next_sentence = random.choice(random_paragraph) # 重新赋值next_sentence为随机句子
# is_next: 布尔值,False表示next_sentence不是sentence的真实下一句 [标量]
is_next = False
# 返回三元组:原句子、处理后的下一句、是否为真实下一句的标识
return sentence, next_sentence, is_next
def get_tokens_and_segments(tokens_a, tokens_b=None):
"""
获取BERT输入序列的词元及其片段索引
参数:
tokens_a (list): 第一个句子的词元列表 [词元数A x 1],每个元素为字符串
tokens_b (list, optional): 第二个句子的词元列表 [词元数B x 1],可选参数
返回:
tokens (list): 完整的词元序列 [总词元数 x 1],包含特殊词元
segments (list): 片段标识序列 [总词元数 x 1],0表示片段A,1表示片段B
"""
# 构建基础词元序列:[<cls>] + 第一个句子 + [<sep>]
# tokens: 列表类型,存储完整的词元序列 [词元数 x 1]
tokens = ['<cls>'] + tokens_a + ['<sep>']
# 为第一个片段创建标识序列,全部标记为0
# segments: 列表类型,存储片段标识 [词元数 x 1]
# len(tokens_a) + 2: tokens_a的长度 + <cls>和<sep>两个特殊词元
segments = [0] * (len(tokens_a) + 2)
# 如果存在第二个句子,添加到序列中
if tokens_b is not None:
# 将第二个句子和结束符添加到词元序列:tokens_b + [<sep>]
tokens += tokens_b + ['<sep>']
# 为第二个片段创建标识序列,全部标记为1
# len(tokens_b) + 1: tokens_b的长度 + 1个<sep>词元
segments += [1] * (len(tokens_b) + 1)
return tokens, segments
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
"""
从单个段落中生成下一句预测(NSP)任务的训练数据
参数:
paragraph (list): 单个段落,包含多个句子 [句子数 x 1],每个元素为字符串
paragraphs (list): 所有段落的列表 [段落数 x 句子数],用于生成负样本
vocab (Vocab): 词汇表对象 ,用于词汇处理
max_len (int): 序列的最大长度限制 [标量]
返回:
nsp_data_from_paragraph (list): NSP训练数据列表 [样本数 x 3]
每个样本包含(tokens, segments, is_next)
"""
# 初始化存储NSP数据的列表
# nsp_data_from_paragraph: 列表类型 [样本数 x 3],每个元素为(tokens, segments, is_next)元组
nsp_data_from_paragraph = []
# 遍历段落中相邻的句子对,生成NSP训练样本
# i: 整数类型,当前句子的索引 [标量]
# range(len(paragraph) - 1): 确保i+1不会越界
for i in range(len(paragraph) - 1):
# 获取句子对和标签
# paragraph[i]: 字符串类型,当前句子 [标量]
# paragraph[i + 1]: 字符串类型,下一个句子 [标量]
# tokens_a: 列表类型,第一个句子的词元 [词元数A x 1]
# tokens_b: 列表类型,第二个句子的词元 [词元数B x 1]
# is_next: 布尔值类型,标识tokens_b是否为tokens_a的真实下一句 [标量]
tokens_a, tokens_b, is_next = _get_next_sentence(
paragraph[i], paragraph[i + 1], paragraphs)
# 检查序列长度是否超过限制
# 需要为1个'<cls>'词元和2个'<sep>'词元预留3个位置
# len(tokens_a): 整数类型,第一个句子的词元数量 [标量]
# len(tokens_b): 整数类型,第二个句子的词元数量 [标量]
# 总长度 = tokens_a长度 + tokens_b长度 + 3个特殊词元
total_length = len(tokens_a) + len(tokens_b) + 3
# 如果总长度超过最大限制,跳过当前样本
if total_length > max_len:
continue
# 将两个句子组合成BERT输入格式,生成tokens和segments
# tokens: 列表类型,包含特殊词元的完整词元序列 [序列长度 x 1]
# 格式:[<cls>, tokens_a, <sep>, tokens_b, <sep>]
# segments: 列表类型,段落标识序列 [序列长度 x 1]
# 0表示第一个句子,1表示第二个句子
tokens, segments = get_tokens_and_segments(tokens_a, tokens_b)
# 将处理好的数据添加到结果列表中
# (tokens, segments, is_next): 元组类型,包含完整的NSP训练样本
nsp_data_from_paragraph.append((tokens, segments, is_next))
return nsp_data_from_paragraph
import random
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab):
"""
为遮蔽语言模型(MLM)任务替换词元,实现BERT的预训练策略
参数:
tokens (list): 原始词元序列 [序列长度 x 1],每个元素为字符串
candidate_pred_positions (list): 候选预测位置的索引列表 [候选位置数 x 1],每个元素为整数
num_mlm_preds (int): 需要进行MLM预测的词元数量 [标量]
vocab (Vocab): 词汇表对象,包含idx_to_token属性
返回:
mlm_input_tokens (list): 处理后的输入词元序列 [序列长度 x 1],部分词元被遮蔽或替换
pred_positions_and_labels (list): 预测位置和标签对的列表 [预测数量 x 2]
每个元素为(位置索引, 原始词元)的元组
"""
# 创建原始词元序列的副本,避免修改原始数据
# mlm_input_tokens: 列表类型,用于MLM任务的输入词元序列 [序列长度 x 1]
mlm_input_tokens = [token for token in tokens]
# 初始化预测位置和标签的存储列表
# pred_positions_and_labels: 列表类型 [预测数量 x 2],存储(位置, 原始词元)对
pred_positions_and_labels = []
# 随机打乱候选预测位置,确保随机性
# 这样可以随机选择15%的词元进行遮蔽语言模型预测
random.shuffle(candidate_pred_positions)
# 遍历打乱后的候选位置,进行词元替换
for mlm_pred_position in candidate_pred_positions: # mlm_pred_position: 整数,当前处理的位置索引 [标量]
# 如果已达到所需的预测数量,停止处理
if len(pred_positions_and_labels) >= num_mlm_preds:
break
# 初始化遮蔽后的词元
# masked_token: 字符串类型,用于替换原词元的新词元 [标量]
masked_token = None
# 根据BERT的MLM策略进行词元替换:
# 80%的概率:将词元替换为"<mask>"
if random.random() < 0.8:
masked_token = '<mask>'
else:
# 在剩余20%的情况下进一步细分:
# 10%的概率(20% * 50%):保持原词元不变
if random.random() < 0.5:
masked_token = tokens[mlm_pred_position]
# 10%的概率(20% * 50%):用词汇表中的随机词元替换
else:
# vocab.idx_to_token: 列表类型,词汇表的索引到词元映射 [词汇表大小 x 1]
masked_token = random.choice(vocab.idx_to_token)
# 将当前位置的词元替换为遮蔽后的词元
mlm_input_tokens[mlm_pred_position] = masked_token
# 记录预测位置和对应的原始词元(作为标签)
# (mlm_pred_position, tokens[mlm_pred_position]): 元组类型,包含位置索引和原始词元
pred_positions_and_labels.append(
(mlm_pred_position, tokens[mlm_pred_position]))
return mlm_input_tokens, pred_positions_and_labels
def _get_mlm_data_from_tokens(tokens, vocab):
"""
从词元序列中生成遮蔽语言模型(MLM)任务的训练数据
参数:
tokens (list): 输入词元序列 [序列长度 x 1],每个元素为字符串
vocab (Vocab): 词汇表对象 ,用于词元到索引的转换
返回:
mlm_input_ids (list): MLM输入的词元索引序列 [序列长度 x 1],每个元素为整数
pred_positions (list): 需要预测的位置索引列表 [预测数量 x 1],每个元素为整数
mlm_pred_label_ids (list): 预测位置对应的真实标签索引 [预测数量 x 1],每个元素为整数
"""
# 初始化候选预测位置列表
# candidate_pred_positions: 列表类型 [候选位置数 x 1],存储可以进行MLM预测的位置索引
candidate_pred_positions = []
# 遍历所有词元,找出可以进行MLM预测的位置
# i: 整数类型,词元在序列中的位置索引 [标量]
# token: 字符串类型,当前位置的词元 [标量]
for i, token in enumerate(tokens):
# 跳过特殊词元,这些词元在MLM任务中不进行预测
# '<cls>'和'<sep>'是BERT的特殊标记,不参与遮蔽预测
if token in ['<cls>', '<sep>']:
continue
# 将非特殊词元的位置添加到候选列表中
candidate_pred_positions.append(i)
# 计算需要进行MLM预测的词元数量
# 按照BERT论文,遮蔽15%的词元进行预测
# max(1, ...): 确保至少预测1个词元,避免空预测
# round(...): 四舍五入到最近的整数
# num_mlm_preds: 整数类型,需要预测的词元数量 [标量]
num_mlm_preds = max(1, round(len(tokens) * 0.15))
# 调用词元替换函数,生成MLM输入和预测标签
# mlm_input_tokens: 列表类型,经过遮蔽处理的输入词元序列 [序列长度 x 1]
# pred_positions_and_labels: 列表类型,预测位置和标签的元组列表 [预测数量 x 2]
mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(
tokens, candidate_pred_positions, num_mlm_preds, vocab)
# 按位置索引对预测数据进行排序,确保位置顺序的一致性
# key=lambda x: x[0]: 按元组的第一个元素(位置索引)进行排序
pred_positions_and_labels = sorted(pred_positions_and_labels,
key=lambda x: x[0])
# 提取预测位置列表
# pred_positions: 列表类型 [预测数量 x 1],包含所有需要预测的位置索引
# v[0]: 元组中的位置索引(第一个元素)
pred_positions = [v[0] for v in pred_positions_and_labels]
# 提取预测标签列表(原始词元)
# mlm_pred_labels: 列表类型 [预测数量 x 1],包含预测位置的原始词元
# v[1]: 元组中的原始词元(第二个元素)
mlm_pred_labels = [v[1] for v in pred_positions_and_labels]
# 将词元转换为词汇表索引并返回
# vocab[mlm_input_tokens]: 将MLM输入词元转换为索引序列 [序列长度 x 1]
# vocab[mlm_pred_labels]: 将预测标签词元转换为索引序列 [预测数量 x 1]
return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]
import torch
def _pad_bert_inputs(examples, max_len, vocab):
"""
对BERT输入样本进行填充,使所有样本具有相同的长度
参数:
examples (list): BERT训练样本列表 [样本数 x 5]
每个样本包含(token_ids, pred_positions, mlm_pred_label_ids, segments, is_next)
max_len (int): 序列的最大长度 [标量]
vocab (Vocab): 词汇表对象,用于获取填充词元索引
返回:
all_token_ids (list): 填充后的词元ID张量列表 [样本数 x max_len]
all_segments (list): 填充后的片段ID张量列表 [样本数 x max_len]
valid_lens (list): 有效长度张量列表 [样本数 x 1]
all_pred_positions (list): 填充后的预测位置张量列表 [样本数 x max_num_mlm_preds]
all_mlm_weights (list): MLM权重张量列表 [样本数 x max_num_mlm_preds]
all_mlm_labels (list): 填充后的MLM标签张量列表 [样本数 x max_num_mlm_preds]
nsp_labels (list): NSP标签张量列表 [样本数 x 1]
"""
# 计算MLM任务的最大预测数量(15%的词元)
# max_num_mlm_preds: 整数类型,单个样本最多预测的词元数量 [标量]
max_num_mlm_preds = round(max_len * 0.15)
# 初始化存储填充后数据的列表
# all_token_ids: 列表类型,存储所有样本的词元ID张量 [样本数 x max_len]
# all_segments: 列表类型,存储所有样本的片段ID张量 [样本数 x max_len]
# valid_lens: 列表类型,存储所有样本的有效长度张量 [样本数 x 1]
all_token_ids, all_segments, valid_lens = [], [], []
# all_pred_positions: 列表类型,存储所有样本的预测位置张量 [样本数 x max_num_mlm_preds]
# all_mlm_weights: 列表类型,存储所有样本的MLM权重张量 [样本数 x max_num_mlm_preds]
# all_mlm_labels: 列表类型,存储所有样本的MLM标签张量 [样本数 x max_num_mlm_preds]
all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
# nsp_labels: 列表类型,存储所有样本的NSP标签张量 [样本数 x 1]
nsp_labels = []
# 遍历所有训练样本进行填充处理
for (token_ids, pred_positions, mlm_pred_label_ids, segments, is_next) in examples:
# token_ids: 列表类型,当前样本的词元ID序列 [当前序列长度 x 1]
# pred_positions: 列表类型,当前样本的MLM预测位置 [当前预测数量 x 1]
# mlm_pred_label_ids: 列表类型,当前样本的MLM标签ID [当前预测数量 x 1]
# segments: 列表类型,当前样本的片段ID序列 [当前序列长度 x 1]
# is_next: 布尔值,当前样本的NSP标签 [标量]
# 对词元ID序列进行填充到max_len长度
# 使用'<pad>'词元的索引进行填充
# vocab['<pad>']: 整数类型,填充词元在词汇表中的索引 [标量]
padded_token_ids = token_ids + [vocab['<pad>']] * (max_len - len(token_ids))
all_token_ids.append(torch.tensor(padded_token_ids, dtype=torch.long))
# 对片段ID序列进行填充到max_len长度
# 填充部分使用0(表示不属于任何片段)
padded_segments = segments + [0] * (max_len - len(segments))
all_segments.append(torch.tensor(padded_segments, dtype=torch.long))
# 记录有效长度(不包括填充词元的数量)
# 用于在注意力计算中忽略填充部分
valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))
# 对MLM预测位置进行填充到max_num_mlm_preds长度
# 填充部分使用0(无效位置)
padded_pred_positions = pred_positions + [0] * (max_num_mlm_preds - len(pred_positions))
all_pred_positions.append(torch.tensor(padded_pred_positions, dtype=torch.long))
# 创建MLM权重向量,用于在损失计算中过滤填充位置
# 真实预测位置权重为1.0,填充位置权重为0.0
# len(mlm_pred_label_ids): 当前样本实际的MLM预测数量
mlm_weights = ([1.0] * len(mlm_pred_label_ids) +
[0.0] * (max_num_mlm_preds - len(pred_positions)))
all_mlm_weights.append(torch.tensor(mlm_weights, dtype=torch.float32))
# 对MLM标签进行填充到max_num_mlm_preds长度
# 填充部分使用0(对应损失计算中会被权重过滤掉)
padded_mlm_labels = mlm_pred_label_ids + [0] * (max_num_mlm_preds - len(mlm_pred_label_ids))
all_mlm_labels.append(torch.tensor(padded_mlm_labels, dtype=torch.long))
# 转换NSP标签为张量
# is_next: 布尔值转换为整数(True->1, False->0)
nsp_labels.append(torch.tensor(is_next, dtype=torch.long))
# 返回所有填充后的数据张量
return (all_token_ids, all_segments, valid_lens, all_pred_positions,
all_mlm_weights, all_mlm_labels, nsp_labels)
import torch
from collections import Counter
def tokenize(lines, token='word'):
"""
将文本行拆分为单词或字符词元
参数:
lines (list): 文本行列表 [行数 x 1],每个元素为字符串
token (str): 词元化类型 [标量],'word'表示按单词分割,'char'表示按字符分割
返回:
tokenized_lines (list): 词元化后的文本 [行数 x 词元数],嵌套列表结构
"""
if token == 'word':
# 按空格分割每行文本为单词列表
return [line.split() for line in lines]
elif token == 'char':
# 将每行文本转换为字符列表
return [list(line) for line in lines]
else:
print('错误:未知词元类型:' + token)
def count_corpus(tokens):
"""
统计词元出现频率
参数:
tokens (list): 词元列表 [词元数 x 1] 或嵌套列表 [序列数 x 词元数]
返回:
counter (Counter): 词元频率统计对象
"""
# 处理嵌套列表情况,将所有词元展平
if len(tokens) == 0 or isinstance(tokens[0], list):
# tokens是二维列表,需要展平
tokens = [token for line in tokens for token in line]
return Counter(tokens)
class Vocab:
"""
文本词汇表类,用于管理词元到索引的映射
"""
def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
"""
初始化词汇表
参数:
tokens (list): 词元列表 [词元数 x 1] 或嵌套列表 [序列数 x 词元数]
min_freq (int): 最小词频阈值 [标量],低于此频率的词元将被忽略
reserved_tokens (list): 保留词元列表 [保留词元数 x 1],如特殊标记
"""
if tokens is None:
tokens = []
if reserved_tokens is None:
reserved_tokens = []
# 统计词元出现频率并按频率降序排列</