【测试】使用python和pytorch框架写一段典型的注意力机制的代码

发布于:2023-09-22 ⋅ 阅读:(95) ⋅ 点赞:(0)

以下是一个简单的注意力机制实现的代码,其中包含了一个Encoder和Decoder,以及注意力机制的实现。

import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

    def forward(self, inputs, hidden):
        outputs, hidden = self.rnn(inputs, hidden)
        return outputs, hidden

    def init_hidden(self, batch_size):
        return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
                torch.zeros(self.num_layers, batch_size, self.hidden_size))

class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, input_size)

    def forward(self, inputs, hidden):
        output, hidden = self.rnn(inputs, hidden)
        output = self.linear(output)
        return output, hidden

    def init_hidden(self, batch_size):
        return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
                torch.zeros(self.num_layers, batch_size, self.hidden_size))

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, hidden, encoder_outputs):
        max_len = encoder_outputs.size(1)
        this_batch_size = encoder_outputs.size(0)

        H = hidden.repeat(max_len, 1, 1).transpose(0, 1)
        encoder_outputs = encoder_outputs.transpose(1, 2)

        attn_energies = self.score(H, encoder_outputs)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

    def score(self, hidden, encoder_outputs): 
        energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], 2))) 
        energy = energy.transpose(1, 2) 
        v = self.v.repeat(encoder_outputs.data.shape[0], 1).unsqueeze(1) 
        energy = torch.bmm(v, energy) 
        return energy.squeeze(1)

class Seq2Seq(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(Seq2Seq, self).__init__()
        self.encoder = Encoder(input_size, hidden_size, num_layers)
        self.decoder = Decoder(output_size, hidden_size, num_layers)
        self.attention = Attention(hidden_size)

    def forward(self, inputs, targets):
        batch_size = inputs.size(0)
        target_len = targets.size(1)

        encoder_outputs, encoder_hidden = self.encoder(inputs, self.encoder.init_hidden(batch_size))
        decoder_hidden = encoder_hidden

        for t in range(target_len):
            decoder_input = targets[:, t, :].unsqueeze(1)
            attn_weights = self.attention(decoder_hidden[0], encoder_outputs)
            context = attn_weights.bmm(encoder_outputs.transpose(1, 2))
            decoder_input = torch.cat([decoder_input, context], 2)

            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

        return decoder_output

使用方法:

  1. 定义模型:
model = Seq2Seq(input_size=100, hidden_size=128, output_size=200, num_layers=1)
  1. 训练模型:
input_data = torch.randn(32, 10, 100) # 输入数据,大小为(32, 10, 100)
target_data = torch.randn(32, 10, 200) # 目标数据,大小为(32, 10, 200)
output = model(input_data, target_data)

以上是一个简单的注意力机制的实现,可以根据自己的实际需求对代码进行修改和调整。