基于RNN和Transformer的词级语言建模 代码分析 PositionalEncoding

发布于:2024-06-03 ⋅ 阅读:(93) ⋅ 点赞:(0)

基于RNN和Transformer的词级语言建模 代码分析 PositionalEncoding

flyfish

Word-level Language Modeling using RNN and Transformer

word_language_model

PyTorch 提供的 word_language_model 示例展示了如何使用循环神经网络RNN(GRU或LSTM)和 Transformer 模型进行词级语言建模 。默认情况下,训练使用Wikitext-2数据集,generate.py可以使用训练好的模型来生成新文本。

源码地址
https://github.com/pytorch/examples/tree/main/word_language_model

文件:model.py

在 Transformer 模型中,位置编码(Positional Encoding)是用来引入序列中位置信息的一种方法,因为 Transformer 本身没有内置的顺序感知能力。位置编码将位置信息添加到输入嵌入中,从而使模型能够利用单词的顺序信息。

可视化
https://flyfish.blog.csdn.net/article/details/137259966

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# 示例用法
d_model = 512
dropout = 0.1
max_len = 100

pos_encoder = PositionalEncoding(d_model, dropout, max_len)
input_tensor = torch.zeros(20, 32, d_model)  # (sequence_length, batch_size, d_model)
output = pos_encoder(input_tensor)

print(output.shape)  # 输出形状与输入形状相同

torch.Size([20, 32, 512])

为什么使用加法而不是拼接
在位置编码的实现中,通常使用加法而不是拼接(concatenation)有以下几个原因:

维度一致性:

加法:位置编码和输入嵌入的维度是相同的,二者逐元素相加后仍保持相同的维度。
拼接:拼接操作会改变张量的维度,这意味着需要额外的处理步骤来适应这种变化。模型的输入层和随后的层可能需要适应这种维度的变化。
模型复杂度:

加法:保留了输入嵌入的维度,模型结构简单,计算效率高。
拼接:增加了模型的输入维度,可能导致更多的参数和更高的计算复杂度。
信息融合:

加法:通过加法操作,输入嵌入和位置编码的每个元素都得到结合,这使得模型在每个维度上都能感知到位置信息。
拼接:通过拼接操作,位置信息和输入嵌入信息是分开的,需要后续层来融合这两部分信息,融合过程可能不如加法操作直观。

正弦和余弦函数具有不同的周期性,正弦函数是奇函数,而余弦函数是偶函数。通过同时使用正弦和余弦函数,我们可以引入不同频率的周期性模式到位置编码中。这样,模型可以同时学习到不同尺度和不同频率的位置信息,从而更好地理解序列中不同位置之间的关系。


网站公告

今日签到

点亮在社区的每一天
去签到