抽取语料库索引语义向量并建milvus库

发布于:2024-07-07 ⋅ 阅读:(122) ⋅ 点赞:(0)

import random
import time

import os
import sys
from tqdm import tqdm
import numpy as np
import paddle
from paddle import inference
from paddlenlp.transformers import AutoModel, AutoTokenizer
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset
from paddlenlp.utils.log import logger

sys.path.append('.')

def convert_example(example,
                    tokenizer,
                    max_seq_length=512,
                    pad_to_max_seq_len=False):
    result = []
    for key, text in example.items():
        encoded_inputs = tokenizer(text=text,
                                   max_seq_len=max_seq_length,
                                   pad_to_max_seq_len=pad_to_max_seq_len)
        input_ids = encoded_inputs["input_ids"]
        token_type_ids = encoded_inputs["token_type_ids"]
        result += [input_ids, token_type_ids]
    return result

model_dir='./output/yysy/'
corpus_file='./datasets/yysy/milvus/milvus_data_s.csv'
max_seq_length=64
batch_size=64
device='gpu'
cpu_threads=8
model_name_or_path='rocketqa-zh-base-query-encoder'

class Predictor(object):
    def __init__(self,
                 model_dir,
                 device="gpu",
                 max_seq_length=128,
                 batch_size=32,
                 use_tensorrt=False,
                 precision="fp32",
                 cpu_threads=10,
                 enable_mkldnn=False):
        self.max_seq_length = max_seq_length
        self.batch_size = batch_size
        model_file = model_dir + "inference.get_pooled_embedding.pdmodel"
        params_file = model_dir + "inference.get_pooled_embedding.pdiparams"
        if not os.path.exists(model_file):
            raise ValueError("not find model file path {}".format(model_file))
        if not os.path.exists(params_file):
            raise ValueError("not find params file path {}".format(params_file))
        config = paddle.inference.Config(model_file, params_file)
        if device == "gpu":
            config.enable_use_gpu(100, 0)
            precision_map = {
                "fp16": inference.PrecisionType.Half,
                "fp32": inference.PrecisionType.Float32,
                "int8": inference.PrecisionType.Int8
            }
            precision_mode = precision_map[precision]
            if use_tensorrt:
                config.enable_tensorrt_engine(max_batch_size=batch_size,
                                              min_subgraph_size=30,
                                              precision_mode=precision_mode)
        elif device == "cpu":
            config.disable_gpu()
            if enable_mkldnn:
                config.set_mkldnn_cache_capacity(10)
                config.enable_mkldnn()
            config.set_cpu_math_library_num_threads(cpu_threads)
        elif device == "xpu":
            config.enable_xpu(100)
        config.switch_use_feed_fetch_ops(False)
        self.predictor = paddle.inference.create_predictor(config)
        self.input_handles = [
            self.predictor.get_input_handle(name)
            for name in self.predictor.get_input_names()
        ]
        self.output_handle = self.predictor.get_output_handle(
            self.predictor.get_output_names()[0])
    def predict(self, data, tokenizer):
        batchify_fn = lambda samples, fn=Tuple(
            Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),  # input
            Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"
                ),  # segment
        ): fn(samples)
        all_embeddings = []
        examples = []
        for idx, text in enumerate(tqdm(data)):
            input_ids, segment_ids = convert_example(
                text,
                tokenizer,
                max_seq_length=self.max_seq_length,
                pad_to_max_seq_len=True)
            examples.append((input_ids, segment_ids))
            if (len(examples) >=self.batch_size):
                input_ids, segment_ids = batchify_fn(examples)
                self.input_handles[0].copy_from_cpu(input_ids)
                self.input_handles[1].copy_from_cpu(segment_ids)
                self.predictor.run()
                logits = self.output_handle.copy_to_cpu()
                all_embeddings.append(logits)
                examples = []
        if (len(examples) > 0):
            input_ids, segment_ids = batchify_fn(examples)
            self.input_handles[0].copy_from_cpu(input_ids)
            self.input_handles[1].copy_from_cpu(segment_ids)
            self.predictor.run()
            logits = self.output_handle.copy_to_cpu()
            all_embeddings.append(logits)
        all_embeddings = np.concatenate(all_embeddings, axis=0)
        np.save('yysy_corpus_embedding', all_embeddings)
def read_text(file_path):
    file = open(file_path)
    id2corpus = {}
    for idx, data in enumerate(file.readlines()):
        id2corpus[idx] = data.strip()
    return id2corpus

 predictor = Predictor(model_dir, device, max_seq_length, batch_size)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
id2corpus = read_text(corpus_file)

 corpus_list = [{idx: text} for idx, text in id2corpus.items()]#用来构建索引库的文本

predictor.predict(corpus_list, tokenizer)

MILVUS_HOST = '127.0.0.1'
MILVUS_PORT = 19530
data_dim = 256
top_k = 20
collection_name = 'literature_search'
partition_tag = 'partition_1'
embedding_name = 'embeddings'
index_config = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {
        "nlist": 1000
    },
}
search_params = {
    "metric_type": "L2",
    "params": {
        "nprobe": top_k
    },
}

from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
)

fmt = "\n=== {:30} ===\n"
text_max_len = 1000
fields = [
    FieldSchema(name="pk",
                dtype=DataType.INT64,
                is_primary=True,#主键
                auto_id=False,#不自动增长
                max_length=100),#id
    FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=text_max_len),#text
    FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=data_dim)#embedding
]
schema = CollectionSchema(fields, "Neural Search Index")

class VecToMilvus():#语义向量-->milvus
    def __init__(self):
        print(fmt.format("start connecting to Milvus"))
        connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
        self.collection = None
    def has_collection(self, collection_name):
        try:
            has = utility.has_collection(collection_name)
            print(f"Does collection {collection_name} exist in Milvus: {has}")
            return has
        except Exception as e:
            print("Milvus has_table error:", e)
    def creat_collection(self, collection_name):
        try:
            print(fmt.format("Create collection {}".format(collection_name)))
            self.collection = Collection(collection_name,
                                         schema,
                                         consistency_level="Strong")
        except Exception as e:
            print("Milvus create collection error:", e)
    def drop_collection(self, collection_name):
        try:
            utility.drop_collection(collection_name)
        except Exception as e:
            print("Milvus delete collection error:", e)
    def create_index(self, index_name):
        try:
            print(fmt.format("Start Creating index"))
            self.collection.create_index(index_name, index_config)
            print(fmt.format("Start loading"))
            self.collection.load()
        except Exception as e:
            print("Milvus create index error:", e)
    def has_partition(self, partition_tag):
        try:
            result = self.collection.has_partition(partition_tag)
            return result
        except Exception as e:
            print("Milvus has partition error: ", e)
    def create_partition(self, partition_tag):
        try:
            self.collection.create_partition(partition_tag)
            print('create partition {} successfully'.format(partition_tag))
        except Exception as e:
            print('Milvus create partition error: ', e)
    def insert(self, entities, collection_name, index_name, partition_tag=None):
        try:
            if not self.has_collection(collection_name):
                self.creat_collection(collection_name)
                self.create_index(index_name)
            else:
                self.collection = Collection(collection_name)
            if (partition_tag
                    is not None) and (not self.has_partition(partition_tag)):
                self.create_partition(partition_tag)
            self.collection.insert(entities, partition_name=partition_tag)
            print(
                f"Number of entities in Milvus: {self.collection.num_entities}"
            )  # check the num_entites
        except Exception as e:
            print("Milvus insert error:", e)

class RecallByMilvus():#从milvus召回向量
    def __init__(self):
        print(fmt.format("start connecting to Milvus"))
        connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
        self.collection = None
    def get_collection(self, collection_name):
        try:
            print(fmt.format("Connect collection {}".format(collection_name)))
            self.collection = Collection(collection_name)
        except Exception as e:
            print("Milvus create collection error:", e)
    def search(self,
               vectors,
               embedding_name,
               collection_name,
               partition_names=[],
               output_fields=[]):
        try:
            self.get_collection(collection_name)
            result = self.collection.search(vectors,
                                            embedding_name,
                                            search_params,
                                            limit=top_k,
                                            partition_names=partition_names,
                                            output_fields=output_fields)
            return result
        except Exception as e:
            print('Milvus recall error: ', e)

data_path='./datasets/yysy/milvus/milvus_data_s.csv'

embedding_path='./yysy_corpus_embedding.npy'
index=18
batch_size=5000

def read_text(file_path):
    file = open(file_path)
    id2corpus = []
    for idx, data in enumerate(file.readlines()):
        id2corpus.append(data.strip())
    return id2corpus

corpus_list_embed=read_text(data_path)

corpus_list_embed[:5]

embeddings = np.load(embedding_path)

embedding_ids = [i for i in range(embeddings.shape[0])]#嵌入ids

client = VecToMilvus()

 client.has_collection(collection_name)

client.drop_collection(collection_name)

data_size = len(embedding_ids)

x=[corpus_list_embed[j][:1000]for j in range(10000, 15000,1)]#[:200]文本切片操作

max([len(i) for i in x])

for i in range(0, data_size, batch_size):
        print(i)

for i in range(0, data_size, batch_size):#i:0-5000-10000-....
    cur_end = i + batch_size
    if (cur_end > data_size):#确保下标不越界
        cur_end = data_size
    batch_emb = embeddings[np.arange(i, cur_end)]#一个批次的嵌入向量
    entities = [
        [j for j in range(i, cur_end, 1)],#索引
        [corpus_list_embed[j][:text_max_len - 1] for j in range(i, cur_end, 1)],#文本
        batch_emb  #每个批次嵌入向量
    ]
    client.insert(collection_name=collection_name,
                  entities=entities,
                  index_name=embedding_name,
                  partition_tag=partition_tag)

recall_client = RecallByMilvus()

embeddings = embeddings[np.arange(index, index + 1)]

time_start = time.time()# start
result = recall_client.search(embeddings,
                              embedding_name,
                              collection_name,
                              partition_names=[partition_tag],
                              output_fields=['pk', 'text'])
time_end = time.time()# end

sum_t = time_end - time_start
print('time cost', sum_t, 's')

for hits in result:
    for hit in hits:
        print(f"hit: {hit}, text field: {hit.entity.get('text')}")


网站公告

今日签到

点亮在社区的每一天
去签到