GitHub - bubbliiiing/arcface-pytorch: 这是一个arcface-pytorch的源码,可以用于训练自己的模型。
https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch
torch模型转换onnx
import torch
import arcface
from nets.arcface import Arcface as arcface
from torch.onnx import export
import onnxruntime as ort
import numpy as np
def convert2onnx_demo():
# model_path = './model_data/arcface_mobilefacenet.pth'
# model_path = './model_data/arcface_mobilenet_v1.pth'
model_path = './model_data/arcface_iresnet50.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Loading weights into state dict...')
# net = arcface(backbone='mobilefacenet', mode="predict").eval()
# net = arcface(backbone='mobilenetv1', mode="predict").eval()
net = arcface(backbone='iresnet50', mode="predict").eval()
net.load_state_dict(torch.load(model_path, map_location=device), strict=True)
net = net.to(device)
batch_size = 4
print('{} model loaded.'.format(model_path))
dummy_input = torch.randn(batch_size, 3, 112, 112).to(device)
# onnx_path = './model_data/arcface_mobilefacenet.onnx'
# onnx_path = './model_data/arcface_mobilenet_v1.onnx'
onnx_path = './model_data/arcface_iresnet50.onnx'
opset = 10
# export_onnx(net, dummy_input, onnx_path, opset, dynamic=True, simplify=True)
# export(net, dummy_input, onnx_path, opset, dynamic=True, simplify=True)
# 使用 torch.onnx.export 来导出模型
# dynamic_axes = {'images': {0: 'batch_size'}} # 支持动态批处理大小
dynamic_axes = {'input.1': {0: 'batch_size'}} # 使用正确的输入名
export(net, dummy_input, onnx_path, opset_version=opset, dynamic_axes=dynamic_axes, do_constant_folding=True)
ort_session = ort.InferenceSession(onnx_path)
# outputs = ort_session.run(None, {'images': np.random.randn(batch_size, 3, 112, 112).astype(np.float32)})
outputs = ort_session.run(None, {'input.1': np.random.randn(batch_size, 3, 112, 112).astype(np.float32)}) # 使用正确的输入名
print(outputs[0], outputs[0].shape)
convert2onnx_demo()
onnx模型推理
import onnxruntime as ort
import numpy as np
import cv2
# 加载ONNX模型
# session = ort.InferenceSession("./model_data/arcface_iresnet50.onnx")
session = ort.InferenceSession("./model_data/arcface_mobilenet_v1.onnx")
# 读取并预处理图像
image_path = "./img/1_001.jpg"
image = cv2.imread(image_path)
image = cv2.resize(image, (112, 112)) # 假设模型需要的输入尺寸是112x112
image = image.transpose(2, 0, 1) # 转换为 CxHxW
image = image.astype(np.float32)
image = (image - 127.5) / 128.0 # 归一化
# 添加batch维度
image = np.expand_dims(image, axis=0)
# 运行模型
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: image})
# 'outputs' 是模型的输出,这里假设输出是特征向量
features = outputs[0]
print(features)
print(features.shape)
参考博客