加速samout

发布于:2024-08-22 ⋅ 阅读:(82) ⋅ 点赞:(0)
import torch
import numpy as np


class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads, win):
        super(MaxState, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head_num = heads
        self.win = win
        self.hidden = hidden_dim
        self.mask = torch.triu(torch.ones([win, win])).to(device)
        self.layer_nor = torch.nn.LayerNorm(hidden_dim)

    def forward(self, input_data, state=None):
        # self.head.to(device)
        b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win

        window = torch.ones([1, w]).to(device)

        out = self.head(input_data)

        out = out.unsqueeze(-1) @ window

        out = out.permute([0, 2, 1, 3])

        one_list = []
        if state is None:
            state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")
            state = state.to(device)
        for i in range(0, s, w):

            state.reshape([state.shape[0], -1])
            j = w + i
            one = out[:, :, i:j]
            _, _, r, c = one.shape
            if r != self.win:

                one = torch.where(self.mask[:r, :] == 1, one, torch.Tensor([-float('inf')]).to(device))

            else:
                one = torch.where(self.mask == 1, one, torch.Tensor([-float('inf')]).to(device))

            if i == 0:

                one = torch.concat([one, state @ window], axis=2)
                state, _ = torch.max(one, axis=2, keepdim=True)


            else:

                state1, _ = torch.max(one, axis=2, keepdim=True)

                # state = torch.sin(self.state(state1.reshape([state1.shape[0], -1]))*state.reshape([state.shape[0], -1]))
                state1 = self.state(state1.permute([0, 3, 1, 2]).reshape([state1.shape[0], -1, state1.shape[1]]))
                state = state1.permute([0, 2, 1]).unsqueeze(-2) + state
                # state = state.reshape(state1.shape)

                one = torch.concat([one, state], axis=2)
                state, _ = torch.max(one, axis=2, keepdim=True)

            one = state.reshape([b, k, h, w])

            state = state[..., -1:]
            if r != self.win:
                one = one[..., :r]

            one = one.permute([0, 3, 1, 2])
            one_list.append(one)

        out = torch.concat(one_list, 1)

        out = out.reshape([b, s, -1])

        return out, state


class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()

        self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
        self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = self.ffn1(x)

        x2 = self.relu(self.gate(x))

        x = x1 * x2

        x = self.ffn2(x)
        return x


class MemoryBlock(torch.nn.Module):
    def __init__(self, hidden_dim):
        super(MemoryBlock, self).__init__()

        # 使用Xavier初始化权重
        self.fc = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.fc)

        # self.mem = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        self.mem = torch.eye(hidden_dim).to(device)
        # torch.nn.init.xavier_uniform_(self.mem)

        # self.sig = torch.nn.Sigmoid()

    def forward(self, x):
        x = x @ (self.fc + self.mem)
        # x = self.sig(x)
        return x


class DecoderLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(DecoderLayer, self).__init__()
        # self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
        self.self_attention = MaxState(hidden_size, num_heads, 8)
        self.ffn = FeedForward(hidden_size)
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

    def forward(self, x, state=None, seq_len=None):
        x1, state = self.self_attention(x, state)
        x = self.layer_norm(self.ffn(x1) + x)

        return x, state


class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        super(SamOut, self).__init__()
        self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
        self.pos = torch.nn.Embedding(1024, hidden_size)

        self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
        self.head = torch.nn.Linear(hidden_size, voc_size, False)
        # self.head_state = torch.nn.Linear(hidden_size, num_layers, False)
        self.layer_nor = torch.nn.LayerNorm(hidden_size)
        self.down = torch.nn.ModuleList(
            [torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

    def state_forward(self, state, pos, x):
        if state is None:
            state = [None] * len(self.decoder_layers)
        i = 0
        for ii, decoder_layer in enumerate(self.decoder_layers):
            x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))
            x1, state[i] = decoder_layer(x, state[i])
            x = x1 + x
            i += 1
        return x, state

    def pos_forward(self, x):
        if x.shape[1] >= 1024:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos

        else:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
        return pos

    def forward(self, x0, x1, x2):
        x0, state0 = self.one_forward(x0)
        x1, state1 = self.one_forward(x1)
        x2, state2 = self.one_forward(x2)
        return x0, x1, x2

    def one_forward(self, x, state=None, seq_len=None):
        x = self.em(x)

        pos = self.pos_forward(x)

        x, state = self.state_forward(state, pos, x)

        return self.head(x), state


device = "cuda"
if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net.to(device)
    net(torch.randint(0, 200, [2, 8 * 13]).to(device), torch.randint(0, 200, [2, 4 * 13]).to(device),
        torch.randint(0, 200, [2, 2 * 13]).to(device))
    #

训练

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from glob import glob

from tqdm import tqdm
from model_d import SamOut
import json


def train():
    train_data, test_data, voc = gen_voc()

    net = SamOut(len(voc), 512, 32, 8)
    # net.load_state_dict(torch.load("model.pth"))
    net.to("cuda")
    opt = torch.optim.Adam(params=net.parameters(), lr=0.00003)
    loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)
    loss_func1 = torch.nn.CrossEntropyLoss(ignore_index=3)
    loss_func2 = torch.nn.CrossEntropyLoss(ignore_index=3)


    bar = tqdm(range(100))
    steps = 0
    epoch_loss = []
    label_index0 = np.array([i for i in range(1, 256) if i % 2 == 1])
    input_index0 = label_index0 - 1
    label_index1 = label_index0[::2]
    input_index1 = label_index1 - 1
    label_index2 = label_index1[::2]
    input_index2 = label_index2 - 1
    for epoch in bar:
        np.random.shuffle(train_data)
        loss_list = []
        for i in range(0, len(train_data), 100):
            j = i + 100
            data = train_data[i:j]

            out0, out1, out2 = net(torch.Tensor(data)[:, input_index0].int().to("cuda"),
                                   torch.Tensor(data)[:, input_index1].int().to("cuda"),
                                   torch.Tensor(data)[:, input_index2].int().to("cuda"))
            loss = loss_func0(out0.reshape([-1, out0.shape[-1]]),
                              torch.Tensor(data)[:, label_index0].reshape([-1]).long().to("cuda"))/3
            loss += loss_func1(out1.reshape([-1, out1.shape[-1]]),
                               torch.Tensor(data)[:, label_index1].reshape([-1]).long().to("cuda"))/3
            loss += loss_func2(out2.reshape([-1, out2.shape[-1]]),
                               torch.Tensor(data)[:, label_index2].reshape([-1]).long().to("cuda"))/3

            # loss += loss_funcs0(s0.reshape([-1, s0.shape[-1]]),
            #                     torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7]] * s0.shape[0]).reshape(-1).long().to("cuda"))
            # loss += loss_funcs1(s1.reshape([-1, s1.shape[-1]]),
            #                     torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7]] * s1.shape[0]).reshape(-1).long().to("cuda"))
            # loss += loss_funcs2(s2.reshape([-1, s2.shape[-1]]),
            #                     torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7]] * s2.shape[0]).reshape(-1).long().to("cuda"))

            loss_list.append(loss.item())
            bar.set_description("epoch___{}____loss___{:.6f}____steps___{}".format(epoch, np.mean(loss_list), steps))
            opt.zero_grad()
            loss.backward()
            opt.step()
            steps += 100
            if steps % 8000 == 0:
                torch.save(net.state_dict(), "model.pth")
        epoch_loss.append(np.mean(loss_list))
        pd.to_pickle(epoch_loss, "loss8")


def val():
    train_data, test_data, voc = gen_voc()

    net = SamOut(len(voc), 256, 16, 4)
    net.to("cuda")

    net.load_state_dict(torch.load("model.pth"))
    net.eval()

    loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)

    bar = tqdm(range(1))
    steps = 0
    for epoch in bar:
        np.random.shuffle(test_data)
        for i in range(0, len(test_data), 100):
            j = i + 100
            data = test_data[i:j]
            out, _ = net(torch.Tensor(data)[:, :-1].int().to("cuda"))
            loss = loss_func0(out.reshape([-1, out.shape[-1]]),
                              torch.Tensor(data)[:, 1:].reshape([-1]).long().to("cuda"))

            bar.set_description("epoch___{}____loss___{:.6f}____steps___{}".format(epoch, loss.item(), steps))


def gen_voc():
    paths = glob("train/*") + glob("test/*")
    q_list = []
    voc_dict = set()
    len_list = []
    for path in tqdm(paths):
        with open(path, "r", encoding="utf-8") as f:
            data = f.read()
            data = json.loads(data)
        if "answer" not in data:
            q = list(data["question"])
        else:

            q = list(data["question"]) + ["<|bos|>"] + list(data["answer"])
        voc_dict.update(set(q))
        q = ["<|sos|>"] + q + ["<|eos|>"]
        len_list.append(len(q))

        q_list.append(q)
    voc = ["<|sss|>", "<|sos|>", "<|eos|>", "<|pos|>"] + sorted(voc_dict)
    add_voc = ["<|pos_{}|>".format(i) for i in range(18)]
    voc += add_voc
    add_voc_list = []
    for _ in tqdm(range(len(voc) + 6)):
        while add_voc in add_voc_list:
            np.random.shuffle(add_voc)

        add_voc_list.append(np.array(add_voc).copy().tolist())
    padding_str = {v: i for i, v in zip(add_voc_list, voc[2:])}
    # plt.plot(sorted(len_list))
    # plt.show()
    train_list = []
    test_list = []

    for i in tqdm(q_list):
        if len(i) < 256 - 18:
            if "<|bos|>" in i:
                o = padding_str.get(i[1])[len(i) % 16 - 16:] + i

                train_list.append([voc.index(j) for j in o] + (256 - len(o)) * [voc.index("<|pos|>")])
            else:
                test_list.append([voc.index(j) for j in i] + (256 - len(i)) * [voc.index("<|pos|>")])
    return train_list, test_list, voc


def eval_data():
    train_data, test_data, voc = gen_voc()

    net = SamOut(len(voc), 256, 16, 4)
    net.to("cuda")

    net.load_state_dict(torch.load("model.pth"))
    net.eval()

    for data in test_data:
        data = data[:data.index(2)] + [voc.index("<|bos|>")]
        for _ in range(10):
            out, _ = net(torch.Tensor(data).reshape([1, -1]).int().to("cuda"))
            data += [torch.argmax(out, -1)[:, -1].item()]
            print("".join([voc[i] for i in data]))


def show_loss():
    loss0 = pd.read_pickle("loss0")
    loss1 = pd.read_pickle("loss1")
    loss2 = pd.read_pickle("loss2")
    loss3 = pd.read_pickle("loss3")
    loss4 = pd.read_pickle("loss4")
    loss5 = pd.read_pickle("loss5")
    loss6 = pd.read_pickle("loss6")
    loss7 = pd.read_pickle("loss7")
    loss8 = pd.read_pickle("loss8")
    plt.plot(loss0)
    plt.plot(loss1)
    plt.plot(loss2)
    plt.plot(loss3)
    plt.plot(loss4)
    plt.plot(loss5)
    plt.plot(loss6)
    plt.plot(np.array(loss7) / 3)
    plt.plot(np.array(loss8) )

    plt.legend(["sin", "+", "nor", "state", "pos", "s", "mem", "mm", "8"])
    plt.show()


if __name__ == '__main__':
    show_loss()
    # train()
    # val()
    # eval_data()


网站公告

今日签到

点亮在社区的每一天
去签到