基于LLM智能问答系统【阿里云:天池比赛】

发布于:2024-12-18 ⋅ 阅读:(135) ⋅ 点赞:(0)

流程:

1、分别识别问题及提供的资料文件中的公司名实体,有公司名的走语义检索,无公司名的走结构化召回

2、结构化召回:Qwen根据问题生成sql,执行sql获取结果数值,把结果数值与问题给到Qwen生成最终结果

3、语义检索:根据1中识别的公司名+比赛提供的数据文件集合找到对应的招股说明书文件、把该文件切分成段N个文本段、使用Qwen为每个文本段生成向量集合A、把问题生成向量B、使用余弦相似度比较2类向量并排序得到top5,把top5合并成一个文本T,把问题与文本T生成提示词给到送给Qwen生成结果

后续优化方向包括不限于:

提升召回率:包括结构化召回与语义召回

提升准确率:主要是语义召回:可以优化提示词+对问题及检索的文本进行归一化

模型微调:sql生成及向量生成这块可以使用微调以后的模型

模型切换:现在使用的是Qwen2.5 7B,可以尝试使用参数更大模型或金融相关的专业模型

得分:综合:78.49

结构化召回:89.05

语义:62.65

排名:31/3502

说明:

本文源码下载:https://download.csdn.net/download/love254443233/90106437

参考的baseline代码=大模型说的队(源码FinQwen)Tongyi-EconML/FinQwen: FinQwen: 致力于构建一个开放、稳定、高质量的金融大模型项目,基于大模型搭建金融场景智能问答系统,利用开源开放来促进「AI+金融」。icon-default.png?t=O83Ahttps://github.com/Tongyi-EconML/FinQwen

关键源码:

提取实体:

import csv
import pandas as pd
import numpy as np
import re
import copy
from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
from modelscope import GenerationConfig

model_dir = '/data/nfs/baozhi/models/Qwen-7B-Chat'

# Note: The default behavior now has injection attack prevention off.
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)

new_question_file_dir = 'intermediate/A01_question_classify.csv'
new_question_file = pd.read_csv(new_question_file_dir,delimiter = ",",header = 0)
company_file_dir = 'files/AF0_pdf_to_company.csv'
company_file = pd.read_csv(company_file_dir,delimiter = ",",header = 0)
company_data_csv_list = list()
company_index_list = list()
company_name_list = list()
for cyc in range(len(company_file)):
    company_name_list.append(company_file[cyc:cyc+1]['公司名称'][cyc])
    company_data_csv_list.append(company_file[cyc:cyc+1]['csv文件名'][cyc])
    temp_index_cp = tokenizer(company_file[cyc:cyc+1]['公司名称'][cyc])
    temp_index_cp = temp_index_cp['input_ids']
    company_index_list.append(temp_index_cp)
    
    
    
g = open('intermediate/A02_question_classify_entity.csv', 'w', newline='', encoding = 'utf-8-sig') 
csvwriter = csv.writer(g)
csvwriter.writerow(['问题id','问题','分类','对应实体','csv文件名'])

for cyc in range(len(new_question_file)):
    
    tempw_id = new_question_file[cyc:cyc+1]['问题id'][cyc]
    tempw_q = new_question_file[cyc:cyc+1]['问题'][cyc]
    tempw_q_class = new_question_file[cyc:cyc+1]['分类'][cyc]
    tempw_entity = 'N_A'
    tempw_csv_name = 'N_A'
    
    
    if new_question_file[cyc:cyc+1]['分类'][cyc] == 'Text':
        temp_index_q = tokenizer(new_question_file[cyc:cyc+1]['问题'][cyc])
        temp_index_q = temp_index_q['input_ids']
        q_cp_similarity_list = list()
        for cyc2 in range(len(company_file)):
            temp_index_cp = company_index_list[cyc2]
            temp_simi = len(set(temp_index_cp) &set(temp_index_q))/ (len(set(temp_index_cp))+len(set(temp_index_q)))
            q_cp_similarity_list.append(temp_simi)
            
            
        t = copy.deepcopy(q_cp_similarity_list) 
        max_number = []
        max_index = []
        
        for _ in range(1):
            number = max(t)
            index = t.index(number)
            t[index] = 0
            max_number.append(number)
            max_index.append(index)
        t = []
        tempw_entity = company_name_list[max_index[0]]
        tempw_csv_name = company_data_csv_list[max_index[0]]
       
        csvwriter.writerow([str(tempw_id),str(tempw_q),tempw_q_class,tempw_entity,tempw_csv_name])
    elif new_question_file[cyc:cyc+1]['分类'][cyc] == 'SQL':
        csvwriter.writerow([str(tempw_id),str(tempw_q),tempw_q_class,tempw_entity,tempw_csv_name])
    else:
        find_its_name_flag = 0
        for cyc_name in range(len(company_name_list)):
            if company_name_list[cyc_name] in tempw_q:
                tempw_entity = company_name_list[cyc_name]
                tempw_csv_name = company_data_csv_list[cyc_name]
                csvwriter.writerow([str(tempw_id),str(tempw_q),tempw_q_class,tempw_entity,tempw_csv_name])
                find_its_name_flag = 1
                break
        if find_its_name_flag == 0:
            csvwriter.writerow([str(tempw_id),str(tempw_q),tempw_q_class,tempw_entity,tempw_csv_name])
        
       
g.close()
print('A02_finished')
exit()

生成sql:

import csv
import pandas as pd 
import numpy as np
import sqlite3
import re
import copy 
from langchain_community.utilities import SQLDatabase
from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
from modelscope import GenerationConfig

table_name_list = ['基金基本信息','基金股票持仓明细','基金债券持仓明细','基金可转债持仓明细','基金日行情表','A股票日行情表','港股票日行情表','A股公司行业划分表','基金规模变动表','基金份额持有人结构']
table_info_dict = {}
n = 5
deny_list = ['0','1','2','3','4','5','6','7','8','9',',','?','。',
             '一','二','三','四','五','六','七','八','九','零','十',
            '的','小','请','.','?','有多少','帮我','我想','知道',
             '是多少','保留','是什么','-','(',')','(',')',':',
              '哪个','统计','且','和','来','请问','记得','有','它们']


# url='sqlite:model_train/other/FinQwen-main/solutions/4_大模型说的队/app/tcdata/bobi.db'
# url="sqlite:data/nfs/baozhi/my_model_train/other/FinQwen-main/bs_challenge_financial_14b_dataset/dataset/bobi.db"
# db0 = SQLDatabase.from_uri(url, sample_rows_in_table_info=0)
# dbd0 = db0.table_info
#
# db2 = SQLDatabase.from_uri(url, sample_rows_in_table_info=2)
# dbd2 = db2.table_info
# list1 = dbd2.split('CREATE TABLE')
# for cyc_piece in range(len(list1)):
#     list1[cyc_piece] = 'CREATE TABLE' + list1[cyc_piece]
# for piece in list1:
#     for word in table_name_list:
#         if word in piece:
#             table_info_dict[word] = piece
question_csv_file_dir = "intermediate/A01_question_classify.csv"
question_csv_file = pd.read_csv(question_csv_file_dir,delimiter = ",",header = 0)
model_dir = '/data/nfs/baozhi/models/Qwen-7B-Chat'
# Note: The default behavior now has injection attack prevention off.
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="cuda:0", trust_remote_code=True, bf16=True).eval()
model.generation_config = GenerationConfig.from_pretrained(model_dir,
                                                           trust_remote_code=True,
                                                           temperature = 0.0001,
                                                           top_p = 1,
                                                           do_sample = False,
                                                           seed = 1234)

print('B01_model_loaded')

deny_token_list = list()
for word in deny_list:
    temp_tokens = tokenizer(word)
    temp_tokens = temp_tokens['input_ids']
    deny_token_list = deny_token_list + temp_tokens

def get_prompt_v33(question,index_list):
    
    Examples = '以下是一些例子:'
    for index in index_list:
        Examples = Examples + "问题:" + example_question_list[index] + '\n'
        Examples = Examples + "SQL:" + example_sql_list[index] + '\n'
    
    impt2 = """
        你是一个精通SQL语句的程序员。
        我会给你一个问题,请按照问题描述,仿照以下例子写出正确的SQL代码。
    """
    
                
    impt2 = impt2 + Examples

    impt2 = impt2 +  "问题:" + question + '\n'
    impt2 = impt2 +  "SQL:" 
    return impt2


SQL_examples_file_dir = "files/ICL_EXP.csv"
SQL_examples_file = pd.read_csv(SQL_examples_file_dir,delimiter = ",",header = 0)

example_employ_list = list()
for cyc in range(len(SQL_examples_file)):
    example_employ_list.append(0)

example_question_list = list()
example_table_list = list()
example_sql_list = list()
example_token_list = list()

for cyc in range(len(SQL_examples_file)):
    example_question_list.append(SQL_examples_file[cyc:cyc+1]['问题'][cyc])
    example_sql_list.append(SQL_examples_file[cyc:cyc+1]['SQL'][cyc])
    temp_tokens = tokenizer(SQL_examples_file[cyc:cyc+1]['问题'][cyc])
    temp_tokens = temp_tokens['input_ids']
    temp_tokens2 = [x for x in temp_tokens if x not in deny_token_list]
    example_token_list.append(temp_tokens2)

g = open('intermediate/question_SQL_V6.csv', 'w', newline='', encoding = 'utf-8-sig') 
csvwriter = csv.writer(g)
csvwriter.writerow(['问题id','问题','SQL语句','prompt'])

pattern1 = r'\d{8}'

for cyc in range(len(question_csv_file)):
    if cyc % 50 == 0:
        print(cyc)
    response2 = 'N_A'
    prompt2 = 'N_A'

    if question_csv_file['分类'][cyc] == 'SQL' and cyc not in [174]:
        temp_question = question_csv_file[cyc:cyc+1]['问题'][cyc]
        date_list =  re.findall(pattern1,temp_question)
        temp_question2_for_search = temp_question
        for t_date in date_list:
            temp_question2_for_search.replace(t_date,' ')
        temp_tokens = tokenizer(temp_question2_for_search)
        temp_tokens = temp_tokens['input_ids']
        temp_tokens2 = [x for x in temp_tokens if x not in deny_token_list]
        temp_tokens = temp_tokens2
        #计算与已有问题的相似度
        similarity_list = list()
        for cyc2 in range(len(SQL_examples_file)):
            similarity_list.append(len(set(temp_tokens) &set(example_token_list[cyc2]))/ (len(set(temp_tokens))+len(set(example_token_list[cyc2])) ))

        #求与第X个问题相似的问题

        t = copy.deepcopy(similarity_list)
        # 求m个最大的数值及其索引
        max_number = []
        max_index = []
        for _ in range(n):
            number = max(t)
            index = t.index(number)
            t[index] = 0
            max_number.append(number)
            max_index.append(index)
        t = []
        
        temp_length_test = ""
        short_index_list = list()
        for index in max_index:
            temp_length_test_1 = temp_length_test
            temp_length_test = temp_length_test + example_question_list[index]
            temp_length_test = temp_length_test + example_sql_list[index]
            if len(temp_length_test) > 2300:
                break
            short_index_list.append(index)
        
        prompt2 = get_prompt_v33(question_csv_file['问题'][cyc],short_index_list)
        print(f"{str(cyc)} prompt2:{prompt2}")
        response2, history = model.chat(tokenizer, prompt2, history=None)
        print(f"response2 = {response2}, \n history = {history}")
        print("---------------------------------------------------------------------------------")
    else:
        pass
    csvwriter.writerow([str(question_csv_file[cyc:(cyc+1)]['问题id'][cyc]),
                str(question_csv_file[cyc:(cyc+1)]['问题'][cyc]),
                response2,prompt2])







语义检索:

import json
import csv
import pandas as pd

import re
from collections import Counter
import math
from modelscope import AutoTokenizer

from ai_loader import tongyi


def counter_cosine_similarity(c1, c2):  # 使用截断的ccs
    terms = set(c1).union(c2)
    dotprod = sum(c1.get(k, 0) * c2.get(k, 0) for k in terms)
    magA = math.sqrt(sum(c1.get(k, 0) ** 2 for k in terms))
    magB = math.sqrt(sum(c2.get(k, 0) ** 2 for k in terms))

    if magA * magB != 0:
        return dotprod / (magA * magB)
    else:
        return 0


pattern1 = r'截至'
pattern2 = r'\d{1,4}年\d{1,2}月\d{1,2}日'

q_file_dir = 'intermediate/A02_question_classify_entity.csv'
q_file = pd.read_csv(q_file_dir, delimiter=",", header=0)

c00_file = 'intermediate/C00_text_understanding.csv'
g = open(c00_file, 'w', newline='', encoding='utf-8-sig')
text_file_dir = 'tcdata/pdf_txt_file'
csvwriter = csv.writer(g)
csvwriter.writerow(['问题id', '问题', '问题[标准化后]', '对应实体', 'csv文件名', 'FA', 'top_text'])

stopword_list = ['根据', '招股意见书', '招股意向书', '截至', '千元', '万元', '哪里', '哪个',
                 '知道', "什么",   '?', '是',
                 '的', '想', '元', '。', ',', '怎样', '谁', '以及', '了',  '对', '?', ',']
bd_list = ['?', '。', ',', '[', ']']

tongyi_model_path = "/data/nfs/baozhi/models/Qwen-7B-Chat"
tokenizer = AutoTokenizer.from_pretrained(tongyi_model_path, trust_remote_code=True)

from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter


def text_split(content):
    """ 将文本分割为较小的部分 """
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=100,
        separators=['\n\n', "\n", "。"],
        keep_separator=False)
    return text_splitter.split_text(content)


def embedding_2_ver(embedding):
    temp_tokens = list()
    for word_add in embedding:
        temp_tokens.append(word_add)
    return temp_tokens


def counter_cosine_similarity(c1, c2):  # 使用截断的ccs
    terms = set(c1).union(c2)
    dotprod = sum(c1.get(k, 0) * c2.get(k, 0) for k in terms)
    magA = math.sqrt(sum(c1.get(k, 0) ** 2 for k in terms))
    magB = math.sqrt(sum(c2.get(k, 0) ** 2 for k in terms))

    if magA * magB != 0:
        return dotprod / (magA * magB)
    else:
        return 0


n = 30
cap = 4


def text_similarity(text, C_temp_q_tokens):
    """ 计算文本和问题的相似度 """
    temp_s_tokens = tokenizer(text)
    temp_s_tokens = temp_s_tokens['input_ids']

    C_temp_s_tokens = Counter(temp_s_tokens)
    C_temp_s_tokens['220'] = 0

    for token in C_temp_s_tokens:
        if C_temp_s_tokens[token] >= cap:
            C_temp_s_tokens[token] = cap
    return counter_cosine_similarity(C_temp_s_tokens, C_temp_q_tokens)

import copy
def process_text_question(question, company, file_path):
    """ 处理单个问题 """
    try:

        temp_q_list = question.split()
        temp_q_tokens = list()
        for word in temp_q_list:
            temp_q_tokens_add = tokenizer(word)
            temp_q_tokens_add = temp_q_tokens_add['input_ids']
            for word_add in temp_q_tokens_add:
                temp_q_tokens.append(word_add)
        C_temp_q_tokens = Counter(temp_q_tokens)

        with open(file_path, 'r', encoding='utf-8') as file:
            content = file.read()
        content = content.replace(' ', '')

        text_list = text_split(content)
        t = copy.deepcopy(text_list)
        sim_list = list()
        for text in text_list:
            text1 = text
            for bd in bd_list:
                text1 = text1.replace(bd,' ')
            sim = text_similarity(text1, C_temp_q_tokens)
            sim_list.append(sim)

        sorted_indices = sorted(enumerate(sim_list), key=lambda x: x[1], reverse=True)
        top_texts = [t[index] for index, _ in sorted_indices[:5]]

        # prompt = ChatPromptTemplate.from_template(
        #     "你是一个能精准提取文本信息并回答问题的AI。\n"
        #     "请根据以下资料的所有内容,首先帮我判断能否依据给定材料回答出问题。"
        #     "如果能根据给定材料回答,则提取出最合理的答案来回答问题,并回答出完整内容,不要输出表格:\n\n"
        #     "{text}\n\n"
        #     "请根据以上材料回答:{q}\n\n"
        #     "请按以下格式输出:\n"
        #     "能否根据给定材料回答问题:回答能或否\n"
        #     "答案:").format_messages(q=question, text="\n".join(top_texts))
        prompt = ChatPromptTemplate.from_template(
            "你是一个能精准提取文本信息并回答问题的AI。\n"
            "下面是一段资料,不要计算,不要计算,直接从资料中寻找问题的答案,使用完整的句子回答问题。\n "
            "如果资料不包含问题的答案,回答“不知道。”如果从资料无法得出问题的答案,回答“不知道。”如果答案未在资料中说明,回答“不知道。”如果资料与问题无关或者在资料中找不到问题的答案,回答“不知道。”如果资料没有明确说明问题答案,回答“不知道。”资料:\n\n"
            "{text}\n\n"
            "请根据以上材料回答:{q}\n\n"
            "答案:").format_messages(q=question, text="\n".join(top_texts))

        response = tongyi(prompt[0].content, temperature=0.01, top_p=0.5)
        return (response, top_texts)
    except Exception as e:
        print(f"Error processing question: {e}")
        return None


print('C00_Started')
for cyc in range(1000):
    temp_q = q_file[cyc:cyc + 1]['问题'][cyc]
    temp_class = q_file[cyc:cyc + 1]['分类'][cyc]
    temp_e = q_file[cyc:cyc + 1]['对应实体'][cyc]
    print(cyc)

    if temp_e == 'N_A':
        csvwriter.writerow([q_file[cyc:cyc + 1]['问题id'][cyc],
                            q_file[cyc:cyc + 1]['问题'][cyc],
                            'N_A', 'N_A', 'N_A', 'N_A', 'N_A'])
        continue
    else:
        if '\n' in temp_e:
            temp_e = temp_e.replace('\n', '')
        print(f'问题:{temp_q}')
        print(f'分类:{temp_class}')
        print(f'对应实体:{temp_e}')
        temp_text_name = q_file[cyc:cyc + 1]['csv文件名'][cyc]
        print(f'csv文件名:{temp_text_name}')
        temp_text_name = temp_text_name.replace('PDF.csv', '')
        temp_text_name = temp_text_name + "txt"

        temp_csv_dir = text_file_dir + '/' + temp_text_name
        print(f'csv文件名[转换后]:{temp_csv_dir}')

        temp_q = temp_q.replace(' ', '')
        temp_q = temp_q.replace(temp_e, ' ')
        #去除截至与日期,使得匹配更有针对性
        str1_list = re.findall(pattern1, temp_q)
        str2_list = re.findall(pattern2, temp_q)
        for word in str1_list:
            temp_q = temp_q.replace(word,'')
        for word in str2_list:
            temp_q = temp_q.replace(word,'')

        for word in stopword_list:
            temp_q = temp_q.replace(word, ' ')
        print(f'问题[标准化后]:{temp_q}')
        FA, top_text = process_text_question(temp_q, temp_e, temp_csv_dir)
        print(f'答案如下:')
        print(FA)
        print("-----------------------------------------------------------")
        csvwriter.writerow([q_file[cyc:cyc + 1]['问题id'][cyc],
                            q_file[cyc:cyc + 1]['问题'][cyc],
                            temp_q, temp_e, temp_text_name, FA, json.dumps(top_text, ensure_ascii=False)])
g.close()