bert 相似度任务训练完整版

发布于:2024-03-04 ⋅ 阅读:(95) ⋅ 点赞:(0)

任务

之前写了一个相似度任务的版本:bert 相似度任务训练简单版本,faiss 寻找相似 topk-CSDN博客

相似度用的是 0,1,相当于分类任务,现在我们相似度有评分,不再是 0,1 了,分数为 0-5,数字越大代表两个句子越相似,这一次的比较完整,评估,验证集,相似度模型都有了。

数据集

链接:https://pan.baidu.com/s/1B1-PKAKNoT_JwMYJx_zT1g 
提取码:er1z 
原始数据好几千条,我训练数据用了部分 2500 条,验证,测试 300 左右,使用 cpu 也用了好几个小时

train.py

import torch
import os
import time
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, AdamW, get_cosine_schedule_with_warmup
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np


# 设备选择
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'


# 定义文本相似度数据集类
class TextSimilarityDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_len=128):
        self.data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                text1, text2, similarity_score = line.strip().split('\t')
                inputs1 = tokenizer(text1, padding='max_length', truncation=True, max_length=max_len)
                inputs2 = tokenizer(text2, padding='max_length', truncation=True, max_length=max_len)
                self.data.append({
                    'input_ids1': inputs1['input_ids'],
                    'attention_mask1': inputs1['attention_mask'],
                    'input_ids2': inputs2['input_ids'],
                    'attention_mask2': inputs2['attention_mask'],
                    'similarity_score': float(similarity_score),
                })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def cosine_similarity_torch(vec1, vec2, eps=1e-8):
    dot_product = torch.mm(vec1, vec2.t())
    norm1 = torch.norm(vec