import torch
import numpy as np
class MaxState(torch.nn.Module):
def __init__(self, hidden_dim, heads, win):
super(MaxState, self).__init__()
assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
self.head_size = hidden_dim // heads
self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head_num = heads
self.win = win
self.hidden = hidden_dim
self.mask = torch.triu(torch.ones([win, win])).to(device)
self.layer_nor = torch.nn.LayerNorm(hidden_dim)
def forward(self, input_data, state=None):
b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win
window = torch.ones([1, w]).to(device)
out = self.head(input_data)
out = out.unsqueeze(-1) @ window
out = out.permute([0, 2, 1, 3])
one_list = []
if state is None:
state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")
state = state.to(device)
for i in range(0, s, w):
state.reshape([state.shape[0], -1])
j = w + i
one = out[:, :, i:j]
_, _, r, c = one.shape
if r != self.win:
one = torch.where(self.mask[:r, :] == 1, one, torch.Tensor([-float('inf')]).to(device))
else:
one = torch.where(self.mask == 1, one, torch.Tensor([-float('inf')]).to(device))
if i == 0:
one = torch.concat([one, state @ window], axis=2)
state, _ = torch.max(one, axis=2, keepdim=True)
else:
state1, _ = torch.max(one, axis=2, keepdim=True)
state1 = self.state(state1.permute([0, 3, 1, 2]).reshape([state1.shape[0], -1, state1.shape[1]]))
state = state1.permute([0, 2, 1]).unsqueeze(-2) + state
one = torch.concat([one, state], axis=2)
state, _ = torch.max(one, axis=2, keepdim=True)
one = state.reshape([b, k, h, w])
state = state[..., -1:]
if r != self.win:
one = one[..., :r]
one = one.permute([0, 3, 1, 2])
one_list.append(one)
out = torch.concat(one_list, 1)
out = out.reshape([b, s, -1])
return out, state
class FeedForward(torch.nn.Module):
def __init__(self, hidden_size):
super(FeedForward, self).__init__()
self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
self.relu = torch.nn.ReLU()
def forward(self, x):
x1 = self.ffn1(x)
x2 = self.relu(self.gate(x))
x = x1 * x2
x = self.ffn2(x)
return x
class MemoryBlock(torch.nn.Module):
def __init__(self, hidden_dim):
super(MemoryBlock, self).__init__()
self.fc = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
torch.nn.init.xavier_uniform_(self.fc)
self.mem = torch.eye(hidden_dim).to(device)
def forward(self, x):
x = x @ (self.fc + self.mem)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super(DecoderLayer, self).__init__()
self.self_attention = MaxState(hidden_size, num_heads, 8)
self.ffn = FeedForward(hidden_size)
self.layer_norm = torch.nn.LayerNorm(hidden_size)
def forward(self, x, state=None, seq_len=None):
x1, state = self.self_attention(x, state)
x = self.layer_norm(self.ffn(x1) + x)
return x, state
class SamOut(torch.nn.Module):
def __init__(self, voc_size, hidden_size, num_heads, num_layers):
super(SamOut, self).__init__()
self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
self.pos = torch.nn.Embedding(1024, hidden_size)
self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
self.head = torch.nn.Linear(hidden_size, voc_size, False)
self.layer_nor = torch.nn.LayerNorm(hidden_size)
self.down = torch.nn.ModuleList(
[torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])
self.layer_norm = torch.nn.LayerNorm(hidden_size)
def state_forward(self, state, pos, x):
if state is None:
state = [None] * len(self.decoder_layers)
i = 0
for ii, decoder_layer in enumerate(self.decoder_layers):
x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))
x1, state[i] = decoder_layer(x, state[i])
x = x1 + x
i += 1
return x, state
def pos_forward(self, x):
if x.shape[1] >= 1024:
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos
else:
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
return pos
def forward(self, x0, x1, x2):
x0, state0 = self.one_forward(x0)
x1, state1 = self.one_forward(x1)
x2, state2 = self.one_forward(x2)
return x0, x1, x2
def one_forward(self, x, state=None, seq_len=None):
x = self.em(x)
pos = self.pos_forward(x)
x, state = self.state_forward(state, pos, x)
return self.head(x), state
device = "cuda"
if __name__ == '__main__':
net = SamOut(235, 256, 16, 4)
net.to(device)
net(torch.randint(0, 200, [2, 8 * 13]).to(device), torch.randint(0, 200, [2, 4 * 13]).to(device),
torch.randint(0, 200, [2, 2 * 13]).to(device))
训练
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from glob import glob
from tqdm import tqdm
from model_d import SamOut
import json
def train():
train_data, test_data, voc = gen_voc()
net = SamOut(len(voc), 512, 32, 8)
net.to("cuda")
opt = torch.optim.Adam(params=net.parameters(), lr=0.00003)
loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)
loss_func1 = torch.nn.CrossEntropyLoss(ignore_index=3)
loss_func2 = torch.nn.CrossEntropyLoss(ignore_index=3)
bar = tqdm(range(100))
steps = 0
epoch_loss = []
label_index0 = np.array([i for i in range(1, 256) if i % 2 == 1])
input_index0 = label_index0 - 1
label_index1 = label_index0[::2]
input_index1 = label_index1 - 1
label_index2 = label_index1[::2]
input_index2 = label_index2 - 1
for epoch in bar:
np.random.shuffle(train_data)
loss_list = []
for i in range(0, len(train_data), 100):
j = i + 100
data = train_data[i:j]
out0, out1, out2 = net(torch.Tensor(data)[:, input_index0].int().to("cuda"),
torch.Tensor(data)[:, input_index1].int().to("cuda"),
torch.Tensor(data)[:, input_index2].int().to("cuda"))
loss = loss_func0(out0.reshape([-1, out0.shape[-1]]),
torch.Tensor(data)[:, label_index0].reshape([-1]).long().to("cuda"))/3
loss += loss_func1(out1.reshape([-1, out1.shape[-1]]),
torch.Tensor(data)[:, label_index1].reshape([-1]).long().to("cuda"))/3
loss += loss_func2(out2.reshape([-1, out2.shape[-1]]),
torch.Tensor(data)[:, label_index2].reshape([-1]).long().to("cuda"))/3
loss_list.append(loss.item())
bar.set_description("epoch___{}____loss___{:.6f}____steps___{}".format(epoch, np.mean(loss_list), steps))
opt.zero_grad()
loss.backward()
opt.step()
steps += 100
if steps % 8000 == 0:
torch.save(net.state_dict(), "model.pth")
epoch_loss.append(np.mean(loss_list))
pd.to_pickle(epoch_loss, "loss8")
def val():
train_data, test_data, voc = gen_voc()
net = SamOut(len(voc), 256, 16, 4)
net.to("cuda")
net.load_state_dict(torch.load("model.pth"))
net.eval()
loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)
bar = tqdm(range(1))
steps = 0
for epoch in bar:
np.random.shuffle(test_data)
for i in range(0, len(test_data), 100):
j = i + 100
data = test_data[i:j]
out, _ = net(torch.Tensor(data)[:, :-1].int().to("cuda"))
loss = loss_func0(out.reshape([-1, out.shape[-1]]),
torch.Tensor(data)[:, 1:].reshape([-1]).long().to("cuda"))
bar.set_description("epoch___{}____loss___{:.6f}____steps___{}".format(epoch, loss.item(), steps))
def gen_voc():
paths = glob("train/*") + glob("test/*")
q_list = []
voc_dict = set()
len_list = []
for path in tqdm(paths):
with open(path, "r", encoding="utf-8") as f:
data = f.read()
data = json.loads(data)
if "answer" not in data:
q = list(data["question"])
else:
q = list(data["question"]) + ["<|bos|>"] + list(data["answer"])
voc_dict.update(set(q))
q = ["<|sos|>"] + q + ["<|eos|>"]
len_list.append(len(q))
q_list.append(q)
voc = ["<|sss|>", "<|sos|>", "<|eos|>", "<|pos|>"] + sorted(voc_dict)
add_voc = ["<|pos_{}|>".format(i) for i in range(18)]
voc += add_voc
add_voc_list = []
for _ in tqdm(range(len(voc) + 6)):
while add_voc in add_voc_list:
np.random.shuffle(add_voc)
add_voc_list.append(np.array(add_voc).copy().tolist())
padding_str = {v: i for i, v in zip(add_voc_list, voc[2:])}
train_list = []
test_list = []
for i in tqdm(q_list):
if len(i) < 256 - 18:
if "<|bos|>" in i:
o = padding_str.get(i[1])[len(i) % 16 - 16:] + i
train_list.append([voc.index(j) for j in o] + (256 - len(o)) * [voc.index("<|pos|>")])
else:
test_list.append([voc.index(j) for j in i] + (256 - len(i)) * [voc.index("<|pos|>")])
return train_list, test_list, voc
def eval_data():
train_data, test_data, voc = gen_voc()
net = SamOut(len(voc), 256, 16, 4)
net.to("cuda")
net.load_state_dict(torch.load("model.pth"))
net.eval()
for data in test_data:
data = data[:data.index(2)] + [voc.index("<|bos|>")]
for _ in range(10):
out, _ = net(torch.Tensor(data).reshape([1, -1]).int().to("cuda"))
data += [torch.argmax(out, -1)[:, -1].item()]
print("".join([voc[i] for i in data]))
def show_loss():
loss0 = pd.read_pickle("loss0")
loss1 = pd.read_pickle("loss1")
loss2 = pd.read_pickle("loss2")
loss3 = pd.read_pickle("loss3")
loss4 = pd.read_pickle("loss4")
loss5 = pd.read_pickle("loss5")
loss6 = pd.read_pickle("loss6")
loss7 = pd.read_pickle("loss7")
loss8 = pd.read_pickle("loss8")
plt.plot(loss0)
plt.plot(loss1)
plt.plot(loss2)
plt.plot(loss3)
plt.plot(loss4)
plt.plot(loss5)
plt.plot(loss6)
plt.plot(np.array(loss7) / 3)
plt.plot(np.array(loss8) )
plt.legend(["sin", "+", "nor", "state", "pos", "s", "mem", "mm", "8"])
plt.show()
if __name__ == '__main__':
show_loss()