RWKV 1/2/3/4/5/6

发布于:2024-05-05 ⋅ 阅读:(28) ⋅ 点赞:(0)

RWKV: RNNs Strike Back

RNN / LSTM / GRU

在这里插入图片描述
LSTM在RNN的基础上,引入Gate门控机制,将前一token的信息和当前token的信息进行加权,来缓解长距离token遗忘问题。但仍然无法根本的解决长程遗忘。

在这里插入图片描述
GRU将LSTM的4个门控,缩减到了3个门控:
在这里插入图片描述

总结:RNN/LSTM/GRU这类串行模型在计算时,必须先输入上一个token x t − 1 x_{t-1} xt1得到hidden_states h t − 1 h_{t-1} ht1,才能输入下一个token x t x_{t} xt得到新的hidden_states h t h_{t} ht

  • 好处就是①天然的表达了token之间的时序位置关系!②推理的时候,每次的单个token串行计算,单次推理复杂度O(1),因此内存占用恒定,不随序列长度增加。总推理时间O(N)随序列长度线性增加
  • 坏处就是①长程序列(Long term)的遗忘问题无法解决!②串行计算hidden_states,训练速度慢!

Transformer

在这里插入图片描述
在这里插入图片描述

Attention则从本质上解决了RNN类模型的长程遗忘问题,计算Attention Socre=Q@K的时候,使用矩阵乘法,并行地计算每个tokens之间的相关性,一次性得到所有tokens x 0 x_0 x0 x n x_n xnhidden_states h 0 h_0 h0 h n h_n hn

  • 好处是①使得每个tokens都可以关注到彼此,解决了长距离token的关联问题。②Attention的并行计算hidden_states加快了训练速度!
  • 坏处是①没有办法判断每个token的时序位置关系!(加入PE缓解) ②推理的时候,计算复杂度高O(N^2)!(引入KV Cache缓解)
    在这里插入图片描述

RWKV

在这里插入图片描述

一个形状为 NxM 的矩阵,与另一个形状为 MxP 的矩阵相乘,其运算复杂度来源于乘法操作的次数,时间复杂度为 O(NMP)。

Transformer的高复杂度源自于它需要计算 Q Q Q K T K^T KT 的矩阵乘运算,而这个矩阵的大小取决于输入文本的长度,因此Transformer推理的计算时间复杂度是输入长度的平方级复杂度 O ( N 2 ) O(N^2) O(N2)。而用RNN跑的好处是推理可以不并行计算sequence中的每个token,而是串行的计算每个token,推理复杂度就是O(1),也就是说你可以在显存非常小的设备上跑大模型。当然速度是不会提升的,可能会非常慢,也就是节约显存。因此大模型的硬件限制和部署成本,使用RWKV可以明显降低,甚至可以在PC和手机上部署。

下图中complexity per-layer:每层的时间复杂度(表示);minimum number of sequential operations:最少需要的序列操作数(表示模型推理序列中n个token的并行程度,即需要模型推理多少次才能推理完n个tokens)
在这里插入图片描述

  • RNN Limititions:RNN虽然单次推理1个token时间复杂度 O ( 1 ) O(1) O(1),但无法在时间维度并行化,总时间复杂度 O ( N ) O(N) O(N),限制了可扩展性(scale up)。且在训练长序列容易梯度消失,存在遗忘问题。
    在这里插入图片描述

  • Transformer(Self-Attention)Limititions:Transformer可以单次推理序列全部N个tokens,推理时间复杂度 O ( N 2 ) O(N^2) O(N2),因此长序列推理成本高和内存占用多。
    在这里插入图片描述

  • Attention Free Transformer(AFT):为了降低Attention中 O ( N 2 ) O(N^2) O(N2)计算复杂度的矩阵乘法,AFT将矩阵乘法@替换为元素乘法 ⊙ \odot ,这样计算复杂度从 O ( N 2 ) O(N^2) O(N2)降为了 O ( N ) O(N) O(N)。并对其中的Q做Sigmoid操作,将其中的K与可学习的位置编码 w 相加。AFT的复杂度是因为 w w w 参与到计算中带来的,因此我们可以调整 w w w 的样式来提升AFT的速度。channel-wise就是hidden_dim维度的操作。
    在这里插入图片描述

class AFTFull(nn.Module):
    def __init__(self, max_len, dim, hid_dim=32):
        super().__init__()
        self.max_len = max_len
        self.dim = dim          # token的节点数
        self.hid_dim = hid_dim  # 隐层节点数
        self.wq = nn.Linear(self.dim, self.hid_dim)
        self.wk = nn.Linear(self.dim, self.hid_dim)
        self.wv = nn.Linear(self.dim, self.hid_dim)
        self.ffnn = nn.Linear(self.hid_dim, self.dim)
        self.w = nn.Parameter(torch.Tensor(max_len, max_len))
        nn.init.xavier_uniform_(self.w)
    
    def forward(self, x):
        B, T, _ = x.shape
        Q = self.wq(x).view(B, T, self.hid_dim)
        K = self.wk(x).view(B, T, self.hid_dim)
        V = self.wv(x).view(B, T, self.hid_dim)
        Q_sig = torch.sigmoid(Q)
        temp = torch.exp(self.w) @ torch.mul(torch.exp(K), V)
        weighted = temp / (torch.exp(self.w) @ torch.exp(K))
        Yt = torch.mul(Q_sig, weighted)
        Yt = Yt.view(B, T, self.hid_dim)
        Yt = self.ffnn(Yt)
        return weighted
  • RWKV:就是将Transformer中的AttentionFFN替换为了Time MixingChannel Mining
  • R:作为过去信息的接受程度的接受向量。
  • W:位置权重衰减向量,可训练(learnable)的模型参数。
  • K:key向量,类似传统attention中的K。
  • V:value向量,类似传统attention中的V。

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

  • Time Mixing:
    在这里插入图片描述

  • Channel Mining:
    在这里插入图片描述

像RNN一样解码:
在这里插入图片描述
在这里插入图片描述

import numpy as np
from torch import load as torch_load  # Only for loading the model weights
from tokenizers import Tokenizer

layer_norm = lambda x, w, b : (x - np.mean(x)) / np.std(x) * w + b
exp = np.exp
sigmoid = lambda x : 1/(1 + exp(-x))

def time_mixing(x, last_x, last_num, last_den, decay, bonus, mix_k, mix_v, mix_r, Wk, Wv, Wr, Wout):
    k = Wk @ ( x * mix_k + last_x * (1 - mix_k) )
    v = Wv @ ( x * mix_v + last_x * (1 - mix_v) )
    r = Wr @ ( x * mix_r + last_x * (1 - mix_r) )

    wkv = (last_num + exp(bonus + k) * v) / (last_den + exp(bonus + k))
    rwkv = sigmoid(r) * wkv

    num = exp(-exp(decay)) * last_num + exp(k) * v
    den = exp(-exp(decay)) * last_den + exp(k)
    return Wout @ rwkv, (x,num,den)

def channel_mixing(x, last_x, mix_k, mix_r, Wk, Wr, Wv):
    k = Wk @ ( x * mix_k + last_x * (1 - mix_k) )
    r = Wr @ ( x * mix_r + last_x * (1 - mix_r) )
    vk = Wv @ np.maximum(k, 0)**2
    return sigmoid(r) * vk, x

def RWKV(model, token, state):
    params = lambda prefix : [model[key] for key in model.keys() if key.startswith(prefix)]

    x = params('emb')[0][token]
    x = layer_norm(x, *params('blocks.0.ln0'))

    for i in range(N_LAYER):
        x_ = layer_norm(x, *params(f'blocks.{i}.ln1'))
        dx, state[i][:3] = time_mixing(x_, *state[i][:3], *params(f'blocks.{i}.att'))
        x = x + dx

        x_ = layer_norm(x, *params(f'blocks.{i}.ln2'))
        dx, state[i][3] = channel_mixing(x_, state[i][3], *params(f'blocks.{i}.ffn'))
        x = x + dx

    x = layer_norm(x, *params('ln_out'))
    x = params('head')[0] @ x

    e_x = exp(x-np.max(x))
    probs = e_x / e_x.sum() # Softmax of x

    return probs, state

##########################################################################################################

def sample_probs(probs, temperature=1.0, top_p=0.85):
    sorted_probs = np.sort(probs)[::-1]
    cumulative_probs = np.cumsum(sorted_probs)
    cutoff = sorted_probs[np.argmax(cumulative_probs > top_p)]
    probs[probs < cutoff] = 0
    probs = probs**(1/temperature)
    return np.random.choice(a=len(probs), p=probs/np.sum(probs))


# Available at https://huggingface.co/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth
# 这个模型文件821mb,还不算太大
MODEL_FILE = '/data/rwkv/RWKV-4-Pile-430M-20220808-8066.pth'
N_LAYER = 24
N_EMBD = 1024

print(f'\nLoading {MODEL_FILE}')
weights = torch_load(MODEL_FILE, map_location='cpu')
for k in weights.keys():
    if '.time_' in k: weights[k] = weights[k].squeeze()
    weights[k] = weights[k].float().numpy() # convert to f32 type


# Available at https://github.com/BlinkDL/ChatRWKV/blob/main/20B_tokenizer.json
tokenizer = Tokenizer.from_file("/data/rwkv/20B_tokenizer.json")

print(f'\nPreprocessing context')

context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

state = np.zeros((N_LAYER, 4, N_EMBD), dtype=np.float32)
for token in tokenizer.encode(context).ids:
    probs, state = RWKV(weights, token, state)

print(context, end="")
for i in range(100):
    token = sample_probs(probs)
    print(tokenizer.decode([token]), end="", flush=True)
    probs, state = RWKV(weights, token, state)

网站公告

今日签到

点亮在社区的每一天
去签到