高级RAG策略学习(一)——自适应检索系统

发布于:2025-09-03 ⋅ 阅读:(17) ⋅ 点赞:(0)

LangChain 自适应检索系统知识点总结

1. 环境配置与依赖导入

1.1 核心依赖

from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain.docstore.document import Document
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from dotenv import load_dotenv

1.2 环境变量配置

load_dotenv()

# 条件性设置环境变量,避免 NoneType 错误
if os.getenv('OPENAI_API_KEY'):
    os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')

# 配置 DashScope 兼容的 API 端点
DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
BASE_URL = os.getenv("OPENAI_API_BASE") or DASHSCOPE_BASE_URL

知识点:

  • 使用 dotenv 管理环境变量
  • 条件性设置环境变量避免运行时错误
  • 支持多种 API 端点的灵活配置

2. Pydantic 模型定义与结构化输出

2.1 查询分类模型

class categories_options(BaseModel):
    """Query classification model for structured output"""
    category: str = Field(
        description="The category of the query, the options are: Factual, Analytical, Opinion, or Contextual",
        example="Factual"
    )

2.2 相关性评分模型

class relevant_score(BaseModel):
    score: float = Field(
        description="The relevance score of the document to the query", 
        example=8.0
    )

2.3 文档选择模型

class SelectedIndices(BaseModel):
    """Model for document selection indices"""
    indices: List[int] = Field(
        description="Indices of selected documents", 
        example=[0, 1, 2, 3]
    )

2.4 子查询生成模型

class SubQueries(BaseModel):
    """Model for generating analytical sub-queries"""
    sub_queries: List[str] = Field(
        description="List of sub-queries for comprehensive analysis",
        example=["What is the population of New York?", "What is the GDP of New York?"]
    )

知识点:

  • 使用 Pydantic 模型确保 LLM 输出的结构化和类型安全
  • Field 函数提供详细的字段描述和示例
  • 支持复杂数据类型如 List[str]List[int]

3. 大模型配置与初始化

3.1 嵌入模型配置

embeddings = DashScopeEmbeddings(
    model="text-embedding-v1",
    dashscope_api_key=os.getenv("DASHSCOPE_API_KEY")
)

3.2 聊天模型配置

self.llm = ChatOpenAI(
    temperature=0, 
    model="qwen-plus", 
    max_tokens=4000,
    api_key=os.getenv("DASHSCOPE_API_KEY"), 
    base_url=BASE_URL
)

知识点:

  • DashScopeEmbeddings 用于文本向量化
  • ChatOpenAI 支持兼容 OpenAI API 的模型
  • 通过 base_url 参数支持自定义 API 端点
  • temperature=0 确保输出的确定性

4. 向量数据库构建与管理

4.1 文本分割与向量化

class BaseRetrievalStrategy:
    def __init__(self, text, max_tokens=4000):
        self.embeddings = DashScopeEmbeddings(
            model="text-embedding-v1",
            dashscope_api_key=os.getenv("DASHSCOPE_API_KEY")
        )
        
        # 文本分割
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
        self.texts = text_splitter.split_text(" ".join(text))
        
        # 构建向量数据库
        self.vectorstore = FAISS.from_texts(self.texts, self.embeddings)
        
        # 配置检索器
        self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 10})

知识点:

  • CharacterTextSplitter 用于文本分块,控制 chunk_sizechunk_overlap
  • FAISS.from_texts() 快速构建向量数据库
  • as_retriever() 方法将向量存储转换为检索器
  • search_kwargs 参数控制检索数量

5. 智能查询分类系统

5.1 查询分类器实现

class QueryClassifier:
    def __init__(self):
        self.llm = ChatOpenAI(
            temperature=0, 
            model="qwen-plus", 
            max_tokens=4000,
            api_key=os.getenv("DASHSCOPE_API_KEY"), 
            base_url=BASE_URL
        )
        
        # 结构化输出配置
        self.structured_llm = self.llm.with_structured_output(categories_options)

    def classify(self, query):
        prompt = f"""
        Classify the following query into one of these categories:
        - Factual: Questions seeking specific facts or information
        - Analytical: Questions requiring analysis, comparison, or reasoning
        - Opinion: Questions asking for opinions, recommendations, or subjective views
        - Contextual: Questions that depend on context or require understanding of relationships
        
        Query: {query}
        """
        
        result = self.structured_llm.invoke(prompt)
        return result.category

知识点:

  • with_structured_output() 方法确保 LLM 输出符合 Pydantic 模型
  • 通过详细的 prompt 指导模型进行准确分类
  • 返回结构化的分类结果

6. 多策略检索系统

6.1 事实性检索策略

class FactualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4):
        # 查询增强
        enhanced_query_prompt = f"""
        Enhance the following query to improve factual information retrieval:
        Original query: {query}
        Enhanced query:
        """
        enhanced_query = self.llm.invoke(enhanced_query_prompt).content
        
        # 检索文档
        docs = self.retriever.get_relevant_documents(enhanced_query)
        
        # 文档评分与排序
        scored_docs = []
        for doc in docs:
            score_prompt = f"""
            Rate the relevance of this document to the query on a scale of 1-10:
            Query: {query}
            Document: {doc.page_content}
            """
            
            score_result = self.structured_llm.invoke(score_prompt)
            scored_docs.append((doc, score_result.score))
        
        # 按分数排序并返回前k个
        scored_docs.sort(key=lambda x: x[1], reverse=True)
        return [doc for doc, score in scored_docs[:k]]

6.2 分析性检索策略

class AnalyticalRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4):
        # 生成子查询
        sub_query_prompt = f"""
        Break down this analytical query into 2-3 specific sub-queries:
        Main query: {query}
        """
        
        sub_queries_result = self.structured_llm_sub.invoke(sub_query_prompt)
        sub_queries = sub_queries_result.sub_queries
        
        # 为每个子查询检索文档
        all_docs = []
        for sub_query in sub_queries:
            docs = self.retriever.get_relevant_documents(sub_query)
            all_docs.extend(docs)
        
        # 文档去重和多样性选择
        unique_docs = list({doc.page_content: doc for doc in all_docs}.values())
        
        if len(unique_docs) <= k:
            return unique_docs
        
        # 使用 LLM 选择最具多样性的文档
        diversity_prompt = f"""
        Select {k} most diverse and relevant documents for this analytical query: {query}
        Documents: {[f"{i}: {doc.page_content[:200]}..." for i, doc in enumerate(unique_docs)]}
        """
        
        selection_result = self.structured_llm_selection.invoke(diversity_prompt)
        selected_indices = selection_result.indices
        
        return [unique_docs[i] for i in selected_indices if i < len(unique_docs)][:k]

6.3 观点性检索策略

class OpinionRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=3):
        # 识别观点关键词
        opinion_prompt = f"""
        Identify key opinion-related terms in this query: {query}
        Enhanced query for opinion retrieval:
        """
        enhanced_query = self.llm.invoke(opinion_prompt).content
        
        # 检索相关文档
        docs = self.retriever.get_relevant_documents(enhanced_query)
        
        # 文档分类和选择
        categorized_docs = []
        for doc in docs:
            category_prompt = f"""
            Categorize this document's stance on the query:
            Query: {query}
            Document: {doc.page_content}
            Categories: Supporting, Opposing, Neutral, Mixed
            """
            
            category = self.llm.invoke(category_prompt).content
            categorized_docs.append((doc, category))
        
        # 确保观点多样性
        selected_docs = []
        categories_seen = set()
        
        for doc, category in categorized_docs:
            if category not in categories_seen or len(selected_docs) < k:
                selected_docs.append(doc)
                categories_seen.add(category)
                if len(selected_docs) >= k:
                    break
        
        return selected_docs

6.4 上下文检索策略

class ContextualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4, user_context=None):
        # 整合用户上下文
        if user_context:
            context_prompt = f"""
            Integrate user context with the query:
            User context: {user_context}
            Original query: {query}
            Contextualized query:
            """
            contextualized_query = self.llm.invoke(context_prompt).content
        else:
            contextualized_query = query
        
        # 查询重构
        rewrite_prompt = f"""
        Rewrite this query to better capture contextual relationships:
        Query: {contextualized_query}
        Rewritten query:
        """
        rewritten_query = self.llm.invoke(rewrite_prompt).content
        
        # 检索和排序
        docs = self.retriever.get_relevant_documents(rewritten_query)
        
        # 基于上下文相关性排序
        scored_docs = []
        for doc in docs:
            context_score_prompt = f"""
            Rate how well this document addresses the contextual aspects of the query (1-10):
            Query: {query}
            Document: {doc.page_content}
            """
            
            score_result = self.structured_llm.invoke(context_score_prompt)
            scored_docs.append((doc, score_result.score))
        
        scored_docs.sort(key=lambda x: x[1], reverse=True)
        return [doc for doc, score in scored_docs[:k]]

知识点:

  • 每种策略针对不同类型的查询优化检索逻辑
  • 使用 LLM 进行查询增强、文档评分和选择
  • 实现文档去重、多样性选择和上下文整合
  • 结构化输出确保评分和选择的准确性

7. 自定义检索器实现

7.1 适配 LangChain BaseRetriever

class PydanticAdaptiveRetriever(BaseRetriever):
    adaptive_retriever: 'AdaptiveRetriever' = Field(exclude=True)

    class Config:
        arbitrary_types_allowed = True

    def get_relevant_documents(self, query: str) -> List[Document]:
        return self.adaptive_retriever.get_relevant_documents(query)

    async def aget_relevant_documents(self, query: str) -> List[Document]:
        return self.get_relevant_documents(query)

    def _get_relevant_documents(self, query: str) -> List[Document]:
        return self.adaptive_retriever.get_relevant_documents(query)
        
    async def _aget_relevant_documents(self, query: str) -> List[Document]:
        return self.adaptive_retriever.get_relevant_documents(query)

7.2 自适应检索器核心逻辑

class AdaptiveRetriever:
    def __init__(self, texts: List[str]):
        self.classifier = QueryClassifier()
        self.strategies = {
            "Factual": FactualRetrievalStrategy(texts),
            "Analytical": AnalyticalRetrievalStrategy(texts),
            "Opinion": OpinionRetrievalStrategy(texts),
            "Contextual": ContextualRetrievalStrategy(texts)
        }

    def get_relevant_documents(self, query: str) -> List[Document]:
        category = self.classifier.classify(query)
        strategy = self.strategies[category]
        return strategy.retrieve(query)

知识点:

  • 继承 BaseRetriever 实现自定义检索器
  • 实现同步和异步检索方法
  • Field(exclude=True) 排除不需要序列化的字段
  • arbitrary_types_allowed = True 允许自定义类型

8. RAG 系统集成与链式调用

8.1 完整 RAG 系统实现

class AdaptiveRAG:
    def __init__(self, texts: List[str]):
        # 初始化自适应检索器
        adaptive_retriever = AdaptiveRetriever(texts)
        self.retriever = PydanticAdaptiveRetriever(adaptive_retriever=adaptive_retriever)
        
        # 配置 LLM
        self.llm = ChatOpenAI(
            temperature=0, 
            model="qwen-plus", 
            max_tokens=4000,
            api_key=os.getenv("DASHSCOPE_API_KEY"), 
            base_url=BASE_URL
        )

        # 创建自定义提示模板
        prompt_template = """Use the following pieces of context to answer the question at the end. 
        If you don't know the answer, just say that you don't know, don't try to make up an answer.

        {context}

        Question: {question}
        Answer:"""
        
        prompt = PromptTemplate(
            template=prompt_template, 
            input_variables=["context", "question"]
        )

        # 创建 LLM 链
        self.llm_chain = prompt | self.llm

    def answer(self, query: str) -> str:
        # 检索相关文档
        docs = self.retriever.get_relevant_documents(query)
        
        # 构建输入数据
        input_data = {
            "context": "\n".join([doc.page_content for doc in docs]), 
            "question": query
        }
        
        # 调用 LLM 链生成答案
        return self.llm_chain.invoke(input_data)

8.2 系统使用示例

# 初始化知识库
texts = [
    "The Earth is the third planet from the Sun and the only astronomical object known to harbor life."
]

# 创建 RAG 系统
rag_system = AdaptiveRAG(texts)

# 不同类型查询测试
factual_result = rag_system.answer("What is the distance between the Earth and the Sun?").content
analytical_result = rag_system.answer("How does the Earth's distance from the Sun affect its climate?").content
opinion_result = rag_system.answer("What are the different theories about the origin of life on Earth?").content
contextual_result = rag_system.answer("How does the Earth's position in the Solar System influence its habitability?").content

知识点:

  • PromptTemplate 定义结构化的提示模板
  • 使用管道操作符 | 创建 LLM 链
  • invoke() 方法执行链式调用
  • 文档内容拼接和上下文构建

9. 核心 LangChain 概念总结

9.1 知识库定义

  • 文本分割: 使用 CharacterTextSplitter 将长文本分块
  • 向量化: 通过 DashScopeEmbeddings 将文本转换为向量
  • 向量存储: 使用 FAISS 构建高效的向量数据库
  • 检索器: 通过 as_retriever() 创建检索接口

9.2 规范化输出

  • Pydantic 模型: 定义结构化的数据模型
  • 结构化输出: 使用 with_structured_output() 确保 LLM 输出格式
  • 类型安全: 通过类型注解和 Field 描述确保数据质量

9.3 向量库调用

  • FAISS 集成: 快速构建和查询向量数据库
  • 相似性搜索: 通过 get_relevant_documents() 检索相关文档
  • 检索参数: 使用 search_kwargs 控制检索行为

9.4 大模型选择性路由

  • 查询分类: 使用 LLM 对查询进行智能分类
  • 策略模式: 根据查询类型选择不同的检索策略
  • 动态路由: 运行时根据分类结果选择合适的处理策略

9.5 链式调用与集成

  • 提示模板: 使用 PromptTemplate 标准化输入格式
  • 链式操作: 通过管道操作符连接不同组件
  • 异步支持: 实现同步和异步检索方法
  • 自定义检索器: 继承 BaseRetriever 实现复杂检索逻辑

10. 最佳实践与注意事项

  1. 环境变量管理: 使用条件判断避免 NoneType 错误
  2. 结构化输出: 始终使用 Pydantic 模型确保输出格式
  3. 错误处理: 在关键操作中添加适当的错误处理逻辑
  4. 性能优化: 合理设置文档分块大小和检索数量
  5. 模块化设计: 将不同功能封装为独立的类和方法
  6. 异步支持: 为高并发场景实现异步方法
  7. 可扩展性: 使用策略模式便于添加新的检索策略

完整代码

# -*- coding: utf-8 -*-
# @Time: 2025/8/29 00:59
# @Author: 陈伟峰
# @Email: swpucwf@126.com
# @File: adaptive_retrieval.py
# @Software: PyCharm
import dashscope
from langchain_community.embeddings import DashScopeEmbeddings
import os
from dotenv import load_dotenv
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from typing import Dict, Any, List
from langchain.docstore.document import Document
from langchain_openai import ChatOpenAI
from pydantic import BaseModel,Field
# from langchain_core.pydantic_v1 import BaseModel, Field
# from langchain_core.pydantic_v1 import Field as LCField

# Load environment variables from a .env file
load_dotenv()

# Set the OpenAI API key environment variable (if needed)
if os.getenv('OPENAI_API_KEY'):
    os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')

# DashScope OpenAI-compatible base url
DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
BASE_URL = os.getenv("OPENAI_API_BASE") or DASHSCOPE_BASE_URL
print(f"Using OpenAI-compatible base_url: {BASE_URL}")

class categories_options(BaseModel):
    '''
    对语句进行分类,划分为客观事实类、分析类、观点类、上下文类
    '''
    category: str = Field(description="The category of the query, the options are: Factual, Analytical, Opinion, or Contextual",example="Factual")


class QueryClassifier:
    def __init__(self):

        self.llm = ChatOpenAI(model="qwen-plus", temperature=0, max_tokens=4000,
                               api_key=os.getenv("DASHSCOPE_API_KEY"), base_url=BASE_URL)

        self.prompt = PromptTemplate(
            input_variables=["query"],
            template="Classify the following query into one of these categories: Factual, Analytical, Opinion, or Contextual.\nQuery: {query}\nCategory:"
        )

        self.chain = self.prompt | self.llm.with_structured_output(categories_options)

    def classify(self, query):

        print("clasiffying query")

        return self.chain.invoke({"query": query}).category


class BaseRetrievalStrategy:
    def __init__(self, text,max_tokens=4000):


        self.embeddings = DashScopeEmbeddings(
            model="text-embedding-v4",
            dashscope_api_key=os.getenv('DASHSCOPE_API_KEY')
        )


        text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
        self.documents = text_splitter.create_documents(texts)

        self.db = FAISS.from_documents(self.documents, self.embeddings)
        self.llm = ChatOpenAI(temperature=0, model="qwen-plus", max_tokens=4000,
                               api_key=os.getenv("DASHSCOPE_API_KEY"), base_url=BASE_URL)


    def retrieve(self, query, k=4):
        return self.db.similarity_search(query, k=k)



class relevant_score(BaseModel):
    score: float = Field(description="The relevance score of the document to the query", example=8.0)


class FactualRetrievalStrategy(BaseRetrievalStrategy):
    '''
    事实类检索策略
    '''
    def retrieve(self, query, k=4):
        print("retrieving factual")
        # Use LLM to enhance the query
        enhanced_query_prompt = PromptTemplate(
            input_variables=["query"],
            template="Enhance this factual query for better information retrieval: {query}"
        )
        query_chain = enhanced_query_prompt | self.llm
        enhanced_query = query_chain.invoke({"query": query}).content
        print(f'enhande query: {enhanced_query}')

        # Retrieve documents using the enhanced query
        docs = self.db.similarity_search(enhanced_query, k=k * 2)

        # Use LLM to rank the relevance of retrieved documents
        ranking_prompt = PromptTemplate(
            input_variables=["query", "doc"],
            template="On a scale of 1-10, how relevant is this document to the query: '{query}'?\nDocument: {doc}\nRelevance score:"
        )
        ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)

        ranked_docs = []
        print("ranking docs")
        for doc in docs:
            input_data = {"query": enhanced_query, "doc": doc.page_content}
            score = float(ranking_chain.invoke(input_data).score)
            ranked_docs.append((doc, score))

        # Sort by relevance score and return top k
        ranked_docs.sort(key=lambda x: x[1], reverse=True)
        return [doc for doc, _ in ranked_docs[:k]]


class SelectedIndices(BaseModel):
    # 可选择的索引
    indices: List[int] = Field(description="Indices of selected documents", example=[0, 1, 2, 3])


class SubQueries(BaseModel):
    '''
    子序列
    '''
    sub_queries: List[str] = Field(
        description="List of sub-queries for comprehensive analysis",
        example=["What is the population of New York?", "What is the GDP of New York?"])


class AnalyticalRetrievalStrategy(BaseRetrievalStrategy):
    '''
    分析类检索
    '''
    def retrieve(self, query, k=4):
        print("retrieving analytical")
        # Use LLM to generate sub-queries for comprehensive analysis
        sub_queries_prompt = PromptTemplate(
            input_variables=["query", "k"],
            template="Generate {k} sub-questions for: {query}"
        )

        llm = self.llm
        sub_queries_chain = sub_queries_prompt | llm.with_structured_output(SubQueries)

        input_data = {"query": query, "k": k}
        sub_queries = sub_queries_chain.invoke(input_data).sub_queries
        print(f'sub queries for comprehensive analysis: {sub_queries}')

        all_docs = []
        for sub_query in sub_queries:
            all_docs.extend(self.db.similarity_search(sub_query, k=2))

        # Use LLM to ensure diversity and relevance
        diversity_prompt = PromptTemplate(
            input_variables=["query", "docs", "k"],
            template="""Select the most diverse and relevant set of {k} documents for the query: '{query}'\nDocuments: {docs}\n
            Return only the indices of selected documents as a list of integers."""
        )
        diversity_chain = diversity_prompt | self.llm.with_structured_output(SelectedIndices)
        docs_text = "\n".join([f"{i}: {doc.page_content[:50]}..." for i, doc in enumerate(all_docs)])
        input_data = {"query": query, "docs": docs_text, "k": k}
        selected_indices_result = diversity_chain.invoke(input_data).indices
        print(f'selected diverse and relevant documents')

        return [all_docs[i] for i in selected_indices_result if i < len(all_docs)]


class OpinionRetrievalStrategy(BaseRetrievalStrategy):
    '''
    观点类检索
    '''
    def retrieve(self, query, k=3):
        print("retrieving opinion")
        # Use LLM to identify potential viewpoints
        viewpoints_prompt = PromptTemplate(
            input_variables=["query", "k"],
            template="Identify {k} distinct viewpoints or perspectives on the topic: {query}"
        )
        viewpoints_chain = viewpoints_prompt | self.llm
        input_data = {"query": query, "k": k}
        viewpoints = viewpoints_chain.invoke(input_data).content.split('\n')
        print(f'viewpoints: {viewpoints}')

        all_docs = []
        for viewpoint in viewpoints:
            all_docs.extend(self.db.similarity_search(f"{query} {viewpoint}", k=2))

        # Use LLM to classify and select diverse opinions
        opinion_prompt = PromptTemplate(
            input_variables=["query", "docs", "k"],
            template="Classify these documents into distinct opinions on '{query}' and select the {k} most representative and diverse viewpoints:\nDocuments: {docs}\nSelected indices:"
        )
        opinion_chain = opinion_prompt | self.llm.with_structured_output(SelectedIndices)

        docs_text = "\n".join([f"{i}: {doc.page_content[:100]}..." for i, doc in enumerate(all_docs)])
        input_data = {"query": query, "docs": docs_text, "k": k}
        selected_indices = opinion_chain.invoke(input_data).indices
        print(f'selected diverse and relevant documents')

        return [all_docs[i] for i in selected_indices if i < len(all_docs)]

class ContextualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4, user_context=None):
        print("retrieving contextual")
        # Use LLM to incorporate user context into the query
        context_prompt = PromptTemplate(
            input_variables=["query", "context"],
            template="Given the user context: {context}\nReformulate the query to best address the user's needs: {query}"
        )
        context_chain = context_prompt | self.llm
        input_data = {"query": query, "context": user_context or "No specific context provided"}
        contextualized_query = context_chain.invoke(input_data).content
        print(f'contextualized query: {contextualized_query}')

        # Retrieve documents using the contextualized query
        docs = self.db.similarity_search(contextualized_query, k=k*2)

        # Use LLM to rank the relevance of retrieved documents considering the user context
        ranking_prompt = PromptTemplate(
            input_variables=["query", "context", "doc"],
            template="Given the query: '{query}' and user context: '{context}', rate the relevance of this document on a scale of 1-10:\nDocument: {doc}\nRelevance score:"
        )
        ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)
        print("ranking docs")

        ranked_docs = []
        for doc in docs:
            input_data = {"query": contextualized_query, "context": user_context or "No specific context provided", "doc": doc.page_content}
            score = float(ranking_chain.invoke(input_data).score)
            ranked_docs.append((doc, score))


        # Sort by relevance score and return top k
        ranked_docs.sort(key=lambda x: x[1], reverse=True)

        return [doc for doc, _ in ranked_docs[:k]]

class AdaptiveRetriever:
    def __init__(self, texts: List[str]):
        self.classifier = QueryClassifier()
        self.strategies = {
            "Factual": FactualRetrievalStrategy(texts),
            "Analytical": AnalyticalRetrievalStrategy(texts),
            "Opinion": OpinionRetrievalStrategy(texts),
            "Contextual": ContextualRetrievalStrategy(texts)
        }

    def get_relevant_documents(self, query: str) -> List[Document]:
        category = self.classifier.classify(query)
        strategy = self.strategies[category]
        return strategy.retrieve(query)

class PydanticAdaptiveRetriever(BaseRetriever):
    adaptive_retriever: 'AdaptiveRetriever' = Field(exclude=True)

    class Config:
        arbitrary_types_allowed = True

    def get_relevant_documents(self, query: str) -> List[Document]:
        return self.adaptive_retriever.get_relevant_documents(query)

    async def aget_relevant_documents(self, query: str) -> List[Document]:
        return self.get_relevant_documents(query)

    def _get_relevant_documents(self, query: str) -> List[Document]:
        return self.adaptive_retriever.get_relevant_documents(query)
    async def  _get_relevant_documents(self, query: str) -> List[Document]:
        return self.adaptive_retriever.get_relevant_documents(query)


class AdaptiveRAG:
    def __init__(self, texts: List[str]):
        adaptive_retriever = AdaptiveRetriever(texts)
        self.retriever = PydanticAdaptiveRetriever(adaptive_retriever=adaptive_retriever)
        self.llm = ChatOpenAI(temperature=0, model="qwen-plus", max_tokens=4000,
                              api_key=os.getenv("DASHSCOPE_API_KEY"), base_url=BASE_URL)

        # Create a custom prompt
        prompt_template = """Use the following pieces of context to answer the question at the end. 
        If you don't know the answer, just say that you don't know, don't try to make up an answer.

        {context}

        Question: {question}
        Answer:"""
        prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

        # Create the LLM chain
        self.llm_chain = prompt | self.llm

    def answer(self, query: str) -> str:
        docs = self.retriever.get_relevant_documents(query)
        input_data = {"context": "\n".join([doc.page_content for doc in docs]), "question": query}
        return self.llm_chain.invoke(input_data)

# Usage
texts = [
    "The Earth is the third planet from the Sun and the only astronomical object known to harbor life."
    ]
rag_system = AdaptiveRAG(texts)


factual_result = rag_system.answer("What is the distance between the Earth and the Sun?").content
print(f"Answer: {factual_result}")

analytical_result = rag_system.answer("How does the Earth's distance from the Sun affect its climate?").content
print(f"Answer: {analytical_result}")

opinion_result = rag_system.answer("What are the different theories about the origin of life on Earth?").content
print(f"Answer: {opinion_result}")

contextual_result = rag_system.answer("How does the Earth's position in the Solar System influence its habitability?").content
print(f"Answer: {contextual_result}")

在这里插入图片描述


网站公告

今日签到

点亮在社区的每一天
去签到