python-pytorch seq2seq+luong dot attention笔记1.0.0

发布于:2024-05-20 ⋅ 阅读:(186) ⋅ 点赞:(0)

可复用部分

主要将数据弄成如下格式:
seq_example = [“你认识我吗”, “你住在哪里”, “你知道我的名字吗”, “你是谁”, “你会唱歌吗”, “你是张学友吗”]
seq_answer = [“当然认识”, “我住在成都”, “我不知道”, “我是机器人”, “我不会”, “肯定不是”]

同时设定embedding_size 、vocab_size、 hidden_size、 seq_len,实现word2index、index2word、encoder_input、decoder_input、target_input

代码如下:

# def getAQ():
#     ask=[]
#     answer=[]
#     with open("./data/flink.txt","r",encoding="utf-8") as f:
#         lines=f.readlines()
#         for line in lines:
#             ask.append(line.split("----")[0])
#             answer.append(line.split("----")[1].replace("\n",""))
#     return answer,ask

# seq_answer,seq_example=getAQ()



import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import os
from tqdm import tqdm
 
seq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "你有父母吗"]
seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "我没有父母"]



# 所有词
example_cut = []
answer_cut = []
word_all = []
# 分词
for i in seq_example:
    example_cut.append(list(jieba.cut(i)))
for i in seq_answer:
    answer_cut.append(list(jieba.cut(i)))
#   所有词
for i in example_cut + answer_cut:
    for word in i:
        if word not in word_all:
            word_all.append(word)
# 词语索引表
word2index = {
   w: i+3 for i, w in enumerate(word_all)}
# 补全
word2index['PAD'] = 0
# 句子开始
word2index['SOS'] = 1
# 句子结束
word2index['EOS'] = 2
index2word = {
   value: key for key, value in word2index.items()}
# 一些参数
vocab_size = len(word2index)
seq_length = max([len(i) for i in example_cut + answer_cut]) + 1

网站公告

今日签到

点亮在社区的每一天
去签到