前置知识
什么是Embedding?
嵌入(Embedding)是一种常用的降维技术,能将文本、图像、音频、视频等各种形式的数据统一表示为同维度的向量。这种表示方式将原始数据中的高维、复杂特征提取并编码为一个 N 维向量(trait vector),每一维表示数据在某一语义特征上的表现。
无论是句子、图片,甚至声音片段,都可以通过嵌入模型转换为向量表示,从而便于统一处理、存储和比较。
检索为什么要用向量,为什么向量值越接近,相似度越高?
因为在现实中,文本、图片、音频等原始数据无法直接进行比较。而将它们转化为“向量”后,我们就可以用数学方法(如计算角度或距离)判断“两个东西像不像”。
url: https://static.docs-hub.com/cosinesimilarity_1742029975633.html
title: "余弦相似度计算"
host: static.docs-hub.com
通过测量两个向量的夹角的余弦值来度量它们之间的相似性,公式如下:
- 两个向量夹角为0°时,cos(θ)=1,方向相同 —> 非常相似
- 两个向量夹角为90°时,cos(θ)=0,方向完全相反,余弦相似度的值为-1 —> 向量无关
- 两个向量夹角为180°时,cos(θ)=0,方向完全相反,余弦相似度的值为-1 —> 极不相似
🧪 小例子:用余弦相似度判断两个句子的相似性
我们有两个中文句子:
- 句子 A:「这个程序代码太乱,那个代码规范」
- 句子 B:「这个程序代码不规范,那个更规范」
我们的目标是:判断它们是否相似?用向量来“算一算”。
Step 1:分词
- 句子 A 分词:
这个 / 程序 / 代码 / 太乱 / 那个 / 代码 / 规范
- 句子 B 分词:
这个 / 程序 / 代码 / 不 / 规范 / 那个 / 更 / 规范
Step 2:构建词袋(Bag-of-Words)
提取两个句子中出现的所有关键词,共计 8 个词:
词袋={这个,程序,代码,太乱,那个,规范,不,更}\text{词袋} = {\text{这个,程序,代码,太乱,那个,规范,不,更}}
Step 3:统计词频并构造向量
A ⃗ = [ 1 , 1 , 2 , 1 , 1 , 1 , 0 , 0 ] \vec{A} = [1,\ 1,\ 2,\ 1,\ 1,\ 1,\ 0,\ 0] A=[1, 1, 2, 1, 1, 1, 0, 0]
B ⃗ = [ 1 , 1 , 1 , 0 , 1 , 2 , 1 , 1 ] \vec{B} = [1,\ 1,\ 1,\ 0,\ 1,\ 2,\ 1,\ 1] B=[1, 1, 1, 0, 1, 2, 1, 1]
Step 4:计算余弦相似度(Cosine Similarity)
余弦相似度的公式为:
cos ( θ ) = A ⃗ ⋅ B ⃗ ∥ A ⃗ ∥ ⋅ ∥ B ⃗ ∥ \cos(\theta) = \frac{\vec{A} \cdot \vec{B}}{\|\vec{A}\| \cdot \|\vec{B}\|} cos(θ)=∥A∥⋅∥B∥A⋅B
所以相似度为:
A ⃗ ⋅ B ⃗ = 1 × 1 + 1 × 1 + 2 × 1 + 1 × 0 + 1 × 1 + 1 × 2 + 0 × 1 + 0 × 1 = 7 \vec{A} \cdot \vec{B} = 1 \times 1 + 1 \times 1 + 2 \times 1 + 1 \times 0 + 1 \times 1 + 1 \times 2 + 0 \times 1 + 0 \times 1 = 7 A⋅B=1×1+1×1+2×1+1×0+1×1+1×2+0×1+0×1=7
向量模长:
∥ A ⃗ ∥ = 1 2 + 1 2 + 2 2 + 1 2 + 1 2 + 1 2 + 0 2 + 0 2 = 11 \|\vec{A}\| = \sqrt{1^2 + 1^2 + 2^2 + 1^2 + 1^2 + 1^2 + 0^2 + 0^2} = \sqrt{11} ∥A∥=12+12+22+12+12+12+02+02=11
∥ B ⃗ ∥ = 1 2 + 1 2 + 1 2 + 0 2 + 1 2 + 2 2 + 1 2 + 1 2 = 12 \|\vec{B}\| = \sqrt{1^2 + 1^2 + 1^2 + 0^2 + 1^2 + 2^2 + 1^2 + 1^2} = \sqrt{12} ∥B∥=12+12+12+02+12+22+12+12=12
✅ 结果说明:
两个向量的余弦值约为 0.738,接近 1,说明两个句子在词的层面上较为相似。
⚠️ 局限性提示
尽管从“词频向量”的角度看两个句子很相似,但实际上它们在语义上可能差异巨大。
例如:
- 「这个代码很规范」 与 「这个代码不规范」
- 在词袋中几乎只差一个“不”,但语义完全相反
这正是向量模型的局限性所在: 它不理解语序,也无法捕捉否定、强调、主谓关系等语言逻辑。
什么是向量数据库?
向量数据库(Vector Database)是一种专门用于存储、检索高维向量数据的数据库,它通过计算向量之间的相似度(如余弦相似度、欧氏距离)来快速找到最匹配的结果。
传统数据库 vs 向量数据库 (传统数据库查表,向量数据库查相似度)
对比项 | 传统数据库 (MySQL) | 向量数据库(FAISS/Milvus) |
---|---|---|
存储内容 | 结构化数据(数字、字符串) | 高维向量(非结构化数据) |
查询方式 | 精确匹配(WHERE id=1 ) |
相似度搜索(找最像的向量) |
适用场景 | 订单管理、用户信息 | AI 推荐、语义搜索、图片检索 |
典型产品 | MySQL、PostgreSQL | FAISS、Milvus、Pinecone |
🚀 向量数据库的工作原理
- 原始数据:收集并预处理文本、图像、音频等原始输入
- 向量化:利用 BERT、CLIP 等模型,将多模态数据映射到同维度空间
- 索引构建:采用 HNSW、IVF 等高效结构,加速大规模向量检索
- 相似度检索:基于余弦距离/欧氏距离等度量,返回 Top‑K 最相似项。
🔧 主流向量数据库对比
数据库 | 开发者 | 特点 | 适用场景 |
---|---|---|---|
FAISS | Meta (Facebook) | 高性能,适合静态数据 | 大规模向量检索(推荐系统) |
Milvus | Zilliz | 开源,支持动态数据更新 | 实时搜索(电商、AI 客服) |
Pinecone | Pinecone | 全托管云服务,简单易用 | 快速部署 AI 应用 |
PostgreSQL + pgvector | PostgreSQL Global Development | 关系型数据库 + 向量扩展,支持复杂查询 | 业务+向量混合查询 |
Qdrant | Qdrant | 开源,实时同步,支持过滤、分片 | 个性化推荐、实时分析 |
什么是召回?
一句话总结:“该找的东西,你找回来多少?”
举个例子🌰
想象你在图书馆找 10本 关于“深度学习”的书:
- 你带回来 8本 相关的书(找对了)
- 但漏了 2本 在书架上(没找到)
👉 召回率 = 8 / (8+2) = 80%
核心关注点:有没有漏掉该找的东西?
📊 召回 vs 精确率(关键区别)
指标 | 核心问题 | 极端案例 |
---|---|---|
召回率 | 该找的是不是都找到了? | 把图书馆所有书搬来 → 召回率100%,但精确率极低 |
精确率 | 找到的东西是不是想要的? | 只拿1本书且正确 → 精确率100%,但召回率极低 |
召回率:避免漏检(如癌症筛查宁可多查)
精确率:避免误判(如垃圾邮件过滤不能误杀)
🔢 召回率公式
Recall = 找对的数量(TP) 应该找的总数(TP + FN) {\text{Recall} = \frac{\text{找对的数量(TP)}}{\text{应该找的总数(TP + FN)}}} Recall=应该找的总数(TP + FN)找对的数量(TP)
- TP(真正例):正确找回的目标(比如抓逃犯时抓对了人)
- FN(假反例):漏掉的目标(逃犯从你眼皮下溜走)
计算案例:
- 池塘里有 100条鱼,你捞上来 70条目标鱼,漏了 30条
- 召回率 = 70 / (70+30) = 70%
⚙️ 如何提升召回?
扩大搜索范围
- 搜索引擎:从返回Top-10改为Top-100结果
- 人脸识别:降低相似度阈值(比如从0.9调到0.6)
多路召回策略
- 同时用 关键词搜索(如“苹果手机”)+ 语义搜索(如“iPhone”)
- 电商场景:搜“连衣裙”时,同时召回“裙子”“长裙”等同义词商品
优化数据覆盖
- 补充冷门同义词(如“新冠”→“新型冠状病毒”)
- 处理生僻词(如医学名词“幽门螺杆菌”)
💡 小结一下
召回是 “宁可多找,不能漏网” 的指标——
- 适合 医疗诊断、安全监控等怕遗漏的场景
- 需与 精确率 平衡:召回太高会混入垃圾信息,太低会漏掉关键内容!
RAG是什么?
RAG(Retrieval-Augmented Generation)检索增强生成技术,是一种让大语言模型“查资料再作答”的技术。它的核心思想是:当用户提问时,先去外部的知识库里查找相关内容,再把这些资料提供给大语言模型作为参考,从而帮助它给出更准确、更新、更靠谱的答案,避免“胡说八道”或凭空编造内容。
RAG的核心流程
1. 向量知识库构建(离线)
- 文档采集:收集多源数据,包括网页、PDF、数据库内容等,作为知识库原始素材。
- 数据清洗与标准化:统一文档格式,清除噪音信息,保留结构化或半结构化文本。
- 文档分块(Chunking):按段落、标题结构或语义边界,将长文档切分为较小片段,兼顾语义完整性与后续检索效率。
- 片段嵌入向量化:使用嵌入模型(如 Qwen、BGE、M3E、Chinese-Alpaca-2 等)将每个文本块编码为高维向量。
- 向量存储入库:将向量及其对应元信息存入向量数据库(如 FAISS、Milvus 等),建立可检索知识索引。
2. 问题检索与重排序(在线)
- 问题向量化:将用户提出的问题通过同一嵌入模型转换为查询向量。
- 向量检索:基于向量相似度,在数据库中召回 Top-K 最相关片段,作为候选参考内容。
- 重排序(可选):借助 Cross-Encoder 或其他 Reranker 模型,对初步检索结果进行语义精度更高的相关性排序,从中选取 Top-N 片段输入生成阶段。
3. 生成阶段
- 构造上下文 Prompt:将用户问题与筛选后的文本片段拼接,形成增强提示输入。
- 语言模型生成回答:大语言模型在增强上下文基础上生成自然语言回答,内容更具事实依据与上下文相关性。
向量数据库(Milvus)安装与使用
前提:电脑上有Docker,Docker安装地址: https://www.docker.com/
1. Milvus安装
# 1. 创建一个新目录用于存放 Milvus 配置和脚本
mkdir milvus-standalone
# 2. 进入该目录
cd milvus-standalone/
# 3. 从 Milvus 官方 GitHub 下载 standalone 启动脚本
curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed.sh
# 4. 运行该脚本以启动 Milvus Standalone 实例(基于 Docker)
bash standalone_embed.sh start
# 5. (可选)下载 v2.0.2 版本的官方 docker-compose.yml 文件,便于手动控制容器部署
curl -L -o docker-compose.yml \
https://github.com/milvus-io/milvus/releases/download/v2.0.2/milvus-standalone-docker-compose.yml
# 手动部署(可选)
# 使用 docker-compose 启动 Milvus(前提是你已经有 docker-compose.yml)
docker-compose up -d
# 停止并删除容器
docker-compose down
# 停止 standalone 模式下的 Milvus 实例
bash standalone_embed.sh stop
2. 使用
访问webUrl: http://127.0.0.1:9091/webui/
创建Collecttion, collectionName: vector_store
Schema(模式 / 架构 / 结构描述)
它规定:有哪些字段、字段叫什么、数据类型是什么、是否主键、长度/维度是多少、允许哪些约束,以及数据之间怎么关联。
最终配置如下:
加载Schema后就可以正常使用了
spring-ai-alibaba实现RAG
这是一个基于 Spring AI Alibaba 的检索增强生成(RAG)示例项目,展示如何使用阿里云通义千问模型和 Milvus 向量数据库构建企业级 AI 应用。实现流程图如下:
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ 文档输入 │ │ 智能分词器 │ │ 向量化存储 │
│ │───▶│ │───▶│ │
│ • 文本输入 │ │ • 语义分割 │ │ • Milvus │
│ • 文件上传 │ │ • 段落分割 │ │ • 向量嵌入 │
│ • URL抓取 │ │ • 句子分割 │ │ • 索引构建 │
└─────────────────┘ └─────────────────┘ └─────────────────┘
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ 智能问答 │ │ 结果重排序 │ │ 向量检索 │
│ │◀───│ │◀───│ │
│ • 通义千问 │ │ • DashScope │ │ • 语义搜索 │
│ • 上下文增强 │ │ • 相关性评分 │ │ • 相似度计算 │
│ • 结果生成 │ │ • 结果优化 │ │ • TopK检索 │
└─────────────────┘ └─────────────────┘ └─────────────────┘
核心依赖
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-starter-dashscope</artifactId>
<version>1.0.0.2</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-milvus-store</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-vector-store-milvus</artifactId>
</dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-advisors-vector-store</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-tika-document-reader</artifactId>
</dependency>
配置文件
spring:
ai:
dashscope:
api-key: ${DASHSCOPE_API_KEY:your-api-key-here}
embedding:
options:
model: text-embedding-v4
dimensions: 1536
chat:
options:
model: qwen-plus-latest
vectorstore:
milvus:
client:
host: ${MILVUS_HOST:localhost}
port: ${MILVUS_PORT:19530}
username: ${MILVUS_USERNAME:root}
password: ${MILVUS_PASSWORD:milvus}
collectionName: vector_store
embeddingDimension: 1536
环境变量
export DASHSCOPE_API_KEY="your-dashscope-api-key"
export MILVUS_HOST="localhost"
export MILVUS_PORT="19530"
项目结构
src/main/java/com/alibaba/cloud/ai/example/rag/
├── config/ # 配置类
│ ├── PromptConfiguration.java # 提示词配置
│ └── TextSplitterConfiguration.java # 分词器配置
├── controller/ # REST 控制器
│ └── RagController.java # RAG API 接口
├── service/ # 业务服务
│ └── EnhancedRagService.java # 增强 RAG 服务
├── utils/ # 工具类
│ ├── DocumentUtils.java # 文档处理工具
│ └── RerankerUtils.java # 重排序工具
└── RagMilvusExampleApplication.java # 启动类
Config
package com.alibaba.cloud.ai.example.rag.config;
import com.alibaba.cloud.ai.example.rag.utils.DocumentUtils;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* 文本分割器配置类
*/
@Configuration
public class TextSplitterConfiguration {
/**
* 智能文本分割器配置属性
*/
@ConfigurationProperties(prefix = "spring.ai.text-splitter")
public static class TextSplitterProperties {
private int chunkSize = 400;
private int overlap = 50;
private DocumentUtils.SplitStrategy strategy = DocumentUtils.SplitStrategy.SEMANTIC;
private boolean keepSeparator = true;
public int getChunkSize() {
return chunkSize;
}
public void setChunkSize(int chunkSize) {
this.chunkSize = chunkSize;
}
public int getOverlap() {
return overlap;
}
public void setOverlap(int overlap) {
this.overlap = overlap;
}
public DocumentUtils.SplitStrategy getStrategy() {
return strategy;
}
public void setStrategy(DocumentUtils.SplitStrategy strategy) {
this.strategy = strategy;
}
public boolean isKeepSeparator() {
return keepSeparator;
}
public void setKeepSeparator(boolean keepSeparator) {
this.keepSeparator = keepSeparator;
}
}
/**
* 注册智能文本分割器Bean
*/ @Bean
public DocumentUtils.SmartTextSplitter smartTextSplitter(TextSplitterProperties properties) {
return new DocumentUtils.SmartTextSplitter(
properties.getChunkSize(),
properties.getOverlap(),
properties.getStrategy(),
properties.isKeepSeparator()
);
}
/**
* 注册配置属性Bean
*/ @Bean
@ConfigurationProperties(prefix = "")
public TextSplitterProperties textSplitterProperties() {
return new TextSplitterProperties();
}
}
Controller
package com.alibaba.cloud.ai.example.rag.controller;
import com.alibaba.cloud.ai.example.rag.service.EnhancedRagService;
import com.alibaba.cloud.ai.example.rag.utils.DocumentUtils;
import com.alibaba.cloud.ai.example.rag.utils.RerankerUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.util.*;
import java.util.stream.Collectors;
/**
* RAG控制器
* 提供知识库构建和智能问答的REST API接口
*/
@RestController
@RequestMapping("/ai")
public class RagController {
private static final Logger logger = LoggerFactory.getLogger(RagController.class);
@Autowired
private ChatClient chatClient;
@Autowired
private EnhancedRagService enhancedRagService;
/**
* 简单聊天接口
*/
@PostMapping("/chat")
public Map<String, String> chat(@RequestParam(value = "message", defaultValue = "请介绍一下Spring AI Alibaba") String message) {
return Map.of("message", chatClient.prompt().user(message).call().content());
}
/**
* 从文本构建知识库
*/
@PostMapping("/build-from-text")
public ResponseEntity<Map<String, Object>> buildKnowledgeBaseFromText(
@RequestParam("content") String content,
@RequestParam(value = "title", required = false) String title) {
logger.info("接收到文本构建知识库请求,标题: {}, 内容长度: {}", title, content.length());
try {
Map<String, Object> result = enhancedRagService.buildKnowledgeBaseFromText(content, title);
return ResponseEntity.ok(result);
} catch (Exception e) {
logger.error("从文本构建知识库失败", e);
Map<String, Object> error = new HashMap<>();
error.put("success", false);
error.put("error", e.getMessage());
return ResponseEntity.badRequest().body(error);
}
}
/**
* 从上传文件构建知识库
*/
@PostMapping("/build-from-file")
public ResponseEntity<Map<String, Object>> buildKnowledgeBaseFromFile(
@RequestParam("file") MultipartFile file) {
logger.info("接收到文件构建知识库请求,文件: {}, 大小: {}",
file.getOriginalFilename(), file.getSize());
try {
if (file.isEmpty()) {
throw new RuntimeException("上传文件不能为空");
}
Map<String, Object> result = enhancedRagService.buildKnowledgeBaseFromFile(file);
return ResponseEntity.ok(result);
} catch (Exception e) {
logger.error("从文件构建知识库失败", e);
Map<String, Object> error = new HashMap<>();
error.put("success", false);
error.put("error", e.getMessage());
return ResponseEntity.badRequest().body(error);
}
}
/**
* 从URL构建知识库
*/
@PostMapping("/build-from-url")
public ResponseEntity<Map<String, Object>> buildKnowledgeBaseFromUrl(
@RequestParam("url") String url) {
logger.info("接收到URL构建知识库请求,URL: {}", url);
try {
if (url == null || url.trim().isEmpty()) {
throw new RuntimeException("URL不能为空");
}
Map<String, Object> result = enhancedRagService.buildKnowledgeBaseFromUrl(url);
return ResponseEntity.ok(result);
} catch (Exception e) {
logger.error("从URL构建知识库失败", e);
Map<String, Object> error = new HashMap<>();
error.put("success", false);
error.put("error", e.getMessage());
return ResponseEntity.badRequest().body(error);
}
}
/**
* 构建示例知识库
*/
@PostMapping("/build-sample-knowledge-base")
public ResponseEntity<Map<String, Object>> buildSampleKnowledgeBase() {
logger.info("接收到构建示例知识库请求");
try {
enhancedRagService.buildKnowledgeBase();
Map<String, Object> result = new HashMap<>();
result.put("success", true);
result.put("message", "示例知识库构建完成");
return ResponseEntity.ok(result);
} catch (Exception e) {
logger.error("构建示例知识库失败", e);
Map<String, Object> error = new HashMap<>();
error.put("success", false);
error.put("error", e.getMessage());
return ResponseEntity.badRequest().body(error);
}
}
/**
* 增强RAG问答(简化版)
*/
@PostMapping("/enhanced-chat")
public ResponseEntity<Map<String, Object>> enhancedChat(
@RequestParam("question") String question) {
return enhancedChatWithParams(question, 5, true);
}
/**
* 增强RAG问答(自定义参数)
*/
@PostMapping("/enhanced-chat-custom")
public ResponseEntity<Map<String, Object>> enhancedChatWithParams(
@RequestParam("question") String question,
@RequestParam(value = "topK", defaultValue = "5") Integer topK,
@RequestParam(value = "useReranker", defaultValue = "true") Boolean useReranker) {
logger.info("接收到增强RAG问答请求,问题: {}, topK: {}, 重排序: {}", question, topK, useReranker);
try {
if (question == null || question.trim().isEmpty()) {
throw new RuntimeException("问题不能为空");
}
Map<String, Object> result = enhancedRagService.enhancedChat(question, topK, useReranker);
return ResponseEntity.ok(result);
} catch (Exception e) {
logger.error("增强RAG问答失败", e);
Map<String, Object> error = new HashMap<>();
error.put("success", false);
error.put("error", e.getMessage());
return ResponseEntity.badRequest().body(error);
}
}
/**
* 获取知识库状态
*/
@GetMapping("/knowledge-base-stats")
public ResponseEntity<Map<String, Object>> getKnowledgeBaseStats() {
logger.info("接收到获取知识库状态请求");
try {
Map<String, Object> stats = enhancedRagService.getKnowledgeBaseStats();
return ResponseEntity.ok(stats);
} catch (Exception e) {
logger.error("获取知识库状态失败", e);
Map<String, Object> error = new HashMap<>();
error.put("error", e.getMessage());
return ResponseEntity.badRequest().body(error);
}
}
/**
* 获取支持的文件类型
*/
@GetMapping("/supported-file-types")
public ResponseEntity<Map<String, Object>> getSupportedFileTypes() {
try {
Map<String, Object> result = new HashMap<>();
result.put("supported_types", DocumentUtils.getSupportedContentTypes());
result.put("description", "支持的文件类型包括:TXT、PDF、Word、Excel、HTML、XML、JSON等");
return ResponseEntity.ok(result);
} catch (Exception e) {
logger.error("获取支持的文件类型失败", e);
Map<String, Object> error = new HashMap<>();
error.put("error", e.getMessage());
return ResponseEntity.badRequest().body(error);
}
}
/**
* 健康检查接口
*/
@GetMapping("/health")
public ResponseEntity<Map<String, Object>> health() {
Map<String, Object> health = new HashMap<>();
health.put("status", "UP");
health.put("service", "Spring AI Alibaba RAG");
health.put("timestamp", System.currentTimeMillis());
return ResponseEntity.ok(health);
}
}
Service
package com.alibaba.cloud.ai.example.rag.service;
import com.alibaba.cloud.ai.example.rag.utils.DocumentUtils;
import com.alibaba.cloud.ai.example.rag.utils.RerankerUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import java.util.*;
import java.util.stream.Collectors;
/**
* 增强RAG服务
* 实现完整的RAG流程:文档采集→分块→向量化→存储→检索→重排序→生成
*/
@Service
public class EnhancedRagService {
private static final Logger logger = LoggerFactory.getLogger(EnhancedRagService.class);
// DashScope API 批量限制
private static final int BATCH_SIZE = 10;
@Autowired
private VectorStore vectorStore;
@Autowired
private ChatClient chatClient;
/**
* 从文本构建知识库
*
* @param content 文本内容
* @param title 文档标题
* @return 构建结果
*/
public Map<String, Object> buildKnowledgeBaseFromText(String content, String title) {
logger.info("开始从文本构建知识库,标题: {}", title);
try {
// 1. 文档采集
Document document = DocumentUtils.createFromText(content, title);
// 2. 文档分块
List<Document> chunks = DocumentUtils.chunkDocuments(Arrays.asList(document));
// 3. 向量化存储(分批处理避免API限制)
addDocumentsInBatches(chunks);
logger.info("成功从文本构建知识库,文档块数: {}", chunks.size());
Map<String, Object> result = new HashMap<>();
result.put("success", true);
result.put("chunks_count", chunks.size());
result.put("total_chars", content.length());
result.put("source_type", "text");
result.put("title", title);
return result;
} catch (Exception e) {
logger.error("从文本构建知识库失败", e);
throw new RuntimeException("从文本构建知识库失败: " + e.getMessage());
}
}
/**
* 从上传文件构建知识库
*
* @param file 上传的文件
* @return 构建结果
*/
public Map<String, Object> buildKnowledgeBaseFromFile(MultipartFile file) {
logger.info("开始从上传文件构建知识库: {}", file.getOriginalFilename());
try {
// 1. 文档采集
Document document = DocumentUtils.createFromUploadedFile(file);
// 2. 文档分块
List<Document> chunks = DocumentUtils.chunkDocuments(Arrays.asList(document));
// 3. 向量化存储(分批处理避免API限制)
addDocumentsInBatches(chunks);
logger.info("成功从文件构建知识库,文档块数: {}", chunks.size());
Map<String, Object> result = new HashMap<>();
result.put("success", true);
result.put("chunks_count", chunks.size());
result.put("file_name", file.getOriginalFilename());
result.put("file_size", file.getSize());
result.put("source_type", "uploaded_file");
result.put("content_type", document.getMetadata().get("detected_content_type"));
return result;
} catch (Exception e) {
logger.error("从文件构建知识库失败: {}", file.getOriginalFilename(), e);
throw new RuntimeException("从文件构建知识库失败: " + e.getMessage());
}
}
/**
* 从URL构建知识库
*
* @param url URL地址
* @return 构建结果
*/
public Map<String, Object> buildKnowledgeBaseFromUrl(String url) {
logger.info("开始从URL构建知识库: {}", url);
try {
// 1. 文档采集
Document document = DocumentUtils.createFromUrl(url);
// 2. 文档分块
List<Document> chunks = DocumentUtils.chunkDocuments(Arrays.asList(document));
// 3. 向量化存储(分批处理避免API限制)
addDocumentsInBatches(chunks);
logger.info("成功从URL构建知识库,文档块数: {}", chunks.size());
Map<String, Object> result = new HashMap<>();
result.put("success", true);
result.put("chunks_count", chunks.size());
result.put("url", url);
result.put("source_type", "url");
result.put("content_type", document.getMetadata().get("content_type"));
result.put("char_count", document.getMetadata().get("char_count"));
return result;
} catch (Exception e) {
logger.error("从URL构建知识库失败: {}", url, e);
throw new RuntimeException("从URL构建知识库失败: " + e.getMessage());
}
}
/**
* 构建示例知识库(用于演示)
*/
public void buildKnowledgeBase() {
logger.info("开始构建示例知识库...");
try {
// 创建示例文档
List<Document> sampleDocuments = DocumentUtils.createSampleDocuments();
// 对所有文档进行分块
List<Document> allChunks = DocumentUtils.chunkDocuments(sampleDocuments);
// 向量化存储(分批处理避免API限制)
addDocumentsInBatches(allChunks);
logger.info("示例知识库构建完成,总文档块数: {}", allChunks.size());
} catch (Exception e) {
logger.error("构建示例知识库失败", e);
throw new RuntimeException("构建示例知识库失败: " + e.getMessage());
}
}
/**
* 增强RAG问答
*
* @param question 用户问题
* @param topK 检索数量
* @param useReranker 是否使用重排序
* @return 问答结果
*/
public Map<String, Object> enhancedChat(String question, Integer topK, Boolean useReranker) {
logger.info("开始增强RAG问答,问题: {}, topK: {}, 重排序: {}", question, topK, useReranker);
try {
// 1. 向量检索
SearchRequest searchRequest = SearchRequest.builder()
.query(question)
.topK(topK != null ? topK : 5)
.similarityThreshold(0.1)
.build();
List<Document> retrievedDocs = vectorStore.similaritySearch(searchRequest);
logger.info("向量检索到 {} 个相关文档", retrievedDocs.size());
// 2. 重排序(可选)
List<Document> finalDocs = retrievedDocs;
if (useReranker != null && useReranker && !retrievedDocs.isEmpty()) {
try {
finalDocs = RerankerUtils.rerank(question, retrievedDocs, topK != null ? topK : 5);
logger.info("重排序后保留 {} 个文档", finalDocs.size());
} catch (Exception e) {
logger.warn("重排序失败,使用原始检索结果: {}", e.getMessage());
}
}
// 3. 构造上下文
String context = finalDocs.stream()
.map(Document::getText)
.collect(Collectors.joining("\n\n"));
// 4. 构造提示
String promptText = String.format("""
你是一个专业的AI助手,请基于以下上下文信息回答用户问题。
上下文信息:
%s 用户问题:%s
请注意:
1. 只基于提供的上下文信息进行回答
2. 如果上下文中没有相关信息,请明确说明
3. 回答要准确、详细且有帮助
4. 使用中文回答
""", context, question);
// 5. 生成回答
String answer = chatClient.prompt()
.user(promptText)
.call()
.content();
logger.info("增强RAG问答完成");
// 6. 构造返回结果
Map<String, Object> result = new HashMap<>();
result.put("question", question);
result.put("answer", answer);
result.put("retrieved_count", retrievedDocs.size());
result.put("reranked_count", finalDocs.size());
result.put("used_reranker", useReranker);
result.put("context_length", context.length());
// 添加检索到的文档信息
List<Map<String, Object>> docInfos = finalDocs.stream()
.map(doc -> {
Map<String, Object> info = new HashMap<>();
String content = doc.getText();
info.put("content_preview", content.substring(0, Math.min(100, content.length())) + "...");
info.put("content", content);
info.put("metadata", doc.getMetadata());
return info;
})
.collect(Collectors.toList());
result.put("retrieved_documents", docInfos);
return result;
} catch (Exception e) {
logger.error("增强RAG问答失败", e);
throw new RuntimeException("增强RAG问答失败: " + e.getMessage());
}
}
/**
* 获取知识库统计信息
*
* @return 统计信息
*/
public Map<String, Object> getKnowledgeBaseStats() {
try {
// 通过搜索空查询来获取总文档数(这是一个简单的估算方法)
SearchRequest searchRequest = SearchRequest.builder()
.query("")
.topK(1000)
.similarityThreshold(0.0)
.build();
List<Document> allDocs = vectorStore.similaritySearch(searchRequest);
Map<String, Object> stats = new HashMap<>();
stats.put("total_documents", allDocs.size());
// 按来源统计
Map<String, Long> sourceStats = allDocs.stream()
.collect(Collectors.groupingBy(
doc -> doc.getMetadata().getOrDefault("source", "unknown").toString(),
Collectors.counting()
));
stats.put("by_source", sourceStats);
// 内容统计
int totalChars = allDocs.stream()
.mapToInt(doc -> doc.getText().length())
.sum();
stats.put("total_chars", totalChars);
stats.put("avg_chars_per_doc", allDocs.isEmpty() ? 0 : totalChars / allDocs.size());
return stats;
} catch (Exception e) {
logger.error("获取知识库统计信息失败", e);
Map<String, Object> errorStats = new HashMap<>();
errorStats.put("error", "无法获取统计信息: " + e.getMessage());
return errorStats;
}
}
/**
* 分批插入文档,避免超过 DashScope API 的批量限制
* 包含最终长度验证,确保不超过Milvus varchar(512)限制
*/
private void addDocumentsInBatches(List<Document> documents) {
if (documents.isEmpty()) {
return;
}
logger.info("批量插入开始,接收到 {} 个文档", documents.size());
// 验证输入文档的长度
for (int i = 0; i < documents.size(); i++) {
Document doc = documents.get(i);
logger.info("输入文档 {}: 长度={}, ID={}, 预览=[{}...]",
i, doc.getText().length(), doc.getMetadata().get("doc_id"),
doc.getText().substring(0, Math.min(30, doc.getText().length())));
if (doc.getText().length() > 500) {
logger.error("严重问题:输入的文档 {} 长度 {} 已经超过500字符!", i, doc.getText().length());
}
}
// 最终长度验证 - 确保所有文档都不超过varchar(500)限制
final int MAX_CONTENT_LENGTH = 500; // 严格限制
List<Document> validatedDocuments = new ArrayList<>(documents.size());
for (Document doc : documents) {
String content = doc.getText();
if (content.length() > MAX_CONTENT_LENGTH) {
logger.warn("发现超长文档,长度: {} > {}, 进行强制截断", content.length(), MAX_CONTENT_LENGTH);
// 强制截断
String truncatedContent = content.substring(0, MAX_CONTENT_LENGTH);
// 优化:只在必要时创建新的元数据对象
Map<String, Object> newMetadata = new HashMap<>(4); // 预设容量
newMetadata.put("doc_id", doc.getMetadata().get("doc_id"));
newMetadata.put("source", doc.getMetadata().get("source"));
newMetadata.put("emergency_truncated", true);
newMetadata.put("truncated_length", truncatedContent.length());
Document validatedDoc = new Document(truncatedContent, newMetadata);
validatedDocuments.add(validatedDoc);
logger.info("文档截断完成:{} -> {} 字符", content.length(), truncatedContent.length());
} else {
validatedDocuments.add(doc);
}
}
int totalBatches = (int) Math.ceil((double) validatedDocuments.size() / BATCH_SIZE);
for (int i = 0; i < totalBatches; i++) {
int startIndex = i * BATCH_SIZE;
int endIndex = Math.min(startIndex + BATCH_SIZE, validatedDocuments.size());
List<Document> batch = validatedDocuments.subList(startIndex, endIndex);
try {
logger.info("插入第 {}/{} 批,文档数: {}", i + 1, totalBatches, batch.size());
// 改为逐个插入,避免批量插入的潜在问题
for (int docIndex = 0; docIndex < batch.size(); docIndex++) {
Document doc = batch.get(docIndex);
logger.info("插入文档 {}/{}: 长度={}, 内容=[{}...]",
docIndex + 1, batch.size(),
doc.getText().length(),
doc.getText().substring(0, Math.min(30, doc.getText().length())));
try {
// 单个文档插入
vectorStore.add(List.of(doc));
logger.debug("文档插入成功");
// 短暂休息避免API限流
Thread.sleep(50);
} catch (Exception docException) {
logger.error("单个文档插入失败,文档长度: {}, 错误: {}",
doc.getText().length(), docException.getMessage());
// 如果是长度问题,记录详细信息
if (docException.getMessage().contains("exceeds max length")) {
logger.error("确认是长度问题!文档实际内容: [{}]", doc.getText());
}
throw docException;
}
}
} catch (Exception e) {
logger.error("第 {} 批插入失败: {}", i + 1, e.getMessage());
throw new RuntimeException("批量插入失败: " + e.getMessage(), e);
}
}
logger.info("批量插入完成,共处理 {} 个文档", validatedDocuments.size());
}
}
Utils
DocumentUtils
package com.alibaba.cloud.ai.example.rag.utils;
import org.apache.tika.Tika;
import org.apache.tika.metadata.Metadata;
import org.apache.tika.metadata.TikaCoreProperties;
import org.apache.tika.parser.AutoDetectParser;
import org.apache.tika.parser.ParseContext;
import org.apache.tika.sax.BodyContentHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.web.multipart.MultipartFile;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.net.URL;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
* 文档处理工具类
* 整合文档采集、分块等功能
*/
public class DocumentUtils {
private static final Logger logger = LoggerFactory.getLogger(DocumentUtils.class);
private static final Tika tika = new Tika();
private static final AutoDetectParser parser = new AutoDetectParser();
// 分块参数配置
private static final int DEFAULT_CHUNK_SIZE = 400;
private static final int DEFAULT_OVERLAP = 50;
private static final int MAX_CHUNK_SIZE = 500; // 数据库字段限制
private static final int MAX_DOCUMENT_LENGTH = 100000; // 10万字符限制
// 文本分割策略枚举
public enum SplitStrategy {
CHARACTER, // 按字符分割
PARAGRAPH, // 按段落分割
SENTENCE, // 按句子分割
SEMANTIC // 语义分割(智能分割)
}
/**
* 文档转换器接口 - 兼容Spring AI的转换器模式
*/
public interface DocumentTransformer {
/**
* 转换文档列表
* @param documents 输入文档列表
* @return 转换后的文档列表
*/
List<Document> transform(List<Document> documents);
}
// ===========================================
// 文档采集功能
// ===========================================
/**
* 从纯文本创建文档
*/
public static Document createFromText(String content, String title) {
logger.info("从文本创建文档,标题: {}", title);
Map<String, Object> metadata = new HashMap<>();
String docId = "text_" + System.currentTimeMillis() + "_" + UUID.randomUUID().toString().substring(0, 8);
metadata.put("doc_id", docId);
metadata.put("source", "text");
metadata.put("title", title != null ? title : "用户输入文本");
metadata.put("created_time", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
metadata.put("content_type", "text/plain");
metadata.put("char_count", content.length());
return new Document(content, metadata);
}
/**
* 从上传文件创建文档
*/
public static Document createFromUploadedFile(MultipartFile file) {
logger.info("从上传文件创建文档: {}, 类型: {}, 大小: {}",
file.getOriginalFilename(), file.getContentType(), file.getSize());
try {
InputStream inputStream = file.getInputStream();
Metadata metadata = new Metadata();
if (file.getOriginalFilename() != null) {
metadata.set(TikaCoreProperties.RESOURCE_NAME_KEY, file.getOriginalFilename());
}
if (file.getContentType() != null && !file.getContentType().isEmpty()) {
metadata.set(Metadata.CONTENT_TYPE, file.getContentType());
}
String content = parseContent(inputStream, metadata);
Map<String, Object> docMetadata = createFileMetadata(file, metadata, content);
logger.info("成功提取文件内容,字符数: {}", content.length());
return new Document(content, docMetadata);
} catch (Exception e) {
logger.error("文件解析失败: {}", file.getOriginalFilename(), e);
throw new RuntimeException("文件解析失败: " + e.getMessage(), e);
}
}
/**
* 从URL创建文档
*/
public static Document createFromUrl(String urlString) {
logger.info("从URL创建文档: {}", urlString);
try {
URL url = new URL(urlString);
String fileName = extractFileNameFromUrl(urlString);
String fileExtension = extractFileExtensionFromUrl(urlString);
InputStream inputStream = url.openStream();
Metadata metadata = new Metadata();
if (fileName != null && !fileName.isEmpty()) {
metadata.set(TikaCoreProperties.RESOURCE_NAME_KEY, fileName);
}
if (fileExtension != null) {
String predictedContentType = predictContentTypeByExtension(fileExtension);
if (predictedContentType != null) {
metadata.set(Metadata.CONTENT_TYPE, predictedContentType);
}
}
String content = parseContent(inputStream, metadata);
Map<String, Object> docMetadata = createUrlMetadata(urlString, fileName, fileExtension, metadata, content);
logger.info("成功提取URL内容,字符数: {}", content.length());
return new Document(content, docMetadata);
} catch (Exception e) {
logger.error("URL解析失败: {}", urlString, e);
throw new RuntimeException("URL解析失败: " + e.getMessage(), e);
}
}
// ===========================================
// 改进的文档分块功能 - 实现Spring AI标准
// ===========================================
/**
* 智能文档分割器 - 实现Spring AI的DocumentTransformer接口规范
*/
public static class SmartTextSplitter implements DocumentTransformer {
private final int chunkSize;
private final int overlap;
private final SplitStrategy strategy;
private final boolean keepSeparator;
public SmartTextSplitter() {
this(DEFAULT_CHUNK_SIZE, DEFAULT_OVERLAP, SplitStrategy.SEMANTIC, true);
}
public SmartTextSplitter(int chunkSize, int overlap) {
this(chunkSize, overlap, SplitStrategy.SEMANTIC, true);
}
public SmartTextSplitter(int chunkSize, int overlap, SplitStrategy strategy) {
this(chunkSize, overlap, strategy, true);
}
public SmartTextSplitter(int chunkSize, int overlap, SplitStrategy strategy, boolean keepSeparator) {
this.chunkSize = Math.min(chunkSize, MAX_CHUNK_SIZE);
this.overlap = Math.max(0, Math.min(overlap, chunkSize - 1));
this.strategy = strategy;
this.keepSeparator = keepSeparator;
}
@Override
public List<Document> transform(List<Document> documents) {
List<Document> chunkedDocuments = new ArrayList<>();
for (Document document : documents) {
List<Document> chunks = splitDocument(document);
chunkedDocuments.addAll(chunks);
}
logger.info("智能分块完成,原始文档: {}, 生成分块: {}, 策略: {}",
documents.size(), chunkedDocuments.size(), strategy);
return chunkedDocuments;
}
private List<Document> splitDocument(Document document) {
String content = document.getText();
// 预处理:截断超长文档
if (content.length() > MAX_DOCUMENT_LENGTH) {
logger.warn("文档过大({} 字符),截断到 {} 字符", content.length(), MAX_DOCUMENT_LENGTH);
content = content.substring(0, MAX_DOCUMENT_LENGTH);
}
if (content.length() <= chunkSize) {
return Collections.singletonList(document);
}
List<String> textChunks = switch (strategy) {
case PARAGRAPH -> splitByParagraphs(content);
case SENTENCE -> splitBySentences(content);
case SEMANTIC -> splitBySemantic(content);
default -> splitByCharacters(content);
};
return createDocumentChunks(document, textChunks);
}
/**
* 语义分割 - 综合考虑段落、句子和长度
*/
private List<String> splitBySemantic(String text) {
List<String> chunks = new ArrayList<>();
// 首先按段落分割
List<String> paragraphs = Arrays.stream(text.split("\n\\s*\n"))
.map(String::trim)
.filter(p -> !p.isEmpty())
.collect(Collectors.toList());
StringBuilder currentChunk = new StringBuilder();
for (String paragraph : paragraphs) {
// 如果单个段落就超过分块大小,需要进一步分割
if (paragraph.length() > chunkSize) {
// 先保存当前已有的内容
if (currentChunk.length() > 0) {
chunks.add(currentChunk.toString().trim());
currentChunk.setLength(0);
}
// 对长段落按句子分割
List<String> sentences = splitBySentences(paragraph);
chunks.addAll(mergeSmallChunks(sentences));
} else {
// 检查加入当前段落后是否超过限制
int futureLength = currentChunk.length() + paragraph.length() +
(currentChunk.length() > 0 ? 2 : 0); // 段落间的换行符
if (futureLength > chunkSize && currentChunk.length() > 0) {
// 保存当前块并开始新块
chunks.add(currentChunk.toString().trim());
currentChunk.setLength(0);
}
if (currentChunk.length() > 0) {
currentChunk.append("\n\n");
}
currentChunk.append(paragraph);
}
}
// 添加最后一个块
if (currentChunk.length() > 0) {
chunks.add(currentChunk.toString().trim());
}
return chunks.isEmpty() ? Collections.singletonList(text) : chunks;
}
/**
* 按段落分割
*/
private List<String> splitByParagraphs(String text) {
List<String> paragraphs = Arrays.stream(text.split("\n\\s*\n"))
.map(String::trim)
.filter(p -> !p.isEmpty())
.collect(Collectors.toList());
return mergeSmallChunks(paragraphs);
}
/**
* 按句子分割
*/
private List<String> splitBySentences(String text) {
// 中文句子分割正则表达式
Pattern sentencePattern = Pattern.compile("[。!?;\\.\\!\\?;]+");
String[] sentences = sentencePattern.split(text);
List<String> validSentences = Arrays.stream(sentences)
.map(String::trim)
.filter(s -> !s.isEmpty())
.collect(Collectors.toList());
return mergeSmallChunks(validSentences);
}
/**
* 按字符分割(带重叠)
*/
private List<String> splitByCharacters(String text) {
List<String> chunks = new ArrayList<>();
int start = 0;
while (start < text.length()) {
int end = Math.min(start + chunkSize, text.length());
String chunk = text.substring(start, end);
// 尝试在自然边界处分割(避免在词语中间分割)
if (end < text.length() && !Character.isWhitespace(text.charAt(end))) {
int lastSpace = chunk.lastIndexOf(' ');
int lastNewline = chunk.lastIndexOf('\n');
int naturalEnd = Math.max(lastSpace, lastNewline);
if (naturalEnd > start + chunkSize / 2) { // 确保不会切割过小
end = start + naturalEnd;
chunk = text.substring(start, end);
}
}
chunks.add(chunk);
start = Math.max(start + chunkSize - overlap, start + 1); // 避免无限循环
}
return chunks;
}
/**
* 合并小块,避免产生过小的分块
*/
private List<String> mergeSmallChunks(List<String> pieces) {
List<String> chunks = new ArrayList<>();
StringBuilder currentChunk = new StringBuilder();
for (String piece : pieces) {
int futureLength = currentChunk.length() + piece.length() +
(currentChunk.length() > 0 ? 1 : 0); // 分隔符
if (futureLength > chunkSize && currentChunk.length() > 0) {
chunks.add(currentChunk.toString().trim());
currentChunk.setLength(0);
}
if (currentChunk.length() > 0) {
currentChunk.append(keepSeparator ? "\n" : " ");
}
currentChunk.append(piece);
}
if (currentChunk.length() > 0) {
chunks.add(currentChunk.toString().trim());
}
return chunks;
}
/**
* 创建文档分块
*/
private List<Document> createDocumentChunks(Document originalDoc, List<String> textChunks) {
List<Document> chunks = new ArrayList<>();
for (int i = 0; i < textChunks.size(); i++) {
String chunkText = textChunks.get(i);
// 强制截断超长内容
if (chunkText.length() > MAX_CHUNK_SIZE) {
logger.warn("分块内容超过硬限制,强制截断: {} -> {}", chunkText.length(), MAX_CHUNK_SIZE);
chunkText = chunkText.substring(0, MAX_CHUNK_SIZE);
}
Document chunk = createChunk(originalDoc, chunkText, i, textChunks.size());
chunks.add(chunk);
}
return chunks;
}
/**
* 创建单个文档块
*/
private Document createChunk(Document originalDoc, String chunkContent, int chunkIndex, int totalChunks) {
Map<String, Object> chunkMetadata = new HashMap<>();
// 生成唯一的块ID
Object originalId = originalDoc.getMetadata().get("doc_id");
String chunkId;
if (originalId != null) {
chunkId = originalId + "_chunk_" + chunkIndex;
chunkMetadata.put("original_doc_id", originalId);
} else {
chunkId = "chunk_" + System.currentTimeMillis() + "_" + chunkIndex + "_" +
UUID.randomUUID().toString().substring(0, 8);
}
// 设置块元数据
chunkMetadata.put("doc_id", chunkId);
chunkMetadata.put("chunk_index", chunkIndex);
chunkMetadata.put("total_chunks", totalChunks);
chunkMetadata.put("chunk_size", chunkContent.length());
chunkMetadata.put("split_strategy", strategy.name());
chunkMetadata.put("created_time", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
// 继承关键的原始元数据
String source = (String) originalDoc.getMetadata().get("source");
if (source != null) {
chunkMetadata.put("source", source);
}
String filename = (String) originalDoc.getMetadata().get("filename");
if (filename != null) {
chunkMetadata.put("filename", filename);
}
String title = (String) originalDoc.getMetadata().get("title");
if (title != null) {
chunkMetadata.put("title", title);
}
return new Document(chunkContent, chunkMetadata);
}
}
/**
* 使用默认参数分块文档
*/
public static List<Document> chunkDocuments(List<Document> documents) {
SmartTextSplitter splitter = new SmartTextSplitter();
return splitter.transform(documents);
}
/**
* 使用自定义参数分块文档
*/
public static List<Document> chunkDocuments(List<Document> documents, int chunkSize, int overlap) {
SmartTextSplitter splitter = new SmartTextSplitter(chunkSize, overlap);
return splitter.transform(documents);
}
/**
* 使用指定策略分块文档
*/
public static List<Document> chunkDocuments(List<Document> documents, int chunkSize, int overlap, SplitStrategy strategy) {
SmartTextSplitter splitter = new SmartTextSplitter(chunkSize, overlap, strategy);
return splitter.transform(documents);
}
// ===========================================
// 私有辅助方法
// ===========================================
/**
* 解析内容(限制内容长度,避免内存溢出)
*/
private static String parseContent(InputStream inputStream, Metadata metadata) throws Exception {
// 限制解析内容的最大长度为5MB,避免内存溢出
final int MAX_CONTENT_LENGTH = 5 * 1024 * 1024; // 5MB
BodyContentHandler handler = new BodyContentHandler(MAX_CONTENT_LENGTH);
ParseContext parseContext = new ParseContext();
try {
parser.parse(inputStream, handler, metadata, parseContext);
} catch (Exception e) {
if (e.getMessage() != null && e.getMessage().contains("content length")) {
logger.warn("文档内容超过{}字符限制,已截断", MAX_CONTENT_LENGTH);
} else {
throw e;
}
}
String content = handler.toString();
if (content.trim().isEmpty()) {
throw new RuntimeException("内容为空或无法解析");
}
logger.info("文档解析完成,内容长度: {} 字符", content.length());
return content;
}
/**
* 创建文件元数据
*/
private static Map<String, Object> createFileMetadata(MultipartFile file, Metadata metadata, String content) {
Map<String, Object> docMetadata = new HashMap<>();
// 添加必需的 doc_id 字段(对应 Milvus collection schema)
String docId = "file_" + System.currentTimeMillis() + "_" + UUID.randomUUID().toString().substring(0, 8);
docMetadata.put("doc_id", docId);
docMetadata.put("source", "uploaded_file");
docMetadata.put("filename", file.getOriginalFilename());
docMetadata.put("file_size", file.getSize());
docMetadata.put("original_content_type", file.getContentType());
docMetadata.put("detected_content_type", metadata.get(Metadata.CONTENT_TYPE));
docMetadata.put("created_time", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
docMetadata.put("char_count", content.length());
// 添加Tika提取的元数据
addTikaMetadata(docMetadata, metadata);
return docMetadata;
}
/**
* 创建URL元数据
*/
private static Map<String, Object> createUrlMetadata(String url, String fileName, String fileExtension,
Metadata metadata, String content) {
Map<String, Object> docMetadata = new HashMap<>();
// 添加必需的 doc_id 字段(对应 Milvus collection schema)
String docId = "url_" + System.currentTimeMillis() + "_" + UUID.randomUUID().toString().substring(0, 8);
docMetadata.put("doc_id", docId);
docMetadata.put("source", "url");
docMetadata.put("url", url);
docMetadata.put("filename", fileName);
docMetadata.put("file_extension", fileExtension);
docMetadata.put("content_type", metadata.get(Metadata.CONTENT_TYPE));
docMetadata.put("created_time", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
docMetadata.put("char_count", content.length());
// 添加Tika提取的元数据
addTikaMetadata(docMetadata, metadata);
return docMetadata;
}
/**
* 添加Tika提取的元数据
*/
private static void addTikaMetadata(Map<String, Object> docMetadata, Metadata metadata) {
for (String name : metadata.names()) {
String value = metadata.get(name);
if (value != null && !value.trim().isEmpty()) {
docMetadata.put("tika_" + name.toLowerCase().replace(":", "_"), value);
}
}
}
/**
* 从URL中提取文件名
*/
private static String extractFileNameFromUrl(String urlString) {
try {
String path = new URL(urlString).getPath();
if (path != null && !path.isEmpty()) {
String[] parts = path.split("/");
return parts[parts.length - 1];
}
} catch (Exception e) {
logger.debug("无法从URL提取文件名: {}", urlString);
}
return null;
}
/**
* 从URL中提取文件扩展名
*/
private static String extractFileExtensionFromUrl(String urlString) {
String fileName = extractFileNameFromUrl(urlString);
return extractFileExtensionFromFilename(fileName);
}
/**
* 从文件名中提取扩展名
*/
private static String extractFileExtensionFromFilename(String filename) {
if (filename != null && filename.contains(".")) {
return filename.substring(filename.lastIndexOf(".") + 1).toLowerCase();
}
return null;
}
/**
* 根据文件扩展名预测Content-Type
*/ private static String predictContentTypeByExtension(String extension) {
if (extension == null) return null;
return switch (extension.toLowerCase()) {
case "pdf" -> "application/pdf";
case "doc" -> "application/msword";
case "docx" -> "application/vnd.openxmlformats-officedocument.wordprocessingml.document";
case "xls" -> "application/vnd.ms-excel";
case "xlsx" -> "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet";
case "ppt" -> "application/vnd.ms-powerpoint";
case "pptx" -> "application/vnd.openxmlformats-officedocument.presentationml.presentation";
case "txt" -> "text/plain";
case "html", "htm" -> "text/html";
case "xml" -> "text/xml";
case "json" -> "application/json";
case "rtf" -> "application/rtf";
default -> null;
};
}
/**
* 检测文件类型
*/
public static String detectContentType(byte[] bytes) {
try {
return tika.detect(bytes);
} catch (Exception e) {
logger.warn("文件类型检测失败", e);
return "application/octet-stream";
}
}
/**
* 获取支持的文件类型列表
*/
public static List<String> getSupportedContentTypes() {
return Arrays.asList(
"text/plain", "text/html", "text/xml",
"application/pdf", "application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
"application/rtf", "application/json"
);
}
RerankerUtils
package com.alibaba.cloud.ai.example.rag.utils;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.*;
import java.util.stream.Collectors;
import static com.alibaba.cloud.ai.example.rag.enums.KeyEnum.DASHSCOPE_KEY;
/**
* 重排序工具类
* 提供文档重排序功能
*/
public class RerankerUtils {
private static final Logger logger = LoggerFactory.getLogger(RerankerUtils.class);
private static final String DEFAULT_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank";
private static final String DEFAULT_MODEL = "gte-rerank-v2";
private static final int DEFAULT_TIMEOUT_SECONDS = 30;
private static final ObjectMapper objectMapper = new ObjectMapper();
private static final HttpClient httpClient = HttpClient.newBuilder()
.connectTimeout(Duration.ofSeconds(DEFAULT_TIMEOUT_SECONDS))
.build();
/**
* 重排序文档列表
*
* @param query 查询文本
* @param documents 待重排序的文档列表
* @param topN 返回前N个文档
* @return 重排序后的文档列表
*/
public static List<Document> rerank(String query, List<Document> documents, int topN) {
return rerank(query, documents, topN, null, DEFAULT_MODEL);
}
/**
* 重排序文档列表(自定义API Key和模型)
*
* @param query 查询文本
* @param documents 待重排序的文档列表
* @param topN 返回前N个文档
* @param apiKey API密钥
* @param model 重排序模型
* @return 重排序后的文档列表
*/
public static List<Document> rerank(String query, List<Document> documents, int topN, String apiKey, String model) {
if (documents == null || documents.isEmpty()) {
logger.warn("文档列表为空,跳过重排序");
return new ArrayList<>();
}
if (topN <= 0) {
topN = documents.size();
}
logger.info("开始重排序,查询: {}, 文档数: {}, topN: {}", query, documents.size(), topN);
try {
// 获取API密钥
String effectiveApiKey = DASHSCOPE_KEY.getKey();
if (effectiveApiKey == null || effectiveApiKey.trim().isEmpty()) {
throw new RuntimeException("未找到有效的API密钥");
}
// 构建请求
RerankerRequest request = buildRerankerRequest(query, documents, Math.min(topN, documents.size()), model);
String requestJson = objectMapper.writeValueAsString(request);
// 发送HTTP请求
HttpRequest httpRequest = HttpRequest.newBuilder()
.uri(URI.create(DEFAULT_API_URL))
.header("Content-Type", "application/json")
.header("Authorization", "Bearer " + effectiveApiKey)
.timeout(Duration.ofSeconds(DEFAULT_TIMEOUT_SECONDS))
.POST(HttpRequest.BodyPublishers.ofString(requestJson))
.build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
if (response.statusCode() != 200) {
throw new RuntimeException("重排序API调用失败,状态码: " + response.statusCode() + ", 响应: " + response.body());
}
// 解析响应
RerankerResponse rerankerResponse = objectMapper.readValue(response.body(), RerankerResponse.class);
if (rerankerResponse.getOutput() == null || rerankerResponse.getOutput().getResults() == null) {
throw new RuntimeException("重排序响应格式错误");
}
// 根据重排序结果重新排列文档
List<Document> rerankedDocuments = processRerankerResults(documents, rerankerResponse.getOutput().getResults());
logger.info("重排序完成,返回 {} 个文档", rerankedDocuments.size());
return rerankedDocuments;
} catch (Exception e) {
logger.error("重排序失败,返回原始文档列表: {}", e.getMessage());
// 重排序失败时,返回原始文档的前topN个
return documents.subList(0, Math.min(topN, documents.size()));
}
}
/**
* 构建重排序请求
*/
private static RerankerRequest buildRerankerRequest(String query, List<Document> documents, int topN, String model) {
RerankerRequest request = new RerankerRequest();
request.setModel(model != null ? model : DEFAULT_MODEL);
RerankerInput input = new RerankerInput();
input.setQuery(query);
input.setDocuments(documents.stream()
.map(Document::getText)
.collect(Collectors.toList()));
request.setInput(input);
RerankerParameters parameters = new RerankerParameters();
parameters.setTopN(topN);
parameters.setReturnDocuments(true);
request.setParameters(parameters);
return request;
}
/**
* 处理重排序结果
*/
private static List<Document> processRerankerResults(List<Document> originalDocuments, List<RerankerResult> results) {
List<Document> rerankedDocuments = new ArrayList<>();
for (RerankerResult result : results) {
int index = result.getIndex();
if (index >= 0 && index < originalDocuments.size()) {
Document originalDoc = originalDocuments.get(index);
// 验证返回的文本是否与原始文档匹配
if (result.getDocument() != null && result.getDocument().getText() != null) {
String returnedText = result.getDocument().getText().trim();
String originalText = originalDoc.getText().trim();
if (!returnedText.equals(originalText)) {
logger.warn("重排序返回的文档文本与原始文档不匹配,索引: {}", index);
}
}
// 创建新的元数据,添加重排序信息
Map<String, Object> newMetadata = new HashMap<>(originalDoc.getMetadata());
newMetadata.put("rerank_score", result.getRelevanceScore());
newMetadata.put("rerank_index", result.getIndex());
newMetadata.put("rerank_time", System.currentTimeMillis());
Document rerankedDoc = new Document(originalDoc.getText(), newMetadata);
rerankedDocuments.add(rerankedDoc);
} else {
logger.warn("重排序结果索引超出范围: {}, 文档总数: {}", index, originalDocuments.size());
}
}
return rerankedDocuments;
}
// ===========================================
// 数据传输对象
// ===========================================
public static class RerankerRequest {
private String model;
private RerankerInput input;
private RerankerParameters parameters;
// Getters and Setters
public String getModel() { return model; }
public void setModel(String model) { this.model = model; }
public RerankerInput getInput() { return input; }
public void setInput(RerankerInput input) { this.input = input; }
public RerankerParameters getParameters() { return parameters; }
public void setParameters(RerankerParameters parameters) { this.parameters = parameters; }
}
public static class RerankerInput {
private String query;
private List<String> documents;
// Getters and Setters
public String getQuery() { return query; }
public void setQuery(String query) { this.query = query; }
public List<String> getDocuments() { return documents; }
public void setDocuments(List<String> documents) { this.documents = documents; }
}
public static class RerankerParameters {
@JsonProperty("top_n")
private int topN;
@JsonProperty("return_documents")
private boolean returnDocuments;
// Getters and Setters
public int getTopN() { return topN; }
public void setTopN(int topN) { this.topN = topN; }
public boolean isReturnDocuments() { return returnDocuments; }
public void setReturnDocuments(boolean returnDocuments) { this.returnDocuments = returnDocuments; }
}
public static class RerankerResponse {
private RerankerOutput output;
private RerankerUsage usage;
@JsonProperty("request_id")
private String requestId;
// Getters and Setters
public RerankerOutput getOutput() { return output; }
public void setOutput(RerankerOutput output) { this.output = output; }
public RerankerUsage getUsage() { return usage; }
public void setUsage(RerankerUsage usage) { this.usage = usage; }
public String getRequestId() { return requestId; }
public void setRequestId(String requestId) { this.requestId = requestId; }
}
public static class RerankerOutput {
private List<RerankerResult> results;
// Getters and Setters
public List<RerankerResult> getResults() { return results; }
public void setResults(List<RerankerResult> results) { this.results = results; }
}
public static class RerankerResult {
private RerankerDocument document;
private int index;
@JsonProperty("relevance_score")
private double relevanceScore;
// Getters and Setters
public RerankerDocument getDocument() { return document; }
public void setDocument(RerankerDocument document) { this.document = document; }
public int getIndex() { return index; }
public void setIndex(int index) { this.index = index; }
public double getRelevanceScore() { return relevanceScore; }
public void setRelevanceScore(double relevanceScore) { this.relevanceScore = relevanceScore; }
}
public static class RerankerDocument {
private String text;
// Getters and Setters
public String getText() { return text; }
public void setText(String text) { this.text = text; }
}
public static class RerankerUsage {
@JsonProperty("total_tokens")
private int totalTokens;
// Getters and Setters
public int getTotalTokens() { return totalTokens; }
public void setTotalTokens(int totalTokens) { this.totalTokens = totalTokens; }
}
}
测试
扩展:优化方法
由于分块可能会导致语言被截断,所以介绍以下两种优化方法:
知识图谱(GraphRAG)
Graph RAG就是一种先把文本转成知识图谱,再通过图结构精准找到答案相关知识,最后用大模型生成更准确回答的方法,这样就解决了传统RAG由于分块可能会导致语言被截断的难题,不过这个方法由于需要调用llm抽取关系,所以成本比较高,目前还没有普及。
Graph RAG 方法的核心链路分为索引、检索与生成三个阶段:
索引阶段
通过调用大语言模型(LLM)服务实现对原始文档的语义三元组抽取(头实体-关系-尾实体),并将提取的结构化信息存入图数据库,从而将非结构化文本转变为便于检索的结构化知识图谱。检索阶段
使用大语言模型对用户查询进行关键词提取和语义泛化处理(如大小写统一、别称识别、同义词扩展等),随后基于图数据库,以关键词为起点执行子图遍历(通常使用DFS或BFS算法),从而快速召回N跳以内与问题相关的局部知识子图。生成阶段
将召回的局部子图转化为适合LLM处理的文本上下文,并与原始问题一并输入大语言模型中进行回答生成,利用图结构上下文的信息增强模型生成回答的准确性和结构化表达能力。
GraphRAG的架构设计
和RAG对比
以下是我部署的LightRAG的实际效果,LightRAG其实就是GraphRAG的轻量简化版。
RAFT方法(Retrieval Augmented Fine Tuning)
RAFT 是一种用“逐词评分”方式进行训练的方法,通过对模型每个生成词打分,引导它学习更符合人类偏好的输出,同时结合强化学习和奖励权重逐步加大的机制,使模型生成更自然、符合指令。
简单来说就是一种让大模型“学会开卷考试”的微调方法,核心思想是:在训练时模拟真实检索环境,教模型如何利用检索到的文档回答问题。
举个例子🌰
如何最好地准备考试?
(a) 基于微调的方法通过“学习”来实现“记忆”输入文档或回答练习题而不参考文档。
(b) 或者,基于上下文检索的方法未能利用固定领域所提供的学习机会,相当于参加开卷考试但没有事先复习。
© 相比之下,我们的方法RAFT利用了微调与问答对,并在一个模拟的不完美检索环境中参考文档——从而有效地为开卷考试环境做准备。
RAFT方法(Retrieval Augmented Fine Tuning):
• 适应特定领域的LLMs对于许多新兴应用至关重要,但如何有效融入信息仍是一个开放问题。
• RAFT结合了检索增强生成(RAG)和监督微调(SFT),从而提高模型在特定领域内回答问题的能力。
• 训练模型识别并忽略那些不能帮助回答问题的干扰文档,只关注和引用相关的文档。
• 通过在训练中引入干扰文档,提高模型对干扰信息的鲁棒性,使其在测试时能更好地处理检索到的文档。
RAFT在所有专业领域的RAG性能上有所提升(在PubMed、HotPot、HuggingFace、Torch Hub和Tensorflow Hub等多个领域)
领域特定的微调提高了基础模型的性能,RAFT无论是在有RAG的情况下还是没有RAG的情况下,都持续优于现有的领域特定微调方法。这表明了需要在上下文中训练模型。