NLP-transformer学习:(8)trainer 使用方法

发布于:2025-02-11 ⋅ 阅读:(98) ⋅ 点赞:(0)

NLP-transformer学习:(8)trainer 使用方法

在这里插入图片描述
11月工作996压力较大,任务完成后,目前休息了一个月,2025年新的一天继续开始补基础。
本章节是单独的 NLP-transformer学习 章节,主要实践了evaluate。同时,最近将学习代码传到:https://github.com/MexWayne/mexwayne_transformers-code,
作者的代码版本有些细节我发现到目前不能完全行的通,为了尊重原作者,我这里保持了大部分的内容,并标明了来源,欢迎大家一起学习。


一、整体代码

这里没什么好讲的说实话,我这里将整体代码附上:

# import the related package
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

import torch
import evaluate
from transformers import DataCollatorWithPadding


def process_function(examples):
    tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)
    tokenized_examples["labels"] = examples["label"]
    return tokenized_examples


def eval_metric(eval_predict):
    predictions, labels = eval_predict
    predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc


if __name__ == "__main__":
    # download the dataset
    dataset = load_dataset("csv", data_files="/home/mex/Desktop/learn_transformer/mexwayne_transformers_NLP/01-Getting_Started/07-trainer/ChnSentiCorp_htl_all.csv", split="train")
    dataset = dataset.filter(lambda x: x["review"] is not None)
    print("load dataset:")
    print(dataset)

    # split the dataset into 0.1 and 0.9, the 0.1 for test_dataset
    datasets = dataset.train_test_split(test_size=0.1)
    print("split dataset:")
    print(dataset)

    # build tokenize the data
    tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")
    tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)
    print(tokenized_datasets)

    # load the model
    model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")
    print("model config:")
    print(model.config)

    # build the evaluate module
    acc_metric = evaluate.load("accuracy")
    f1_metric = evaluate.load("f1")


    # build train args 
    train_args = TrainingArguments(output_dir="./checkpoints",  # train model output path 
                               per_device_train_batch_size=64,  # train batch_size
                               per_device_eval_batch_size=128,  # test batch_size
                               logging_steps=10,                # log 
                               eval_strategy="epoch",     # evaluation strategy
                               save_strategy="epoch",           # save the model every epoch 
                               save_total_limit=3,              # only keep 3 model save
                               learning_rate=2e-5,              #  
                               weight_decay=0.01,               #  
                               metric_for_best_model="f1",      #  
                               load_best_model_at_end=True)     # save the best model after train 
    print("train_args")
    print(train_args)


    # build trainer
    trainer = Trainer(model=model, 
                      args=train_args, 
                      train_dataset=tokenized_datasets["train"], 
                      eval_dataset=tokenized_datasets["test"], 
                      data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
                      compute_metrics=eval_metric)
    
    # start trainnig
    trainer.train()

    # evaluation
    trainer.evaluate(tokenized_datasets["test"])
    #####################################################


    #####################################################
    # try one new, to predict a results
    trainer.predict(tokenized_datasets["test"])
    from transformers import pipeline
    
    model.config.id2label = id2_label
    pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=0)
    sen = "我觉得还是蛮不错的哈!"
    print(pipe(sen))

其中依赖的 csv 在我的github 仓库下:
https://github.com/MexWayne/mexwayne_transformers-code/blob/master/01-Getting_Started/07-trainer/ChnSentiCorp_htl_all.csv

当你正确的配置环境后可以看到数据被正确加载
在这里插入图片描述
在这里插入图片描述
以及相关的trainer 可以训练数据

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

二、遇到的问题

问题1:如下图

在这里插入图片描述
遇到这个问题
conda install accelerate 不解决问题
真正的解决方案:

要用最新的
pip install -U accelerate
pip install -U transformers

问题2:如下图

注意这里 我用了 一个 库
在这里插入图片描述
要手动键入,1 和 2 都不行,3可以,这是最终结果
在这里插入图片描述

问题3:如下图

笔者之前一直能提交的仓库突然不行了
在这里插入图片描述
后来查了下,用 ssh -T git@github.com 可以次测试通断
在这里插入图片描述
发现笔者的 22port 用不了
后来采取如下方法在这里插入图片描述
当需要恢复22port 时,只需要删除这个config 即可


网站公告

今日签到

点亮在社区的每一天
去签到