【TrOCR】训练代码

发布于:2025-07-07 ⋅ 阅读:(12) ⋅ 点赞:(0)

一、官方README

TrOCR

项目地址:https://github.com/microsoft/unilm

简介

TrOCR是一种端到端的文本识别方法,它结合了预训练的图像Transformer和文本Transformer模型,利用Transformer架构同时进行图像理解和字块级别的文本生成。
在这里插入图片描述

论文:TrOCR: 基于预训练模型的Transformer光学字符识别
李明浩,吕腾超,崔磊,卢一娟,迪内·弗洛伦西奥,张查,李周军,魏富如,AAAI 2023

TrOCR模型也以Huggingface格式提供。[文档][模型]

模型 参数数量 测试集 得分
TrOCR-Small 62M IAM 4.22(区分大小写的字符错误率)
TrOCR-Base 334M IAM 3.42(区分大小写的字符错误率)
TrOCR-Large 558M IAM 2.89(区分大小写的字符错误率)
TrOCR-Small 62M SROIE 95.86(F1值)
TrOCR-Base 334M SROIE 96.34(F1值)
TrOCR-Large 558M SROIE 96.60(F1值)
模型 IIIT5K - 3000 SVT - 647 ICDAR2013 - 857 ICDAR2013 - 1015 ICDAR2015 - 1811 ICDAR2015 - 2077 SVTP - 645 CT80 - 288
TrOCR-Base(单词准确率) 93.4 95.2 98.4 97.4 86.9 81.2 92.1 90.6
TrOCR-Large(单词准确率) 94.1 96.1 98.4 97.3 88.1 84.1 93.0 95.1
安装
conda create -n trocr python=3.7
conda activate trocr
git clone https://github.com/microsoft/unilm.git
cd unilm
cd trocr
pip install pybind11
pip install -r requirements.txt
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" 'git+https://github.com/NVIDIA/apex.git'
微调与评估
模型下载
模型 下载链接
TrOCR-Small-IAM trocr-small-handwritten.pt
TrOCR-Base-IAM trocr-base-handwritten.pt
TrOCR-Large-IAM trocr-large-handwritten.pt
TrOCR-Small-SROIE trocr-small-printed.pt
TrOCR-Base-SROIE trocr-base-printed.pt
TrOCR-Large-SROIE trocr-large-printed.pt
TrOCR-Small-Stage1 trocr-small-stage1.pt
TrOCR-Base-Stage1 trocr-base-stage1.pt
TrOCR-Large-Stage1 trocr-large-stage1.pt
TrOCR-Base-STR trocr-base-str.pt
TrOCR-Large-STR trocr-large-str.pt
测试集下载
测试集 下载链接
IAM IAM.tar.gz
SROIE SROIE_Task2_Original.tar.gz
STR基准测试集 STR_BENCHMARKS.zip

IAM主要用于手写文本识别任务
SROIE主要用于打印文本识别任务,特别是在文档处理和信息提取方面。
STR 基准测试集用于文本识别模型的基准测试,涵盖了多种场景和风格的文本图像,

如果本页面上的任何文件下载失败,请在URL后添加以下字符串作为后缀。

后缀字符串: ?sv=2022-11-02&ss=b&srt=o&sp=r&se=2033-06-08T16:48:15Z&st=2023-06-08T08:48:15Z&spr=https&sig=a9VXrihTzbWyVfaIDlIT1Z0FoR1073VB0RLQUMuudD4%3D

在IAM上进行微调
export MODEL_NAME=ft_iam
export SAVE_PATH=/path/to/save/${MODEL_NAME}
export LOG_DIR=log_${MODEL_NAME}
export DATA=/path/to/data
mkdir ${LOG_DIR}
export BSZ=8
export valid_BSZ=16

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 \
    $(which fairseq-train) \
    --data-type STR --user-dir ./ --task text_recognition --input-size 384 \
    --arch trocr_large \   # 或者 trocr_base
    --seed 1111 --optimizer adam --lr 2e-05 --lr-scheduler inverse_sqrt \
    --warmup-init-lr 1e-8 --warmup-updates 500 --weight-decay 0.0001 --log-format tqdm \
    --log-interval 10 --batch-size ${BSZ} --batch-size-valid ${valid_BSZ} --save-dir ${SAVE_PATH} \
    --tensorboard-logdir ${LOG_DIR} --max-epoch 300 --patience 20 --ddp-backend legacy_ddp \
    --num-workers 8 --preprocess DA2 --update-freq 1 \
    --bpe gpt2 --decoder-pretrained roberta2 \ # --bpe sentencepiece --sentencepiece-model ./unilm3-cased.model --decoder-pretrained unilm ## 对于小模型
    ${DATA} 
在SROIE上进行微调
export MODEL_NAME=ft_SROIE
export SAVE_PATH=/path/to/save/${MODEL_NAME}
export LOG_DIR=log_${MODEL_NAME}
export DATA=/path/to/data
mkdir ${LOG_DIR}
export BSZ=16
export valid_BSZ=16

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 \
    $(which fairseq-train) \
    --data-type SROIE --user-dir ./ --task text_recognition --input-size 384 \
    --arch trocr_large \   # 或者 trocr_base
    --seed 1111 --optimizer adam --lr 5e-05 --lr-scheduler inverse_sqrt \
    --warmup-init-lr 1e-8 --warmup-updates 800 --weight-decay 0.0001 --log-format tqdm \
    --log-interval 10 --batch-size ${BSZ} --batch-size-valid ${valid_BSZ} \
    --save-dir ${SAVE_PATH} --tensorboard-logdir ${LOG_DIR} --max-epoch 300 \
    --patience 10 --ddp-backend legacy_ddp --num-workers 10 --preprocess DA2 \
    --bpe gpt2 --decoder-pretrained roberta2 \ # --bpe sentencepiece --sentencepiece-model ./unilm3-cased.model --decoder-pretrained unilm ## 对于小模型
    ${DATA}
在STR基准测试集上进行微调
    --preprocess RandAugment  --update-freq 1  --ddp-backend legacy_ddp \
    --num-workers 8  --finetune-from-model /path/to/model  \
    --bpe gpt2  --decoder-pretrained roberta2 \
    ${DATA} 
在SROIE上进行评估
export DATA=/path/to/data
export MODEL=/path/to/model
export RESULT_PATH=/path/to/result
export BSZ=16
$(which fairseq-generate) \
        --data-type SROIE --user-dir ./ --task text_recognition --input-size 384 \
        --beam 10 --nbest 1 --scoring sroie --gen-subset test \
        --batch-size ${BSZ} --path ${MODEL} --results-path ${RESULT_PATH} \
        --bpe gpt2 --dict-path-or-url https://layoutlm.blob.core.windows.net/trocr/dictionaries/gpt2_with_mask.dict.txt \ # --bpe sentencepiece --sentencepiece-model ./unilm3-cased.model --dict-path-or-url https://layoutlm.blob.core.windows.net/trocr/dictionaries/unilm3.dict.txt ## 对于小模型
        --preprocess DA2 \
        --fp16 \
        ${DATA}
在STR基准测试集上进行评估
export DATA=/path/to/data
export MODEL=/path/to/model
export RESULT_PATH=/path/to/result
export BSZ=16
$(which fairseq-generate) \
        --data-type Receipt53K --user-dir ./ --task text_recognition \
        --input-size 384 --beam 10 --nbest 1 --scoring wpa \
        --gen-subset test --batch-size ${BSZ} --bpe gpt2 \
        --dict-path-or-url https://layoutlm.blob.core.windows.net/trocr/dictionaries/gpt2_with_mask.dict.txt \
        --path ${MODEL} --results-path ${RESULT_PATH} \
        --preprocess RandAugment \
        ${DATA}

请使用 “convert_to_sroie_format.py” 将输出文件转换为zip格式,并在网站上提交以获取分数。

推理示例

详情请见 pic_inference.py

引用

如果您想在研究中引用TrOCR,请引用以下论文:

@misc{li2021trocr,
      title={TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models}, 
      author={Minghao Li and Tengchao Lv and Lei Cui and Yijuan Lu and Dinei Florencio and Cha Zhang and Zhoujun Li and Furu Wei},
      year={2021},
      eprint={2109.10282},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
许可证

本项目遵循此源代码树根目录下的LICENSE文件中的许可证。部分源代码基于fairseq项目。微软开源行为准则

联系信息

如需使用TrOCR时获得帮助或遇到问题,请提交一个GitHub问题。
如需与TrOCR相关的其他沟通,请联系崔磊 (lecu@microsoft.com),魏富如 (fuwei@microsoft.com)。

训练代码

  • TrOCR项目代码来自:https://github.com/microsoft/unilm/tree/master/trocr
  • 地址为:/mnt/Virgil/TrOCR/unilm-master/trocr
  • 用于训练的图片文件夹路径为:/mnt/Virgil/TrOCR/ChineseDataset/val_images
  • 用于训练的标签文件路径为:/mnt/Virgil/TrOCR/ChineseDataset/val_lables.txt
  • 用于验证的图片文件夹路径:/mnt/Virgil/TrOCR/ChineseDataset/9w_images
  • 用于验证的标签文件路径:/mnt/Virgil/TrOCR/ChineseDataset/9w_lables.txt
  • 预训练权重地址:/mnt/Virgil/TrOCR/trocr-base-printed

参考TrOCR模型微调Transformers-Tutorials/TrOCR/Fine_tune_TrOCR_on_IAM_Handwriting_Database_using_native_PyTorch.ipynb,的训练代码:

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from evaluate import load as load_metric
from torch.optim import AdamW
from tqdm import tqdm
import os

# 数据集类
class ORCDataset(Dataset):
    def __init__(self, image_dir, label_file, processor, max_target_length=256):
        self.image_dir = image_dir
        self.processor = processor
        self.max_target_length = max_target_length
        self.data = []
        # 读取标签文件
        with open(label_file, 'r', encoding='utf-8') as f:
            valid_count = 0
            for line in f:
                parts = line.strip().split()  # 使用空格分割文件名和文本
                if len(parts) == 2:
                    file_name, text = parts
                    self.data.append((file_name, text))
                    valid_count += 1
            print(f"成功从 {label_file} 加载 {valid_count} 个样本")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 获取文件名和文本
        file_name, text = self.data[idx]
        image_path = f'{self.image_dir}/{file_name}'
        
        # 确保图像文件存在
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"图像文件不存在: {image_path}")
            
        # 打开并转换图像
        image = Image.open(image_path).convert("RGB")
        
        # 处理图像得到像素值
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        
        # 对文本进行编码得到标签
        labels = self.processor.tokenizer(
            text,
            padding="max_length",
            max_length=self.max_target_length
        ).input_ids
        
        # 确保PAD标记被损失函数忽略
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        return {
            "pixel_values": pixel_values.squeeze(),
            "labels": torch.tensor(labels)
        }

# 数据集路径
train_image_dir = '/mnt/Virgil/TrOCR/ChineseDataset/val_images'
train_label_file = '/mnt/Virgil/TrOCR/ChineseDataset/val_lables.txt'
eval_image_dir = '/mnt/Virgil/TrOCR/ChineseDataset/9w_images'
eval_label_file = '/mnt/Virgil/TrOCR/ChineseDataset/9w_lables.txt'
pretrained_weights = '/mnt/Virgil/TrOCR/trocr-base-printed'

# 加载处理器
processor = TrOCRProcessor.from_pretrained(pretrained_weights)

# 创建数据集
print("正在加载训练数据集...")
train_dataset = ORCDataset(
    image_dir=train_image_dir,
    label_file=train_label_file,
    processor=processor
)

print("正在加载验证数据集...")
eval_dataset = ORCDataset(
    image_dir=eval_image_dir,
    label_file=eval_label_file,
    processor=processor
)

print(f"训练样本数量: {len(train_dataset)}")
print(f"验证样本数量: {len(eval_dataset)}")

# 创建数据加载器
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=16)

# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionEncoderDecoderModel.from_pretrained(pretrained_weights)
model.to(device)
print("模型加载成功,编码器结构:")
print(model.encoder)

# 设置模型参数
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

# 加载评价指标
cer_metric = load_metric("cer")

def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    return cer_metric.compute(predictions=pred_str, references=label_str)

# 原生PyTorch训练流程
print("开始训练模型...")
optimizer = AdamW(model.parameters(), lr=5e-5)

for epoch in range(100):
    # 训练
    model.train()
    train_loss = 0.0
    
    print(f"Epoch {epoch+1}/{100} [训练中]")
    for batch in tqdm(train_dataloader):
        # 将数据移动到设备上
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # 前向传播
        outputs = model(**batch)
        loss = outputs.loss
        
        # 反向传播和优化
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        train_loss += loss.item()

    # 打印训练损失
    avg_train_loss = train_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{100} 训练损失: {avg_train_loss:.4f}")

    # 评估
    model.eval()
    valid_cer = 0.0
    
    print(f"Epoch {epoch+1}/{100} [验证中]")
    with torch.no_grad():
        for batch in tqdm(eval_dataloader):
            # 生成预测结果
            outputs = model.generate(batch["pixel_values"].to(device))
            
            # 计算指标
            cer = compute_cer(
                pred_ids=outputs, 
                label_ids=batch["labels"].to(device)
            )
            valid_cer += cer

    # 打印验证CER
    avg_valid_cer = valid_cer / len(eval_dataloader)
    print(f"Epoch {epoch+1}/{100} 验证 CER: {avg_valid_cer:.4f}")

# 保存模型
model.save_pretrained(".")
print("模型训练完成并保存")

开始调试报错:

(tiacai) root@5de27e9cb8c1:/mnt/Virgil/TrOCR# python train2.py
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
正在加载训练数据集...
成功从 /mnt/Virgil/TrOCR/ChineseDataset/val_lables.txt 加载 364400 个样本
正在加载验证数据集...
成功从 /mnt/Virgil/TrOCR/ChineseDataset/9w_lables.txt 加载 90000 个样本
训练样本数量: 364400
验证样本数量: 90000
Traceback (most recent call last):
  File "/mnt/Virgil/TrOCR/train2.py", line 95, in <module>
    model = VisionEncoderDecoderModel.from_pretrained(pretrained_weights)
  File "/opt/anaconda3/envs/tiacai/lib/python3.10/site-packages/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py", line 378, in from_pretrained
    return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  File "/opt/anaconda3/envs/tiacai/lib/python3.10/site-packages/transformers/modeling_utils.py", line 272, in _wrapper
    return func(*args, **kwargs)
  File "/opt/anaconda3/envs/tiacai/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4345, in from_pretrained
    with safe_open(checkpoint_files[0], framework="pt") as f:
safetensors_rust.SafetensorError: Error while deserializing header: MetadataIncompleteBuffer

在这里插入图片描述

从给出的错误信息safetensors_rust.SafetensorError: Error while deserializing header: MetadataIncompleteBuffer 可知,在加载预训练模型权重时,safetensors 格式的文件头元数据不完整,这一般是由于权重文件损坏或者下载不完整导致的。
重新下载一下safetensors文件就OK了。

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
正在加载训练数据集...
成功从 /mnt/Virgil/TrOCR/ChineseDataset/val_lables.txt 加载 364400 个样本
正在加载验证数据集...
成功从 /mnt/Virgil/TrOCR/ChineseDataset/9w_lables.txt 加载 90000 个样本
训练样本数量: 364400
验证样本数量: 90000
Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "pooler_act": "tanh",
  "pooler_output_size": 768,
  "qkv_bias": false,
  "torch_dtype": "float32",
  "transformers_version": "4.50.3"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "eos_token_id": 2,
  "init_std": 0.02,
  "is_decoder": true,
  "layernorm_embedding": true,
  "max_position_embeddings": 512,
  "model_type": "trocr",
  "pad_token_id": 1,
  "scale_embedding": false,
  "torch_dtype": "float32",
  "transformers_version": "4.50.3",
  "use_cache": false,
  "use_learned_position_embeddings": true,
  "vocab_size": 50265
}

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at /mnt/Virgil/TrOCR/trocr-base-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
模型加载成功,编码器结构:
ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=False)
            (key): Linear(in_features=768, out_features=768, bias=False)
            (value): Linear(in_features=768, out_features=768, bias=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
    )
  )
  (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (pooler): ViTPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

但是在训练了一个epoch,并且验证完后,再要保存模型的时候报错了。

Traceback (most recent call last):
  File "/mnt/Virgil/TrOCR/train2.py", line 168, in <module>
    cer_metric = load_metric("cer")
  File "/opt/anaconda3/envs/tiacai/lib/python3.10/site-packages/evaluate/loading.py", line 748, in load
    evaluation_module = evaluation_module_factory(
  File "/opt/anaconda3/envs/tiacai/lib/python3.10/site-packages/evaluate/loading.py", line 681, in evaluation_module_factory
    raise FileNotFoundError(
FileNotFoundError: Couldn't find a module script at /mnt/Virgil/TrOCR/cer/cer.py. Module 'cer' doesn't exist on the Hugging Face Hub either.

网站公告

今日签到

点亮在社区的每一天
去签到