一、官方README
TrOCR
项目地址:https://github.com/microsoft/unilm
简介
TrOCR是一种端到端的文本识别方法,它结合了预训练的图像Transformer和文本Transformer模型,利用Transformer架构同时进行图像理解和字块级别的文本生成。
论文:TrOCR: 基于预训练模型的Transformer光学字符识别
李明浩,吕腾超,崔磊,卢一娟,迪内·弗洛伦西奥,张查,李周军,魏富如,AAAI 2023
。
TrOCR模型也以Huggingface格式提供。[文档][模型]
模型 | 参数数量 | 测试集 | 得分 |
---|---|---|---|
TrOCR-Small | 62M | IAM | 4.22(区分大小写的字符错误率) |
TrOCR-Base | 334M | IAM | 3.42(区分大小写的字符错误率) |
TrOCR-Large | 558M | IAM | 2.89(区分大小写的字符错误率) |
TrOCR-Small | 62M | SROIE | 95.86(F1值) |
TrOCR-Base | 334M | SROIE | 96.34(F1值) |
TrOCR-Large | 558M | SROIE | 96.60(F1值) |
模型 | IIIT5K - 3000 | SVT - 647 | ICDAR2013 - 857 | ICDAR2013 - 1015 | ICDAR2015 - 1811 | ICDAR2015 - 2077 | SVTP - 645 | CT80 - 288 |
---|---|---|---|---|---|---|---|---|
TrOCR-Base(单词准确率) | 93.4 | 95.2 | 98.4 | 97.4 | 86.9 | 81.2 | 92.1 | 90.6 |
TrOCR-Large(单词准确率) | 94.1 | 96.1 | 98.4 | 97.3 | 88.1 | 84.1 | 93.0 | 95.1 |
安装
conda create -n trocr python=3.7
conda activate trocr
git clone https://github.com/microsoft/unilm.git
cd unilm
cd trocr
pip install pybind11
pip install -r requirements.txt
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" 'git+https://github.com/NVIDIA/apex.git'
微调与评估
模型下载
模型 | 下载链接 |
---|---|
TrOCR-Small-IAM | trocr-small-handwritten.pt |
TrOCR-Base-IAM | trocr-base-handwritten.pt |
TrOCR-Large-IAM | trocr-large-handwritten.pt |
TrOCR-Small-SROIE | trocr-small-printed.pt |
TrOCR-Base-SROIE | trocr-base-printed.pt |
TrOCR-Large-SROIE | trocr-large-printed.pt |
TrOCR-Small-Stage1 | trocr-small-stage1.pt |
TrOCR-Base-Stage1 | trocr-base-stage1.pt |
TrOCR-Large-Stage1 | trocr-large-stage1.pt |
TrOCR-Base-STR | trocr-base-str.pt |
TrOCR-Large-STR | trocr-large-str.pt |
测试集下载
测试集 | 下载链接 |
---|---|
IAM | IAM.tar.gz |
SROIE | SROIE_Task2_Original.tar.gz |
STR基准测试集 | STR_BENCHMARKS.zip |
IAM主要用于手写文本识别任务
SROIE主要用于打印文本识别任务,特别是在文档处理和信息提取方面。
STR 基准测试集用于文本识别模型的基准测试,涵盖了多种场景和风格的文本图像,
如果本页面上的任何文件下载失败,请在URL后添加以下字符串作为后缀。
后缀字符串: ?sv=2022-11-02&ss=b&srt=o&sp=r&se=2033-06-08T16:48:15Z&st=2023-06-08T08:48:15Z&spr=https&sig=a9VXrihTzbWyVfaIDlIT1Z0FoR1073VB0RLQUMuudD4%3D
在IAM上进行微调
export MODEL_NAME=ft_iam
export SAVE_PATH=/path/to/save/${MODEL_NAME}
export LOG_DIR=log_${MODEL_NAME}
export DATA=/path/to/data
mkdir ${LOG_DIR}
export BSZ=8
export valid_BSZ=16
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 \
$(which fairseq-train) \
--data-type STR --user-dir ./ --task text_recognition --input-size 384 \
--arch trocr_large \ # 或者 trocr_base
--seed 1111 --optimizer adam --lr 2e-05 --lr-scheduler inverse_sqrt \
--warmup-init-lr 1e-8 --warmup-updates 500 --weight-decay 0.0001 --log-format tqdm \
--log-interval 10 --batch-size ${BSZ} --batch-size-valid ${valid_BSZ} --save-dir ${SAVE_PATH} \
--tensorboard-logdir ${LOG_DIR} --max-epoch 300 --patience 20 --ddp-backend legacy_ddp \
--num-workers 8 --preprocess DA2 --update-freq 1 \
--bpe gpt2 --decoder-pretrained roberta2 \ # --bpe sentencepiece --sentencepiece-model ./unilm3-cased.model --decoder-pretrained unilm ## 对于小模型
${DATA}
在SROIE上进行微调
export MODEL_NAME=ft_SROIE
export SAVE_PATH=/path/to/save/${MODEL_NAME}
export LOG_DIR=log_${MODEL_NAME}
export DATA=/path/to/data
mkdir ${LOG_DIR}
export BSZ=16
export valid_BSZ=16
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 \
$(which fairseq-train) \
--data-type SROIE --user-dir ./ --task text_recognition --input-size 384 \
--arch trocr_large \ # 或者 trocr_base
--seed 1111 --optimizer adam --lr 5e-05 --lr-scheduler inverse_sqrt \
--warmup-init-lr 1e-8 --warmup-updates 800 --weight-decay 0.0001 --log-format tqdm \
--log-interval 10 --batch-size ${BSZ} --batch-size-valid ${valid_BSZ} \
--save-dir ${SAVE_PATH} --tensorboard-logdir ${LOG_DIR} --max-epoch 300 \
--patience 10 --ddp-backend legacy_ddp --num-workers 10 --preprocess DA2 \
--bpe gpt2 --decoder-pretrained roberta2 \ # --bpe sentencepiece --sentencepiece-model ./unilm3-cased.model --decoder-pretrained unilm ## 对于小模型
${DATA}
在STR基准测试集上进行微调
--preprocess RandAugment --update-freq 1 --ddp-backend legacy_ddp \
--num-workers 8 --finetune-from-model /path/to/model \
--bpe gpt2 --decoder-pretrained roberta2 \
${DATA}
在SROIE上进行评估
export DATA=/path/to/data
export MODEL=/path/to/model
export RESULT_PATH=/path/to/result
export BSZ=16
$(which fairseq-generate) \
--data-type SROIE --user-dir ./ --task text_recognition --input-size 384 \
--beam 10 --nbest 1 --scoring sroie --gen-subset test \
--batch-size ${BSZ} --path ${MODEL} --results-path ${RESULT_PATH} \
--bpe gpt2 --dict-path-or-url https://layoutlm.blob.core.windows.net/trocr/dictionaries/gpt2_with_mask.dict.txt \ # --bpe sentencepiece --sentencepiece-model ./unilm3-cased.model --dict-path-or-url https://layoutlm.blob.core.windows.net/trocr/dictionaries/unilm3.dict.txt ## 对于小模型
--preprocess DA2 \
--fp16 \
${DATA}
在STR基准测试集上进行评估
export DATA=/path/to/data
export MODEL=/path/to/model
export RESULT_PATH=/path/to/result
export BSZ=16
$(which fairseq-generate) \
--data-type Receipt53K --user-dir ./ --task text_recognition \
--input-size 384 --beam 10 --nbest 1 --scoring wpa \
--gen-subset test --batch-size ${BSZ} --bpe gpt2 \
--dict-path-or-url https://layoutlm.blob.core.windows.net/trocr/dictionaries/gpt2_with_mask.dict.txt \
--path ${MODEL} --results-path ${RESULT_PATH} \
--preprocess RandAugment \
${DATA}
请使用 “convert_to_sroie_format.py” 将输出文件转换为zip格式,并在网站上提交以获取分数。
推理示例
详情请见 pic_inference.py。
引用
如果您想在研究中引用TrOCR,请引用以下论文:
@misc{li2021trocr,
title={TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models},
author={Minghao Li and Tengchao Lv and Lei Cui and Yijuan Lu and Dinei Florencio and Cha Zhang and Zhoujun Li and Furu Wei},
year={2021},
eprint={2109.10282},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
许可证
本项目遵循此源代码树根目录下的LICENSE
文件中的许可证。部分源代码基于fairseq项目。微软开源行为准则
联系信息
如需使用TrOCR时获得帮助或遇到问题,请提交一个GitHub问题。
如需与TrOCR相关的其他沟通,请联系崔磊 (lecu@microsoft.com
),魏富如 (fuwei@microsoft.com
)。
训练代码
- TrOCR项目代码来自:https://github.com/microsoft/unilm/tree/master/trocr
- 地址为:/mnt/Virgil/TrOCR/unilm-master/trocr
- 用于训练的图片文件夹路径为:/mnt/Virgil/TrOCR/ChineseDataset/val_images
- 用于训练的标签文件路径为:/mnt/Virgil/TrOCR/ChineseDataset/val_lables.txt
- 用于验证的图片文件夹路径:/mnt/Virgil/TrOCR/ChineseDataset/9w_images
- 用于验证的标签文件路径:/mnt/Virgil/TrOCR/ChineseDataset/9w_lables.txt
- 预训练权重地址:/mnt/Virgil/TrOCR/trocr-base-printed
参考TrOCR模型微调,Transformers-Tutorials/TrOCR/Fine_tune_TrOCR_on_IAM_Handwriting_Database_using_native_PyTorch.ipynb,的训练代码:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from evaluate import load as load_metric
from torch.optim import AdamW
from tqdm import tqdm
import os
# 数据集类
class ORCDataset(Dataset):
def __init__(self, image_dir, label_file, processor, max_target_length=256):
self.image_dir = image_dir
self.processor = processor
self.max_target_length = max_target_length
self.data = []
# 读取标签文件
with open(label_file, 'r', encoding='utf-8') as f:
valid_count = 0
for line in f:
parts = line.strip().split() # 使用空格分割文件名和文本
if len(parts) == 2:
file_name, text = parts
self.data.append((file_name, text))
valid_count += 1
print(f"成功从 {label_file} 加载 {valid_count} 个样本")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# 获取文件名和文本
file_name, text = self.data[idx]
image_path = f'{self.image_dir}/{file_name}'
# 确保图像文件存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"图像文件不存在: {image_path}")
# 打开并转换图像
image = Image.open(image_path).convert("RGB")
# 处理图像得到像素值
pixel_values = self.processor(image, return_tensors="pt").pixel_values
# 对文本进行编码得到标签
labels = self.processor.tokenizer(
text,
padding="max_length",
max_length=self.max_target_length
).input_ids
# 确保PAD标记被损失函数忽略
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
return {
"pixel_values": pixel_values.squeeze(),
"labels": torch.tensor(labels)
}
# 数据集路径
train_image_dir = '/mnt/Virgil/TrOCR/ChineseDataset/val_images'
train_label_file = '/mnt/Virgil/TrOCR/ChineseDataset/val_lables.txt'
eval_image_dir = '/mnt/Virgil/TrOCR/ChineseDataset/9w_images'
eval_label_file = '/mnt/Virgil/TrOCR/ChineseDataset/9w_lables.txt'
pretrained_weights = '/mnt/Virgil/TrOCR/trocr-base-printed'
# 加载处理器
processor = TrOCRProcessor.from_pretrained(pretrained_weights)
# 创建数据集
print("正在加载训练数据集...")
train_dataset = ORCDataset(
image_dir=train_image_dir,
label_file=train_label_file,
processor=processor
)
print("正在加载验证数据集...")
eval_dataset = ORCDataset(
image_dir=eval_image_dir,
label_file=eval_label_file,
processor=processor
)
print(f"训练样本数量: {len(train_dataset)}")
print(f"验证样本数量: {len(eval_dataset)}")
# 创建数据加载器
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=16)
# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionEncoderDecoderModel.from_pretrained(pretrained_weights)
model.to(device)
print("模型加载成功,编码器结构:")
print(model.encoder)
# 设置模型参数
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
# 加载评价指标
cer_metric = load_metric("cer")
def compute_cer(pred_ids, label_ids):
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
return cer_metric.compute(predictions=pred_str, references=label_str)
# 原生PyTorch训练流程
print("开始训练模型...")
optimizer = AdamW(model.parameters(), lr=5e-5)
for epoch in range(100):
# 训练
model.train()
train_loss = 0.0
print(f"Epoch {epoch+1}/{100} [训练中]")
for batch in tqdm(train_dataloader):
# 将数据移动到设备上
batch = {k: v.to(device) for k, v in batch.items()}
# 前向传播
outputs = model(**batch)
loss = outputs.loss
# 反向传播和优化
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loss += loss.item()
# 打印训练损失
avg_train_loss = train_loss / len(train_dataloader)
print(f"Epoch {epoch+1}/{100} 训练损失: {avg_train_loss:.4f}")
# 评估
model.eval()
valid_cer = 0.0
print(f"Epoch {epoch+1}/{100} [验证中]")
with torch.no_grad():
for batch in tqdm(eval_dataloader):
# 生成预测结果
outputs = model.generate(batch["pixel_values"].to(device))
# 计算指标
cer = compute_cer(
pred_ids=outputs,
label_ids=batch["labels"].to(device)
)
valid_cer += cer
# 打印验证CER
avg_valid_cer = valid_cer / len(eval_dataloader)
print(f"Epoch {epoch+1}/{100} 验证 CER: {avg_valid_cer:.4f}")
# 保存模型
model.save_pretrained(".")
print("模型训练完成并保存")
开始调试报错:
(tiacai) root@5de27e9cb8c1:/mnt/Virgil/TrOCR# python train2.py
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
正在加载训练数据集...
成功从 /mnt/Virgil/TrOCR/ChineseDataset/val_lables.txt 加载 364400 个样本
正在加载验证数据集...
成功从 /mnt/Virgil/TrOCR/ChineseDataset/9w_lables.txt 加载 90000 个样本
训练样本数量: 364400
验证样本数量: 90000
Traceback (most recent call last):
File "/mnt/Virgil/TrOCR/train2.py", line 95, in <module>
model = VisionEncoderDecoderModel.from_pretrained(pretrained_weights)
File "/opt/anaconda3/envs/tiacai/lib/python3.10/site-packages/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py", line 378, in from_pretrained
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
File "/opt/anaconda3/envs/tiacai/lib/python3.10/site-packages/transformers/modeling_utils.py", line 272, in _wrapper
return func(*args, **kwargs)
File "/opt/anaconda3/envs/tiacai/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4345, in from_pretrained
with safe_open(checkpoint_files[0], framework="pt") as f:
safetensors_rust.SafetensorError: Error while deserializing header: MetadataIncompleteBuffer
从给出的错误信息safetensors_rust.SafetensorError: Error while deserializing header: MetadataIncompleteBuffer
可知,在加载预训练模型权重时,safetensors 格式的文件头元数据不完整,这一般是由于权重文件损坏或者下载不完整导致的。
重新下载一下safetensors文件就OK了。
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
正在加载训练数据集...
成功从 /mnt/Virgil/TrOCR/ChineseDataset/val_lables.txt 加载 364400 个样本
正在加载验证数据集...
成功从 /mnt/Virgil/TrOCR/ChineseDataset/9w_lables.txt 加载 90000 个样本
训练样本数量: 364400
验证样本数量: 90000
Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
"attention_probs_dropout_prob": 0.0,
"encoder_stride": 16,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 768,
"image_size": 384,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"model_type": "vit",
"num_attention_heads": 12,
"num_channels": 3,
"num_hidden_layers": 12,
"patch_size": 16,
"pooler_act": "tanh",
"pooler_output_size": 768,
"qkv_bias": false,
"torch_dtype": "float32",
"transformers_version": "4.50.3"
}
Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
"activation_dropout": 0.0,
"activation_function": "gelu",
"add_cross_attention": true,
"attention_dropout": 0.0,
"bos_token_id": 0,
"classifier_dropout": 0.0,
"cross_attention_hidden_size": 768,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 2,
"dropout": 0.1,
"eos_token_id": 2,
"init_std": 0.02,
"is_decoder": true,
"layernorm_embedding": true,
"max_position_embeddings": 512,
"model_type": "trocr",
"pad_token_id": 1,
"scale_embedding": false,
"torch_dtype": "float32",
"transformers_version": "4.50.3",
"use_cache": false,
"use_learned_position_embeddings": true,
"vocab_size": 50265
}
Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at /mnt/Virgil/TrOCR/trocr-base-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
模型加载成功,编码器结构:
ViTModel(
(embeddings): ViTEmbeddings(
(patch_embeddings): ViTPatchEmbeddings(
(projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
)
(dropout): Dropout(p=0.0, inplace=False)
)
(encoder): ViTEncoder(
(layer): ModuleList(
(0-11): 12 x ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=False)
(key): Linear(in_features=768, out_features=768, bias=False)
(value): Linear(in_features=768, out_features=768, bias=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
)
)
(layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(pooler): ViTPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
但是在训练了一个epoch,并且验证完后,再要保存模型的时候报错了。
Traceback (most recent call last):
File "/mnt/Virgil/TrOCR/train2.py", line 168, in <module>
cer_metric = load_metric("cer")
File "/opt/anaconda3/envs/tiacai/lib/python3.10/site-packages/evaluate/loading.py", line 748, in load
evaluation_module = evaluation_module_factory(
File "/opt/anaconda3/envs/tiacai/lib/python3.10/site-packages/evaluate/loading.py", line 681, in evaluation_module_factory
raise FileNotFoundError(
FileNotFoundError: Couldn't find a module script at /mnt/Virgil/TrOCR/cer/cer.py. Module 'cer' doesn't exist on the Hugging Face Hub either.