NLP - 使用 transformers 翻译

发布于:2024-04-27 ⋅ 阅读:(24) ⋅ 点赞:(0)
from transformers import AutoTokenizer

#加载编码器
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-ro',
                                          use_fast=True)

print(tokenizer)

#编码试算
tokenizer.batch_encode_plus(
    [['Hello, this one sentence!', 'This is another sentence.']])

PreTrainedTokenizer(name_or_path='Helsinki-NLP/opus-mt-en-ro', vocab_size=59543, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'})
{'input_ids': [[125, 778, 3, 63, 141, 9191, 23, 187, 32, 716, 9191, 2, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

from datasets import load_dataset, load_from_disk

#加载数据
dataset = load_dataset(path='wmt16', name='ro-en')
# dataset = load_from_disk('datas/wmt16/ro-en')

#采样,数据量太大了跑不动
dataset['train'] = dataset['train'].shuffle(1).select(range(20000))
dataset['validation'] = dataset['validation'].shuffle(1).select(range(200))
dataset['test'] = dataset['test'].shuffle(1).select(range(200))


#数据预处理
def preprocess_function(data):
    #取出数据中的en和ro
    en = [ex['en'] for ex in data['translation']]
    ro = [ex['ro'] for ex in data['translation']]

    #源语言直接编码就行了
    data = tokenizer.batch_encode_plus(en, max_length=128, truncation=True)

    #目标语言在特殊模块中编码
    with tokenizer.as_target_tokenizer():
        data['labels'] = tokenizer.batch_encode_plus(
            ro, max_length=128, truncation=True)['input_ids']

    return data


dataset = dataset.map(function=preprocess_function,
                      batched=True,
                      batch_size=1000,
                      num_proc=4,
                      remove_columns=['translation'])

print(dataset['train'][0])

dataset

{'input_ids': [460, 354, 3794, 12, 10677, 20, 5046, 14, 4, 2546, 37, 8, 397, 5551, 30, 10113, 37, 3501, 19814, 18, 8465, 20, 4, 44690, 782, 2, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [902, 576, 2946, 76, 10815, 17, 5098, 14997, 5, 559, 1140, 43, 2434, 6624, 27, 50, 337, 19216, 46, 22174, 17, 2317, 121, 16825, 2, 0]}
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 20000
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 200
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 200
    })
})

#这个函数和下面这个工具类等价,但我也是模仿实现的,不确定有没有出入
#from transformers import DataCollatorForSeq2Seq
#DataCollatorForSeq2Seq(tokenizer, model=model)

import torch


#数据批处理函数
def collate_fn(data):
    #求最长的label
    max_length = max([len(i['labels']) for i in data])

    #把所有的label都补pad到最长
    for i in data:
        pads = [-100] * (max_length - len(i['labels']))
        i['labels'] = i['labels'] + pads

    #把多个数据整合成一个tensor
    data = tokenizer.pad(
        encoded_inputs=data,
        padding=True,
        max_length=None,
        pad_to_multiple_of=None,
        return_tensors='pt',
    )

    #定义decoder_input_ids
    data['decoder_input_ids'] = torch.full_like(data['labels'],
                                                tokenizer.get_vocab()['<pad>'],
                                                dtype=torch.long)
    data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]
    data['decoder_input_ids'][data['decoder_input_ids'] ==
                              -100] = tokenizer.get_vocab()['<pad>']

    return data


data = [{
    'input_ids': [21603, 10, 37, 3719, 13],
    'attention_mask': [1, 1, 1, 1, 1],
    'labels': [10455, 120, 80]
}, {
    'input_ids': [21603, 10, 7086, 8408, 563],
    'attention_mask': [1, 1, 1, 1, 1],
    'labels': [301, 53, 4074, 1669]
}]

collate_fn(data)['decoder_input_ids']

tensor([[59542, 10455,   120,    80],
        [59542,   301,    53,  4074]])

import torch

#数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset['train'],
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
)

for i, data in enumerate(loader):
    break

for k, v in data.items():
    print(k, v.shape, v[:2])

len(loader)

from transformers import AutoModelForSeq2SeqLM, MarianModel

#加载模型
#model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-en-ro')


#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained = MarianModel.from_pretrained(
            'Helsinki-NLP/opus-mt-en-ro')

        self.register_buffer('final_logits_bias',
                             torch.zeros(1, tokenizer.vocab_size))

        self.fc = torch.nn.Linear(512, tokenizer.vocab_size, bias=False)

        #加载预训练模型的参数
        parameters = AutoModelForSeq2SeqLM.from_pretrained(
            'Helsinki-NLP/opus-mt-en-ro')
        self.fc.load_state_dict(parameters.lm_head.state_dict())

        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels, decoder_input_ids):
        logits = self.pretrained(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 decoder_input_ids=decoder_input_ids)
        logits = logits.last_hidden_state

        logits = self.fc(logits) + self.final_logits_bias

        loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())

        return {'loss': loss, 'logits': logits}


model = Model()

#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)

#out = model(**data)
#out['loss'], out['logits'].shape

from datasets import load_metric

#加载评价函数
metric = load_metric(path='sacrebleu')

#试算
metric.compute(predictions=['hello there', 'general kenobi'],
               references=[['hello there'], ['general kenobi']])



测试

#测试
def test():
    model.eval()

    #数据加载器
    loader_test = torch.utils.data.DataLoader(
        dataset=dataset['test'],
        batch_size=8,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True,
    )

    predictions = []
    references = []
    for i, data in enumerate(loader_test):
        #计算
        with torch.no_grad():
            out = model(**data)

        pred = tokenizer.batch_decode(out['logits'].argmax(dim=2))
        label = tokenizer.batch_decode(data['decoder_input_ids'])
        predictions.extend(pred)
        references.extend(label)

        if i % 2 == 0:
            print(i)
            input_ids = tokenizer.decode(data['input_ids'][0])

            print('input_ids=', input_ids)
            print('pred=', pred[0])
            print('label=', label[0])

        if i == 10:
            break

    references = [[j] for j in references]
    metric_out = metric.compute(predictions=predictions, references=references)
    print(metric_out)


test()



from transformers import AdamW
from transformers.optimization import get_scheduler


#训练
def train():
    optimizer = AdamW(model.parameters(), lr=2e-5)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    model.train()
    for i, data in enumerate(loader):
        out = model(**data)
        loss = out['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if i % 50 == 0:
            out = out['logits'].argmax(dim=2)
            correct = (data['decoder_input_ids'] == out).sum().item()
            total = data['decoder_input_ids'].shape[1] * 8
            accuracy = correct / total
            del correct
            del total

            predictions = []
            references = []
            for j in range(8):
                pred = tokenizer.decode(out[j])
                label = tokenizer.decode(data['decoder_input_ids'][j])
                predictions.append(pred)
                references.append([label])

            metric_out = metric.compute(predictions=predictions,
                                        references=references)

            lr = optimizer.state_dict()['param_groups'][0]['lr']

            print(i, loss.item(), accuracy, metric_out, lr)

    torch.save(model, 'models/7.翻译.model')


train()



model = torch.load('models/7.翻译.model')
test()