llm数据处理之指代平均折叠法

发布于:2024-09-17 ⋅ 阅读:(130) ⋅ 点赞:(0)
import json
import pandas as pd
from glob import glob
from tqdm import tqdm


def get_replace_token(text):
    total_replace_count = len(text) - 127
    alm_list = []
    if total_replace_count < 1024:
        if total_replace_count >= 96:
            one_alm = total_replace_count // 32
            two_alm = total_replace_count % 32
            alm_list = [one_alm] * 32
            for i in range(two_alm):
                alm_list[i] += 1
        else:
            if total_replace_count <= 3:
                alm_list = [3]
            else:
                alm_list = [2, 2]
                for i in range(1, total_replace_count):
                    if 2 in alm_list:

                        for j, a in enumerate(alm_list):
                            if alm_list[j] != 3:
                                alm_list[j] += 1
                                break
                    else:
                        alm_list.pop()
                        alm_list += [2, 2]
    return alm_list

def get_replace_pos(al,one):
    fill_len = len(one) // len(al)
    pos_se = []
    start = 0

    for i in al:
        start += fill_len

        pos_se.append([start, start + i])
        start += i
    return pos_se


for one_path in glob("D:/Linly_AI/Chinese_pretraining_dataset/*"):
    with open(one_path, "r", encoding="utf-8") as f:
        data = f.readlines()
    data = [json.loads(i)["text"].replace('[正文]', "") for i in data]
    # 开始折叠 (平均指代法 两个token 起  指代token 最高32 )
    #
    for one in tqdm(data):
        if len(one) > 128:
            al = get_replace_token(one)
            al_pos= get_replace_pos(al,one)
            

这段代码进一步扩展了之前的逻辑,用于处理文本数据并确定每个折叠位置。以下是代码的详细解析:

  1. 导入必要的库:
    • json:用于处理JSON数据。
    • pandas:数据处理和分析库,但在这段代码中并未使用。
    • glob:用于查找符合特定规则的文件路径。
    • tqdm:用于在循环过程中显示进度条。
  2. 定义函数 get_replace_token
    • 功能:计算文本长度超过127个字符时,需要折叠的token数量,并生成一个列表 alm_list,表示每个token需要折叠的次数。
    • 逻辑与之前相同,不再重复解释。
  3. 定义函数 get_replace_pos
    • 输入参数:
      • al:由 get_replace_token 函数返回的折叠次数列表。
      • one:要处理的文本。
    • 功能:根据折叠次数列表 al 和文本 one 计算每个折叠的具体位置。
    • 逻辑:
      • 计算 fill_len,即文本 one 平均分配到每个折叠位置的基础长度。
      • 初始化一个空列表 pos_se,用于存储每个折叠的开始和结束位置。
      • 遍历 al 列表,计算每个折叠的开始位置 start
      • 对于每个折叠次数 i,计算折叠的开始和结束位置,并将它们作为一个列表添加到 pos_se 中。
  4. 主循环:
    • 使用 glob 函数查找 “D:/Linly_AI/Chinese_pretraining_dataset/” 目录下的所有文件。
    • 读取每个文件的内容,并处理:
      • 打开文件,读取所有行。
      • 将每行JSON数据解析为字典,并提取 “text” 字段的值,同时去除其中的 “[正文]” 字符串。
      • 遍历处理后的文本数据。
  5. 对每段文本进行处理:
    • 如果文本长度超过128个字符,调用 get_replace_token 函数计算折叠次数列表 al
    • 调用 get_replace_pos 函数计算每个折叠的具体位置 al_pos
      需要注意的是,代码中的 al_pos 变量计算出了每个折叠的开始和结束位置,但代码中没有进一步使用这些位置。如果需要进一步处理,例如实际折叠文本或进行其他操作,需要在代码中添加相应的逻辑。此外,代码中存在一些潜在的逻辑问题和可能的错误,例如在 get_replace_token 函数中的 else 分支可能不会正确处理所有情况。

网站公告

今日签到

点亮在社区的每一天
去签到