政务问答系统模型动转静,插入milvus.ipynb

发布于:2024-06-13 ⋅ 阅读:(131) ⋅ 点赞:(0)

import os
import paddle
from paddlenlp.transformers import AutoModel, AutoTokenizer
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from paddle import inference
from tqdm import tqdm
from paddlenlp.data import Pad, Tuple

class SimCSE(nn.Layer):
    def __init__(self, pretrained_model, dropout=None, margin=0.0, scale=20, output_emb_size=None):
        super().__init__()
        self.ptm = pretrained_model#预训练模型
        #dropout is not None和dropout是不一样的,dropout=0.时,dropout是False,dropout is not None是True
        self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
        self.output_emb_size = output_emb_size
        if output_emb_size > 0:#如果output_emb_size>0,线性转换
            weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
            self.emb_reduce_linear = paddle.nn.Linear(768, output_emb_size, weight_attr=weight_attr)
        self.margin = margin
        self.scale = scale

    @paddle.jit.to_static(
        input_spec=[
            paddle.static.InputSpec(shape=[None, None], dtype="int64"),
            paddle.static.InputSpec(shape=[None, None], dtype="int64"),
        ]
    )
    def get_pooled_embedding(
        self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None, with_pooler=True
    ):
        # Note: cls_embedding is poolerd embedding with act tanh
        sequence_output, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids, attention_mask)
        if with_pooler is False:#如果ptm不返回池化层,把[CLS]输出作为池化输出
            cls_embedding = sequence_output[:, 0, :]
        if self.output_emb_size > 0:
            cls_embedding = self.emb_reduce_linear(cls_embedding)
        cls_embedding = self.dropout(cls_embedding)
        cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)#向量单位化(b,d)
        return cls_embedding
    def get_semantic_embedding(self, data_loader):
        self.eval()
        with paddle.no_grad():
            for batch_data in data_loader:
                input_ids, token_type_ids = batch_data
                text_


网站公告

今日签到

点亮在社区的每一天
去签到