【第四十九周】HippoRAG 2 推理阶段代码分析

发布于:2025-06-16 ⋅ 阅读:(16) ⋅ 点赞:(0)

摘要

本周工作重点围绕HippoRAG 2框架的推理阶段展开深入代码分析,主要关注QA模块的核心实现机制。通过分析提示模板构建、语言模型推理以及答案解析三个关键子模块的代码实现,完整梳理了从文档检索到答案生成的端到端工作流程,深入理解了以下几个核心问题:第一,提示模板的构建策略(包括文档组织方式和问题表述形式)直接影响语言模型的推理效果;第二,检索文档的质量与相关性是保证最终答案准确性的前提条件;最后,答案解析环节的健壮性处理对系统稳定性至关重要。这些发现表明了模型性能与提示工程、文档检索质量之间存在的深度耦合关系,为后续系统优化指明了方向。本周工作不仅建立了对RAG系统推理阶段的完整认知框架,更为后续针对性优化问答准确率、提升系统性能提供了关键的技术依据和优化切入点,也为后续的模型微调、提示工程优化以及检索-推理协同优化等工作奠定了坚实的理论基础和实践指导。

Abstract

This week, I focused on in-depth code analysis of the inference phase in the HippoRAG 2 framework, particularly examining the core implementation of the QA module. By analyzing three key components - prompt template construction, language model inference, and answer parsing - I mapped out the complete workflow from document retrieval to answer generation and gained important insights: Firstly, Prompt template design (including document organization and question phrasing) directly impacts the model’s reasoning performance. Secondly, the quality and relevance of retrieved documents are prerequisites for answer accuracy. At last, robust answer parsing is critical for system stability. These findings reveal the strong interdependence between model performance, prompt engineering, and retrieval quality, pointing the way for future optimizations. This analysis has not only built my comprehensive understanding of RAG inference but also provided key technical foundations for improving answer accuracy and system performance. It establishes a solid basis for upcoming work on model fine-tuning, prompt engineering enhancements, and retrieval-inference coordination.

在这里插入图片描述

在主函数中我们可以看到整个项目的关键语句就三句,上周分析完了索引阶段的代码,这周就看一下检索和QA环节的代码,整个检索QA都是有hipporag这个类中的rag_qa方法实现的。


HippoRAG类中的属性和参数:

  • 属性(Attributes):

global_config (BaseConfig):实例的全局配置设置。如果没有提供值,则使用一个 BaseConfig 的实例。
saving_dir (str):用于存储特定 HippoRAG 实例的目录。如果没有提供值,默认为 outputs。
llm_model (BaseLLM):根据全局配置设置使用的语言模型,用于处理任务。
openie (Union[OpenIE, VLLMOfflineOpenIE]):开放信息抽取模块,根据全局配置设置以在线或离线模式进行配置。
graph:由 initialize_graph 方法初始化的图结构实例。
embedding_model (BaseEmbeddingModel):与当前配置相关联的嵌入模型。
chunk_embedding_store (EmbeddingStore):用于管理段落(chunk)嵌入的嵌入存储器。
entity_embedding_store (EmbeddingStore):用于管理实体(entity)嵌入的嵌入存储器。
fact_embedding_store (EmbeddingStore):用于管理事实(fact)嵌入的嵌入存储器。
prompt_template_manager (PromptTemplateManager):用于管理提示模板和角色映射的管理器。
openie_results_path (str):基于全局配置中的数据集名称和 LLM 名称,存储开放信息抽取结果的文件路径。
rerank_filter (Optional[DSPyFilter]):当全局配置中指定了重排序文件路径时,负责执行信息重排序的过滤器。
ready_to_retrieve (bool):标志位,表示系统是否已准备好执行检索操作。

  • 参数(Parameters):

global_config:全局配置对象。默认为 None,此时将初始化一个新的 BaseConfig 对象。
working_dir:用于存储工作文件的目录。默认为 None,此时将基于类名和时间戳构造一个默认目录。
llm_model_name:大语言模型名称,可以通过参数直接传入,也可以通过配置文件指定。
embedding_model_name:嵌入模型名称,可以通过参数直接传入,也可以通过配置文件指定。
llm_base_url:部署好的 LLM 模型的 URL 地址,可以通过参数直接传入,也可以通过配置文件指定。


rag_qa()

def rag_qa(self,
           queries: List[str|QuerySolution],
           gold_docs: List[List[str]] = None,
           gold_answers: List[List[str]] = None) -> Tuple[List[QuerySolution], List[str], List[Dict]] | Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
    """
    使用HippoRAG 2框架执行检索增强生成(RAG)的问答任务
    
    该方法可以处理基于字符串的查询和预处理的QuerySolution对象。根据输入不同,
    可以仅返回答案,或者额外评估检索和答案质量(使用召回率@k、精确匹配和F1分数指标)
    
    参数:
        queries (List[Union[str, QuerySolution]]): 查询列表,可以是字符串或QuerySolution实例。
            如果是字符串,将执行检索过程。
        gold_docs (Optional[List[List[str]]]): 每个查询对应的标准答案文档列表的列表。
            如果要执行文档级评估,则需要提供此参数。默认为None。
        gold_answers (Optional[List[List[str]]]): 每个查询对应的标准答案列表的列表。
            如果要评估问答(QA)答案质量,则需要提供此参数。默认为None。
    
    返回:
        Union[
            Tuple[List[QuerySolution], List[str], List[Dict]],
            Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]
        ]: 返回一个元组,总是包含以下内容:
            - QuerySolution对象列表,包含每个查询的答案和元数据
            - 响应消息列表
            - 每个查询的元数据字典列表
            如果启用了评估,元组还会包含:
            - 检索阶段的整体结果字典(如果适用)
            - QA评估指标的整体结果(精确匹配和F1分数)
    """
    # 如果有提供标准答案,初始化QA评估器
    if gold_answers is not None:
        qa_em_evaluator = QAExactMatch(global_config=self.global_config)  # 精确匹配评估器
        qa_f1_evaluator = QAF1Score(global_config=self.global_config)   # F1分数评估器

    # 检索阶段(如果需要)
    overall_retrieval_result = None  # 初始化检索结果

    # 如果查询是字符串而非QuerySolution对象,则需要进行检索
    if not isinstance(queries[0], QuerySolution):
        if gold_docs is not None:
            # 如果有提供标准文档,执行带评估的检索
            queries, overall_retrieval_result = self.retrieve(queries=queries, gold_docs=gold_docs)
        else:
            # 否则执行普通检索
            queries = self.retrieve(queries=queries)

    # 执行问答阶段
    queries_solutions, all_response_message, all_metadata = self.qa(queries)

    # 评估问答结果
    if gold_answers is not None:
        # 计算精确匹配指标
        overall_qa_em_result, example_qa_em_results = qa_em_evaluator.calculate_metric_scores(
            gold_answers=gold_answers, 
            predicted_answers=[qa_result.answer for qa_result in queries_solutions],
            aggregation_fn=np.max)  # 使用最大值作为聚合函数
        
        # 计算F1分数指标
        overall_qa_f1_result, example_qa_f1_results = qa_f1_evaluator.calculate_metric_scores(
            gold_answers=gold_answers, 
            predicted_answers=[qa_result.answer for qa_result in queries_solutions],
            aggregation_fn=np.max)

        # 合并QA结果并四舍五入到4位小数
        overall_qa_em_result.update(overall_qa_f1_result)
        overall_qa_results = overall_qa_em_result
        overall_qa_results = {k: round(float(v), 4) for k, v in overall_qa_results.items()}
        logger.info(f"Evaluation results for QA: {overall_qa_results}")

        # 保存检索和QA结果到QuerySolution对象
        for idx, q in enumerate(queries_solutions):
            q.gold_answers = list(gold_answers[idx])  # 保存标准答案
            if gold_docs is not None:
                q.gold_docs = gold_docs[idx]  # 保存标准文档

        # 返回完整结果(包含评估指标)
        return queries_solutions, all_response_message, all_metadata, overall_retrieval_result, overall_qa_results
    else:
        # 如果不需评估,返回基本结果
        return queries_solutions, all_response_message, all_metadata

rag_qa这个方法里又有很多小的方法互相嵌套,我们逐一来看下关键方法的作用。

retrieve()

def retrieve(self,
             queries: List[str],
             num_to_retrieve: int = None,
             gold_docs: List[List[str]] = None) -> List[QuerySolution] | Tuple[List[QuerySolution], Dict]:
    """
    使用HippoRAG 2框架执行文档检索,包含以下步骤:
    - 事实检索
    - 识别记忆优化事实选择
    - 密集段落评分
    - 基于个性化PageRank的重新排序

    参数:
        queries: List[str]
            需要检索文档的查询字符串列表
        num_to_retrieve: int, 可选
            每个查询要检索的最大文档数量。如果未指定,默认使用全局配置中的`retrieval_top_k`值
        gold_docs: List[List[str]], 可选
            每个查询对应的标准答案文档列表。如果启用了检索性能评估(全局配置中的`do_eval_retrieval`),
            则需要提供此参数

    返回:
        List[QuerySolution] 或 (List[QuerySolution], Dict)
            如果未启用检索评估,返回QuerySolution对象列表,每个对象包含对应查询检索到的文档及其分数。
            如果启用了评估,还会返回一个包含检索结果评估指标的字典

    注意
    -----
    - 对于重新排序后没有相关事实的长查询,将默认返回密集段落检索的结果
    """
    retrieve_start_time = time.time()  # 记录检索开始时间

    # 设置要检索的文档数量
    if num_to_retrieve is None:
        num_to_retrieve = self.global_config.retrieval_top_k  # 使用全局配置的默认值

    # 如果有提供标准文档,初始化检索召回评估器
    if gold_docs is not None:
        retrieval_recall_evaluator = RetrievalRecall(global_config=self.global_config)

    # 确保检索对象已准备就绪
    if not self.ready_to_retrieve:
        self.prepare_retrieval_objects()

    # 获取查询的嵌入表示
    self.get_query_embeddings(queries)

    retrieval_results = []  # 初始化检索结果列表

    # 遍历所有查询进行检索
    for q_idx, query in tqdm(enumerate(queries), desc="Retrieving", total=len(queries)):
        # 重新排序阶段开始
        rerank_start = time.time()
        # 获取查询与事实的匹配分数
        query_fact_scores = self.get_fact_scores(query)
        # 对事实进行重新排序,获取top-k事实
        top_k_fact_indices, top_k_facts, rerank_log = self.rerank_facts(query, query_fact_scores)
        rerank_end = time.time()

        # 累计重新排序时间
        self.rerank_time += rerank_end - rerank_start

        # 如果重新排序后没有找到事实,回退到密集段落检索
        if len(top_k_facts) == 0:
            logger.info('No facts found after reranking, return DPR results')
            sorted_doc_ids, sorted_doc_scores = self.dense_passage_retrieval(query)
        else:
            # 使用图搜索基于事实实体检索文档
            sorted_doc_ids, sorted_doc_scores = self.graph_search_with_fact_entities(
                query=query,
                link_top_k=self.global_config.linking_top_k,
                query_fact_scores=query_fact_scores,
                top_k_facts=top_k_facts,
                top_k_fact_indices=top_k_fact_indices,
                passage_node_weight=self.global_config.passage_node_weight)

        # 获取top-k文档内容
        top_k_docs = [self.chunk_embedding_store.get_row(self.passage_node_keys[idx])["content"] 
                      for idx in sorted_doc_ids[:num_to_retrieve]]

        # 将检索结果保存到QuerySolution对象
        retrieval_results.append(QuerySolution(
            question=query, 
            docs=top_k_docs, 
            doc_scores=sorted_doc_scores[:num_to_retrieve]))

    retrieve_end_time = time.time()  # 记录检索结束时间

    # 累计总检索时间
    self.all_retrieval_time += retrieve_end_time - retrieve_start_time

    # 记录各阶段耗时
    logger.info(f"Total Retrieval Time {self.all_retrieval_time:.2f}s")
    logger.info(f"Total Recognition Memory Time {self.rerank_time:.2f}s")
    logger.info(f"Total PPR Time {self.ppr_time:.2f}s")
    logger.info(f"Total Misc Time {self.all_retrieval_time - (self.rerank_time + self.ppr_time):.2f}s")

    # 评估检索结果
    if gold_docs is not None:
        # 设置评估的k值列表
        k_list = [1, 2, 5, 10, 20, 30, 50, 100, 150, 200]
        # 计算检索召回率指标
        overall_retrieval_result, example_retrieval_results = retrieval_recall_evaluator.calculate_metric_scores(
            gold_docs=gold_docs, 
            retrieved_docs=[retrieval_result.docs for retrieval_result in retrieval_results], 
            k_list=k_list)
        logger.info(f"Evaluation results for retrieval: {overall_retrieval_result}")

        # 返回检索结果和评估指标
        return retrieval_results, overall_retrieval_result
    else:
        # 仅返回检索结果
        return retrieval_results

qa()

def qa(self, queries: List[QuerySolution]) -> Tuple[List[QuerySolution], List[str], List[Dict]]:
    """
    使用语言模型对一组查询解决方案执行问答(QA)推理
    
    参数:
        queries: List[QuerySolution]
            包含用户查询、检索到的文档和其他相关信息的QuerySolution对象列表
    
    返回:
        Tuple[List[QuerySolution], List[str], List[Dict]]
            包含以下内容的元组:
            - 更新后的QuerySolution对象列表(包含预测答案)
            - 来自语言模型的原始响应消息列表
            - 与结果关联的元数据字典列表
    """
    # 初始化存储所有QA消息的列表
    all_qa_messages = []

    # 遍历所有查询解决方案,收集QA提示
    for query_solution in tqdm(queries, desc="Collecting QA prompts"):
        # 获取检索到的文档(根据配置取前qa_top_k个)
        retrieved_passages = query_solution.docs[:self.global_config.qa_top_k]

        # 构建用户提示(prompt)
        prompt_user = ''
        for passage in retrieved_passages:
            prompt_user += f'Wikipedia Title: {passage}\n\n'  # 添加文档标题
        prompt_user += 'Question: ' + query_solution.question + '\nThought: '  # 添加问题

        # 检查是否存在针对当前数据集的定制提示模板
        if self.prompt_template_manager.is_template_name_valid(name=f'rag_qa_{self.global_config.dataset}'):
            # 使用该数据集的定制模板
            prompt_dataset_name = self.global_config.dataset
        else:
            # 没有定制模板则使用MUSIQUE的默认模板
            logger.debug(
                f"rag_qa_{self.global_config.dataset} does not have a customized prompt template. Using MUSIQUE's prompt template instead.")
            prompt_dataset_name = 'musique'
        
        # 渲染提示模板并添加到列表中
        all_qa_messages.append(
            self.prompt_template_manager.render(name=f'rag_qa_{prompt_dataset_name}', prompt_user=prompt_user))

    # 使用语言模型对所有QA消息进行推理
    all_qa_results = [self.llm_model.infer(qa_messages) for qa_messages in tqdm(all_qa_messages, desc="QA Reading")]

    # 解压推理结果(响应消息、元数据、缓存命中)
    all_response_message, all_metadata, all_cache_hit = zip(*all_qa_results)
    all_response_message, all_metadata = list(all_response_message), list(all_metadata)

    # 处理响应并提取预测答案
    queries_solutions = []
    for query_solution_idx, query_solution in tqdm(enumerate(queries), desc="Extraction Answers from LLM Response"):
        response_content = all_response_message[query_solution_idx]
        try:
            # 尝试从响应内容中提取答案(查找"Answer:"后的内容)
            pred_ans = response_content.split('Answer:')[1].strip()
        except Exception as e:
            # 如果解析失败,记录警告并使用原始响应内容
            logger.warning(f"Error in parsing the answer from the raw LLM QA inference response: {str(e)}!")
            pred_ans = response_content

        # 将预测答案存入查询解决方案对象
        query_solution.answer = pred_ans
        queries_solutions.append(query_solution)

    # 返回处理后的结果
    return queries_solutions, all_response_message, all_metadata

总结

本周通过对HippoRAG 2推理阶段的代码剖析,我系统掌握了QA模块从文档检索到答案生成的全流程实现机制。研究发现提示模板设计、检索文档质量和答案解析处理是影响系统性能的三大关键因素,这些要素之间存在着紧密的协同关系。这项工作不仅深化了我对RAG系统推理过程的理解,更重要的是为后续的模型优化、提示工程改进和系统调优提供了明确的技术路线和优化方向,为构建高性能的问答系统奠定了坚实基础。


网站公告

今日签到

点亮在社区的每一天
去签到