第2天:认识LSTM

发布于:2025-06-07 ⋅ 阅读:(18) ⋅ 点赞:(0)

目标

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架: pytorch

(二)具体步骤
1. 什么是LSTM

LSTM(Long Short-Term Memory,长短期记忆网络)是一种特殊的循环神经网络(RNN),专门设计来解决传统RNN在处理长序列时遇到的梯度消失问题。

📖 LSTM的发展背景

传统RNN在处理长序列时面临两个主要问题:

  • 梯度消失:随着序列长度增加,早期信息的梯度会急剧衰减
  • 梯度爆炸:梯度可能变得过大,导致训练不稳定
    LSTM通过引入"门控机制"和"细胞状态"来解决这些问题。

🔧 LSTM的核心组件

LSTM单元包含三个门和一个细胞状态:

1. 遗忘门(Forget Gate)

f_t = σ(W_f · [h_{t-1}, x_t] + b_f)
  • 作用:决定从细胞状态中丢弃什么信息
  • 输出:0到1之间的值,0表示完全遗忘,1表示完全保留

2. 输入门(Input Gate)

i_t = σ(W_i · [h_{t-1}, x_t] + b_i)
C̃_t = tanh(W_C · [h_{t-1}, x_t] + b_C)
  • 作用:决定什么新信息被存储在细胞状态中
  • 两部分:决定更新什么值 + 创建候选值

3. 输出门(Output Gate)

o_t = σ(W_o · [h_{t-1}, x_t] + b_o)
h_t = o_t * tanh(C_t)
  • 作用:决定输出什么部分的细胞状态

4. 细胞状态(Cell State)

C_t = f_t * C_{t-1} + i_t * C̃_t
  • 作用:LSTM的"记忆",信息可以在其中流动

🎯 LSTM的工作流程

让我用一个形象的比喻来解释:
想象LSTM是一个智能的信息管理系统

  1. 遗忘门像一个"删除键",决定删除哪些过时信息
  2. 输入门像一个"筛选器",决定接收哪些新信息
  3. 细胞状态像一个"主内存",存储重要信息
  4. 输出门像一个"发布器",决定输出什么信息

📊 LSTM vs 传统RNN对比

特征 传统RNN LSTM
记忆能力 短期记忆 长短期记忆
梯度问题 梯度消失严重 有效缓解
参数数量 较少 较多(约4倍)
训练复杂度 简单 复杂
长序列处理 困难 擅长
2. 网络结构
import torch  
import torch.nn as nn  
  
class SimpleLSTM(nn.Module):  
    def __init__(self, input_size, hidden_size, num_layers, output_size):  
        """  
        类初始化  
        :param input_size: 每个时间步的输入特征维度  
        :param hidden_size:  LSTM隐藏状态的维度,也决定了LSTM内部门控单元的大小  
        :param num_layers: LSTM的层数  
        :param output_size: 最终输出的维度  
        """        super(SimpleLSTM, self).__init__()  
  
        # 定义LSTM层  
        # 其中batch_first=True:指定输入张量的格式为(batch_size, seq_len, input_size)  
        # 如果不设置,默认格式是(seq_len, batch_size, input_size)  
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)  
        # 定义一个线性层,将LSTM输出映射到期望的输出维度  
        self.fc = nn.Linear(hidden_size, output_size)  
  
    def forward(self, x):  
        # LSTMn层的前向传播,默认返回output和(hidden, cell_state)  
        # lstm_out:shape(batch_size, seq_len, hidden_size)        # hn:最终的隐藏状态,形状为(num_layers, batch_size, hidden_size)  
        # cn:最终的记忆状态,形状为(num_layers, batch_size, hidden_size)与hn相同  
        lstm_out, (hn, cn) = self.lstm(x)  
  
        # 取最后一个时间步输出  
        lstm_out = lstm_out[:, -1, :]  
  
        # 通过全连接层将LSTM输出映射到输出维度  
        output = self.fc(lstm_out)  
  
        return output  
  
# 参数设置  
input_size = 10 # 输入特征的维度  
hidden_size = 20 # LSTM隐藏层的维度  
num_layers = 2 # LSTM的层数  
output_size = 1 #  输出的维度  
  
# 创建模型实例  
model = SimpleLSTM(input_size, hidden_size, num_layers, output_size)  
  
# 打印模型结构  
print(model)  
  
# 示例输入(batch_size, seq_len, input_size)  
x = torch.randn(5, 15, input_size)  # 本例相当于(5, 15, 10)  
  
# 前向传播  
output = model(x)  
# 计算过程如下:  
# 1. 输入:(5, 15, 10)  
# 2. LSTM处理:(5, 15, 10) -> (5, 15, 20)  
# 3. 取最后的时间步: (5, 15, 20) -> (5, 20)  
# 4. 全连接层:(5, 20) -> (5, 1)  
  
# 输出结果  
print("输入shape为:", x.shape)  
print("输出shape为:", output.shape)

image.png

(三)总结
LSTM的典型应用
1. 自然语言处理
  • 机器翻译
  • 情感分析
  • 文本生成
2. 时间序列预测
  • 股票价格预测
  • 天气预报
  • 销售预测
3. 语音识别
  • 语音到文本转换
  • 语音合成
4. 其他序列任务
  • 视频分析
  • 生物序列分析
  • 异常检测

网站公告

今日签到

点亮在社区的每一天
去签到