昇思25天学习打卡营第13天 |昇思MindSpore 基于 MindSpore 实现 BERT 对话情绪识别

发布于:2024-08-01 ⋅ 阅读:(122) ⋅ 点赞:(0)

一、模型简介
BERT 是 Google 于 2018 年末开发并发布的一种新型语言模型,在众多自然语言处理任务中发挥重要作用。其创新点在于 pre-train 方法,即采用了 Masked Language Model 和 Next Sentence Prediction 两种方法分别捕捉词语和句子级别的表征。

二、调用库的功能介绍

  1. mindspore:提供了深度学习框架的核心功能,用于构建、训练和推理模型。
  2. mindspore.dataset:包含数据处理相关的模块,如文本处理、数据集生成和转换等。
  3. mindnlp._legacy.engine:提供了训练和评估模型的相关类和回调函数。
  4. mindnlp._legacy.metrics:用于定义和计算模型评估指标。

三、函数介绍

1. SentimentDataset

  • 参数path,表示数据集文件的路径。
  • 功能:读取指定路径的数据集文件,提取其中的标签和文本数据。
  • 例句
sentiment_dataset = SentimentDataset("data/train.tsv")

2. process_dataset 函数

  • 参数
    • source:数据集的来源。
    • tokenizer:用于文本分词的工具。
    • max_seq_len(默认值 64):序列的最大长度。
    • batch_size(默认值 32):批次大小。
    • shuffle(默认值 True):是否打乱数据集。
  • 功能:对数据集进行加载、转换、分词和批处理等预处理操作。
  • 例句
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)

3. BertForSequenceClassification.from_pretrained 函数

  • 参数
    • 'bert-base-chinese':预训练模型的名称。
    • num_labels=3:分类的类别数量。
  • 功能:从预训练模型加载并构建用于序列分类的 BERT 模型。
  • 例句
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)

4. auto_mixed_precision 函数

  • 参数
    • model:要进行混合精度处理的模型。
    • 'O1':混合精度的模式。
  • 功能:对模型进行自动混合精度操作,以提高训练速度。
  • 例句
model = auto_mixed_precision(model, 'O1')

5. CheckpointCallback

  • 参数
    • save_path='checkpoint':保存检查点的路径。
    • ckpt_name='bert_emotect':检查点的名称。
    • epochs=1:保存的间隔周期。
    • keep_checkpoint_max=2:保留的最大检查点数量。
  • 功能:在训练过程中按照指定的间隔和策略保存模型检查点。
  • 例句
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)

6. BestModelCallback

  • 参数
    • save_path='checkpoint':保存最佳模型的路径。
    • ckpt_name='bert_emotect_best':最佳模型的名称。
    • auto_load=True:是否自动加载最佳模型。
  • 功能:在训练过程中保存表现最佳的模型。
  • 例句
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)

四、数据集

  1. 提供了一份已标注、经过分词预处理的机器人聊天数据集,来自百度飞桨团队。
  2. 数据由两列组成,以制表符分隔,第一列是情绪分类类别(0 表示消极,1 表示中性,2 表示积极),第二列是以空格分词的中文文本。

五、数据加载和数据预处理

  1. 新建 process_dataset 函数用于数据加载和预处理。
    • 包括数据格式转换、Tokenize 处理和 pad 操作。
    • 针对昇腾 NPU 环境,采用静态 Shape 处理。
  2. 加载预训练的 BertTokenizer ,并对训练集、验证集和测试集进行处理。

六、模型构建

  1. 通过 BertForSequenceClassification 构建情感分类的 BERT 模型。
  2. 加载预训练权重,设置情感三分类的超参数自动构建模型。
  3. 采用自动混合精度操作,实例化优化器和评价指标。
  4. 设置模型训练的权重保存策略,构建训练器并开始训练。

七、模型验证
使用验证数据集对训练好的模型进行验证,评价指标为准确率。

八、模型推理

  1. 遍历推理数据集,展示推理结果与标签。
  2. 自定义推理数据,展示模型泛化能力。