从代码学习深度学习 - 序列到序列学习 GRU编解码器 PyTorch 版

发布于:2025-04-09 ⋅ 阅读:(86) ⋅ 点赞:(0)


前言

Seq2Seq 模型的核心思想是将一个输入序列(例如英语句子)通过编码器(Encoder)转化为一个固定长度的上下文向量,再由解码器(Decoder)根据该向量生成目标序列(例如法语句子)。这种编码-解码的架构最初由 RNN 实现,后来发展出 LSTM 和 Transformer 等变种。在本文中,我们将聚焦于基于 RNN 的经典实现,并通过 PyTorch 代码逐步拆解其关键组件。

本文的代码来源于一个完整的机器翻译任务示例,数据集为英语-法语翻译对。我们将从数据加载与预处理开始,逐步构建编码器和解码器,最后通过 BLEU 分数评估翻译效果。所有代码都经过注释,确保易于理解,同时保留了附件中的完整性。

让我们开始吧!


一、数据加载与预处理

Seq2Seq 模型的第一步是准备数据。我们需要将原始的英语-法语翻译对数据加载到内存中,并对其进行预处理和词元化(tokenization),以便后续输入到模型中。以下是相关代码及其解释:

1.1 读取数据

from collections import Counter  # 用于词频统计
import torch  # PyTorch 核心库
from torch.utils import data  # PyTorch 数据加载工具
import numpy as np  # NumPy 用于数组操作

def read_data_nmt():
    """
    载入“英语-法语”数据集
    
    返回值:
        str: 文件内容的完整字符串
    """
    with open('fra.txt', 'r', encoding='utf-8') as f:
        return f.read()

read_data_nmt 函数简单地读取名为 fra.txt 的文件,该文件包含英语和法语的翻译对,每行以制表符分隔。它返回整个文件的字符串内容,为后续处理奠定基础。

1.2 预处理数据

def preprocess_nmt(text):
    """
    预处理“英语-法语”数据集
    
    参数:
        text (str): 输入的原始文本字符串
    
    返回值:
        str: 处理后的文本字符串
    """
    def no_space(char, prev_char):
        """
        判断当前字符是否需要前置空格
        """
        return char in set(',.!?') and prev_char != ' '

    # 使用空格替换不间断空格(\u202f)和非断行空格(\xa0),并转换为小写
    text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
    
    # 在单词和标点符号之间插入空格
    out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
           for i, char in enumerate(text)]
           
    return ''.join(out)

preprocess_nmt 函数对文本进行标准化处理:

  1. 将特殊空格字符替换为普通空格,并将所有字符转换为小写。
  2. 在标点符号(如逗号、句号)前插入空格,便于后续按空格分割词元。这种处理确保标点符号被视为独立的词元,而不是粘附在单词上。

1.3 词元化

def tokenize_nmt(text, num_examples=None):
    """
    词元化“英语-法语”数据集
    
    参数:
        text (str): 输入的文本字符串,每行包含英语和法语句子,用制表符分隔
        num_examples (int, optional): 最大处理样本数,默认值为 None 表示处理全部
    
    返回值:
        tuple: 包含两个列表的元组
            - source (list): 英语句子词元列表
            - target (list): 法语句子词元列表
    """
    source, target = [], []
    for i, line in enumerate(text.split('\n')):
        if num_examples and i > num_examples:
            break
        parts = line.split('\t')
        if len(parts) == 2:
            source.append(parts[0].split(' '))
            target.append(parts[1].split(' '))
    return source, target

tokenize_nmt 函数将预处理后的文本按行分割,并进一步将每行按制表符分为英语和法语部分,然后按空格分割成词元列表。它返回两个列表:source(英语词元列表)和 target(法语词元列表)。

1.4 词频统计

def count_corpus(tokens):
    """
    统计词元的频率
    
    参数:
        tokens: 词元列表,可以是一维或二维列表
    
    返回值:
        Counter: Counter 对象,统计每个词元的出现次数
    """
    if not tokens:
        return Counter()
    if isinstance(tokens[0], list):
        flattened_tokens = [token for sublist in tokens for token in sublist]
    else:
        flattened_tokens = tokens
    return Counter(flattened_tokens)

count_corpus 函数使用 Counter 类统计词元的出现频率,支持一维和二维列表输入。它是构建词汇表的基础工具。

1.5 构建词汇表

class Vocab:
    """文本词表类,用于管理词元及其索引的映射关系"""

    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        """初始化词表"""
        self.tokens = tokens if tokens is not None else []
        self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
        counter = self._count_corpus(self.tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        self.idx_to_token = ['<unk>'] + self.reserved_tokens
        self.token_to_idx = {
   token: idx for idx, token in enumerate(self.idx_to_token)}
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1

    @staticmethod
    def _count_corpus(tokens):
        """统计词元频率"""
        if not tokens:
            return Counter()
        if isinstance(tokens[0], list):
            tokens = [token for sublist in tokens for token in sublist]
        return Counter(tokens)

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self[token] for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

    @property
    def unk(self):
        return 0

    @property
    def token_freqs(self):
        return self._token_freqs

Vocab 类用于构建词汇表并管理词元与索引之间的映射:

  • 初始化时接受词元列表、最小频率阈值和预留特殊词元(如 <pad><bos><eos>)。
  • 内部使用 count_corpus 统计词频,并按频率排序。
  • 提供 __getitem__to_tokens 方法,分别用于词元到索引和索引到词元的转换。
  • <unk> 表示未知词元,默认索引为 0。

1.6 截断与填充

def truncate_pad(line, num_steps, padding_token):
    """
    截断或填充文本序列
    
    参数:
        line (list): 输入的文本序列(词元列表)
        num_steps (int): 目标序列长度
        padding_token (str): 用于填充的标记
    
    返回值:
        list: 截断或填充后的序列,长度为 num_steps
    """
    if len(line) > num_steps:
        return line[:num_steps]
    return line + [padding_token] * (num_steps - len(line))

truncate_pad 函数确保所有序列长度一致:

  • 如果序列长度超过 num_steps,则截断。
  • 如果不足,则用 padding_token(通常是 <pad>)填充。

1.7 转换为张量

def build_array_nmt(lines, vocab, num_steps):
    """
    将机器翻译的文本序列转换为小批量
    
    参数:
        lines (list): 文本序列列表,每个元素是一个词元列表
        vocab (dict): 词汇表,将词元映射为索引
        num_steps (int): 目标序列长度
    
    返回值:
        tuple: 包含两个元素的元组
            - array (torch.Tensor): 转换后的张量,形状为 (样本数, num_steps)
            - valid_len (np.ndarray): 每个序列的有效长度,形状为 (样本数,)
    """
    lines =

网站公告

今日签到

点亮在社区的每一天
去签到