Spring AI(14)——文本分块优化

发布于:2025-07-27 ⋅ 阅读:(13) ⋅ 点赞:(0)

RAG时,检索效果的优劣,和文本的分块的情况有很大关系。

SpringAI中通过TokenTextSplitter对文本分块。本文对SpringAI提供的TokenTextSplitter源码进行了分析,并给出一些自己的想法,欢迎大家互相探讨。

查看了TokenTextSplitter的源码,其进行文本分块的核心代码如下:

protected List<String> doSplit(String text, int chunkSize) {
        if (text != null && !text.trim().isEmpty()) {
            // 将分割的内容转为对应token的列表
            List<Integer> tokens = this.getEncodedTokens(text);
            List<String> chunks = new ArrayList();
            int num_chunks = 0;

            while(!tokens.isEmpty() && num_chunks < this.maxNumChunks) {
                // 根据token列表,按照chunkSize或者token列表长度的最小值进行截取
                List<Integer> chunk = tokens.subList(0, Math.min(chunkSize, tokens.size()));
                // 将token转为字符串
                String chunkText = this.decodeTokens(chunk);
                if (chunkText.trim().isEmpty()) {
                    tokens = tokens.subList(chunk.size(), tokens.size());
                } else {
                    // 从文本最后开始,获取英文的.!?和换行符的索引
                    int lastPunctuation = Math.max(chunkText.lastIndexOf(46), Math.max(chunkText.lastIndexOf(63), Math.max(chunkText.lastIndexOf(33), chunkText.lastIndexOf(10))));
                    // 如果索引值不是-1,并且索引大于分块的最小的字符数,对分块内容进行截取
                    if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) {
                        chunkText = chunkText.substring(0, lastPunctuation + 1);
                    }
                    // 如果keepSeparator是false,将本文中的换行符替换为空格
                    String chunkTextToAppend = this.keepSeparator ? chunkText.trim() : chunkText.replace(System.lineSeparator(), " ").trim();
                    if (chunkTextToAppend.length() > this.minChunkLengthToEmbed) {
                        // 将分块内容添加到分块列表中
                        chunks.add(chunkTextToAppend);
                    }
                    // 对原来的token列表进行截取,用于排除已经分块的内容
                    tokens = tokens.subList(this.getEncodedTokens(chunkText).size(), tokens.size());
                    ++num_chunks;
                }
            }

            if (!tokens.isEmpty()) {
                String remaining_text = this.decodeTokens(tokens).replace(System.lineSeparator(), " ").trim();
                if (remaining_text.length() > this.minChunkLengthToEmbed) {
                    chunks.add(remaining_text);
                }
            }

            return chunks;
        } else {
            return new ArrayList();
        }
    }

参数说明: 

chunkSize: 每个文本块以 token 为单位的目标大小(默认值:800)。
minChunkSizeChars: 每个文本块以字符为单位的最小大小(默认值:350)。
minChunkLengthToEmbed: 文本块去除空白字符或者处理分隔符后,用于嵌入处理的文本的最小长度(默认值:5)。
maxNumChunks: 从文本生成的最大块数(默认值:10000)。
keepSeparator: 是否在块中保留分隔符(例如换行符)(默认值:true)。


TokenTextSplitter拆分文档的逻辑

1.使用 CL100K_BASE 编码将输入文本编码为 token列表

2.根据 chunkSize 对编码后的token列表进行截取分块

3.对于分块:

        (1)将token分块再解码为文本字符串

        (2)尝试从后向前找到一个合适的截断点(默认是英文的句号、问号、感叹号或换行符)。

        (3)如果找到合适的截断点,并且截断点所在的index大于minChunkSizeChars,则将在该点截断该块

        (4)对分块去除两边的空白字符,并根据 keepSeparator 设置,如果为false,则移除换行符

        (5)如果处理后的分块长度大于 minChunkLengthToEmbed,则将其添加到分块列表中

4.持续执行第2步和第3步,直到所有 token 都被处理完或达到 maxNumChunks

5.如果还有剩余的token没有处理,并且剩余的token进行编码和转换处理后,长度大于 minChunkLengthToEmbed,则将其作为最终块添加
 

源码中,是根据英文的逗号,叹号,问号和换行符进行文本的截取。这显然不太符合中文文档的语法习惯。为此,我们对源码进行修改,增加分割符的列表,用户可以根据文档的中英文情况,自行设置分割符。自定义的分割类代码如下:

package com.renr.springainew.controller;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import com.knuddels.jtokkit.api.IntArrayList;
import org.springframework.ai.transformer.splitter.TextSplitter;
import org.springframework.util.Assert;

import java.util.*;

/**
 * @Classname MyTextSplit
 * @Description TODO
 * @Date 2025-07-26 9:46
 * @Created by 老任与码
 */
public class MyTextSplit extends TextSplitter {

    private static final int DEFAULT_CHUNK_SIZE = 800;
    private static final int MIN_CHUNK_SIZE_CHARS = 350;
    private static final int MIN_CHUNK_LENGTH_TO_EMBED = 5;
    private static final int MAX_NUM_CHUNKS = 10000;
    private static final boolean KEEP_SEPARATOR = true;
    private final EncodingRegistry registry;
    private final Encoding encoding;
    private final int chunkSize;
    private final int minChunkSizeChars;
    private final int minChunkLengthToEmbed;
    private final int maxNumChunks;
    private final boolean keepSeparator;
    private final List<String> splitList;

    public MyTextSplit() {
        this(800, 350, 5, 10000, true, Arrays.asList(".", "!", "?", "\n"));
    }

    public MyTextSplit(boolean keepSeparator) {
        this(800, 350, 5, 10000, keepSeparator, Arrays.asList(".", "!", "?", "\n"));
    }

    public MyTextSplit(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks, boolean keepSeparator, List<String> splitList) {
        this.registry = Encodings.newLazyEncodingRegistry();
        this.encoding = this.registry.getEncoding(EncodingType.CL100K_BASE);
        this.chunkSize = chunkSize;
        this.minChunkSizeChars = minChunkSizeChars;
        this.minChunkLengthToEmbed = minChunkLengthToEmbed;
        this.maxNumChunks = maxNumChunks;
        this.keepSeparator = keepSeparator;
        if (splitList == null || splitList.isEmpty()) {
            this.splitList = Arrays.asList(".", "!", "?", "\n");
        } else {
            this.splitList = splitList;
        }
    }

    protected List<String> splitText(String text) {
        return this.doSplit(text, this.chunkSize);
    }

    protected List<String> doSplit(String text, int chunkSize) {
        if (text != null && !text.trim().isEmpty()) {
            List<Integer> tokens = this.getEncodedTokens(text);
            List<String> chunks = new ArrayList();
            int num_chunks = 0;

            while (!tokens.isEmpty() && num_chunks < this.maxNumChunks) {
                List<Integer> chunk = tokens.subList(0, Math.min(chunkSize, tokens.size()));
                String chunkText = this.decodeTokens(chunk);
                if (chunkText.trim().isEmpty()) {
                    tokens = tokens.subList(chunk.size(), tokens.size());
                } else {
                    int lastPunctuation = splitList.stream()
                            .mapToInt(chunkText::lastIndexOf)
                            .max().orElse(-1);
                    // 46 .  63 ?  33 !   10换行
                    // int lastPunctuation = Math.max(chunkText.lastIndexOf(46), Math.max(chunkText.lastIndexOf(63), Math.max(chunkText.lastIndexOf(33), chunkText.lastIndexOf(10))));
                    if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) {
                        chunkText = chunkText.substring(0, lastPunctuation + 1);
                    }

                    String chunkTextToAppend = this.keepSeparator ? chunkText.trim() : chunkText.replace(System.lineSeparator(), " ").trim();
                    if (chunkTextToAppend.length() > this.minChunkLengthToEmbed) {
                        chunks.add(chunkTextToAppend);
                    }

                    tokens = tokens.subList(this.getEncodedTokens(chunkText).size(), tokens.size());
                    ++num_chunks;
                }
            }

            if (!tokens.isEmpty()) {
                String remaining_text = this.decodeTokens(tokens).replace(System.lineSeparator(), " ").trim();
                if (remaining_text.length() > this.minChunkLengthToEmbed) {
                    chunks.add(remaining_text);
                }
            }

            return chunks;
        } else {
            return new ArrayList();
        }
    }

    private List<Integer> getEncodedTokens(String text) {
        Assert.notNull(text, "Text must not be null");
        return this.encoding.encode(text).boxed();
    }

    private String decodeTokens(List<Integer> tokens) {
        Assert.notNull(tokens, "Tokens must not be null");
        IntArrayList tokensIntArray = new IntArrayList(tokens.size());
        Objects.requireNonNull(tokensIntArray);
        tokens.forEach(tokensIntArray::add);
        return this.encoding.decode(tokensIntArray);
    }

}

测试代码:

    public void init2() {
        // 读取文本文件
        TextReader textReader = new TextReader(this.resource);
        // 元数据中增加文件名
        textReader.getCustomMetadata().put("filename", "医院.txt");
        // 获取Document对象,只有一个记录
        List<Document> docList = textReader.read();
        // 指定分割符
        List<String> splitList = Arrays.asList("。", "!", "?", System.lineSeparator());
        MyTextSplit splitter = new MyTextSplit(300, 100, 5, 10000, true, splitList);
        List<Document> splitDocuments = splitter.apply(docList);
        System.out.println(splitDocuments);
    }

另外,根据源码,minChunkSizeChars的值要小于chunkSize的值才有意义。

根据CL100K_BASE编码,300长度的token转为本文内容后,文本内容的长度在220-250之间(根据本例的中文文档测试,实际存在误差),转换比例在70%到80%多,为了根据特定的字符进行分割,所以minChunkSize的值最好小于210。

根据源码的逻辑,分割文本时,可能出现如果分隔符的索引小于minChunkSizeChars,就不会对文本进行分割,于是,就会出现句子被断开的情况。

针对该现象,可以增加分割的字符种类;或者干脆将minChunkSizeChars设置为0(解决方案有点简单粗暴哈O(∩_∩)O哈哈~);还可以根据分割后的内容,进行手动修改,然后再进行向量化处理。

该代码存在的问题:

使用由于是先转为token列表;再转为字符串后,根据分割符进行截取;截取后转为token,再根据token长度截取token列表,索引多次转换后,使用CL100K_BASE编码会存在一些中文数据的丢失或者乱码情况。

经过测试,可以将编码方式修改为O200K_BASE编码。使用该编码后,中文转换的token列表长度小于文本本身的长度,所以分块时,需要重置chunkSize和minChunkSizeChars的值。

this.encoding = this.registry.getEncoding(EncodingType.O200K_BASE);


网站公告

今日签到

点亮在社区的每一天
去签到