PyTorch 进行人脸训练和 Java 实现人脸识别的解决方案

发布于:2025-08-16 ⋅ 阅读:(20) ⋅ 点赞:(0)

下面我将为你提供一个结合 PyTorch 进行人脸训练和 Java 实现人脸识别的解决方案。这个方案分为两个主要部分:用 PyTorch 训练人脸特征提取模型,以及用 Java 加载模型并实现人脸识别功能。

第一部分:PyTorch 人脸特征训练

首先我们需要训练一个能够提取人脸特征的模型,这里使用基于 ResNet 的架构并结合 Triplet Loss 进行训练:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import os
import numpy as np
from sklearn.model_selection import train_test_split

# 1. 数据集定义
class FaceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# 2. 模型定义 - 基于ResNet50的特征提取器
class FaceFeatureExtractor(nn.Module):
    def __init__(self, embedding_size=128, num_classes=None):
        super(FaceFeatureExtractor, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        # 替换最后一层,输出特征向量
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embedding_size)
        self.embedding_size = embedding_size
        
        # 如果需要分类头(用于辅助训练)
        self.classifier = None
        if num_classes is not None:
            self.classifier = nn.Linear(embedding_size, num_classes)
        
    def forward(self, x):
        features = self.resnet(x)
        if self.classifier is not None:
            logits = self.classifier(features)
            return features, logits
        return features

# 3. Triplet Loss定义
class TripletLoss(nn.Module):
    def __init__(self, margin=0.5):
        super(TripletLoss, self).__init__()
        self.margin = margin
        
    def forward(self, anchor, positive, negative):
        distance_positive = torch.sqrt(torch.sum(torch.pow(anchor - positive, 2), dim=1))
        distance_negative = torch.sqrt(torch.sum(torch.pow(anchor - negative, 2), dim=1))
        
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return torch.mean(losses)

# 4. 数据准备
def prepare_data(data_dir):
    image_paths = []
    labels = []
    label_map = {}
    current_label = 0
    
    for person_name in os.listdir(data_dir):
        person_dir = os.path.join(data_dir, person_name)
        if os.path.isdir(person_dir):
            label_map[person_name] = current_label
            for img_file in os.listdir(person_dir):
                if img_file.endswith(('.jpg', '.png', '.jpeg')):
                    image_paths.append(os.path.join(person_dir, img_file))
                    labels.append(current_label)
            current_label += 1
    
    return image_paths, labels, label_map, current_label

# 5. 训练函数
def train_model(data_dir, epochs=50, batch_size=32, embedding_size=128):
    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize((150, 150)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # 准备数据
    image_paths, labels, label_map, num_classes = prepare_data(data_dir)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=0.2, random_state=42
    )
    
    # 创建数据集和数据加载器
    train_dataset = FaceDataset(train_paths, train_labels, transform)
    val_dataset = FaceDataset(val_paths, val_labels, transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    # 初始化模型、损失函数和优化器
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = FaceFeatureExtractor(embedding_size=embedding_size, num_classes=num_classes).to(device)
    
    # 组合损失:Triplet Loss + 交叉熵损失(辅助训练)
    triplet_criterion = TripletLoss()
    ce_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    
    # 训练循环
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            # 获取特征和分类结果
            features, logits = model(images)
            
            # 构建三元组(简单实现:同一类为正样本,不同类为负样本)
            # 实际应用中应使用更智能的三元组挖掘策略
            anchor_idx = torch.randint(0, len(images), (len(images)//2,))
            positive_idx = [torch.where(labels == labels[i])[0][torch.randint(0, len(torch.where(labels == labels[i])[0]), (1,))] 
                           for i in anchor_idx]
            negative_idx = [torch.where(labels != labels[i])[0][torch.randint(0, len(torch.where(labels != labels[i])[0]), (1,))] 
                           for i in anchor_idx]
            
            anchor = features[anchor_idx]
            positive = features[positive_idx]
            negative = features[negative_idx]
            
            # 计算损失
            triplet_loss = triplet_criterion(anchor.squeeze(), positive.squeeze(), negative.squeeze())
            ce_loss = ce_criterion(logits, labels)
            loss = triplet_loss + 0.5 * ce_loss  # 组合损失
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)
        
        # 计算平均损失
        train_loss = train_loss / len(train_loader.dataset)
        
        # 验证
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                
                features, logits = model(images)
                ce_loss = ce_criterion(logits, labels)
                val_loss += ce_loss.item() * images.size(0)
        
        val_loss = val_loss / len(val_loader.dataset)
        
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    
    # 保存模型和标签映射
    torch.save(model.state_dict(), 'face_feature_extractor.pth')
    np.save('label_map.npy', label_map)
    
    return model, label_map

# 6. 导出为ONNX格式(供Java使用)
def export_to_onnx(model, input_size=(1, 3, 150, 150), output_path='face_model.onnx'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dummy_input = torch.randn(*input_size).to(device)
    
    # 切换到评估模式
    model.eval()
    
    # 导出模型
    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    print(f"模型已导出为ONNX格式: {output_path}")

if __name__ == "__main__":
    # 假设数据集组织结构为:data_dir/人名/图片.jpg
    data_dir = "path/to/your/face_dataset"
    model, label_map = train_model(data_dir, epochs=50)
    export_to_onnx(model)

第二部分:Java 实现人脸识别

Java 实现人脸识别需要使用 OpenCV 进行人脸检测和预处理,使用 Deeplearning4j 加载 ONNX 模型进行特征提取,然后通过计算特征向量距离实现人脸识别。

import org.opencv.core.*;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import org.opencv.objdetect.CascadeClassifier;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.onnxruntime.OrtEnvironment;
import org.nd4j.onnxruntime.OrtSession;

import java.io.*;
import java.nio.FloatBuffer;
import java.util.*;
import java.util.stream.Collectors;

public class FaceRecognition {
    // 加载OpenCV库
    static {
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
    }
    
    private final CascadeClassifier faceDetector;
    private final OrtEnvironment env;
    private final OrtSession session;
    private final Map<String, float[]> faceFeatures;
    private final float THRESHOLD = 0.6f; // 相似度阈值,根据实际情况调整
    
    public FaceRecognition(String modelPath, String featuresPath) throws Exception {
        // 初始化人脸检测器
        faceDetector = new CascadeClassifier("haarcascade_frontalface_default.xml");
        
        // 初始化ONNX模型
        env = OrtEnvironment.getEnvironment();
        session = env.createSession(modelPath);
        
        // 加载已知人脸特征
        faceFeatures = loadFaceFeatures(featuresPath);
    }
    
    // 从文件加载人脸特征
    private Map<String, float[]> loadFaceFeatures(String path) throws IOException {
        Map<String, float[]> features = new HashMap<>();
        
        File file = new File(path);
        if (!file.exists()) {
            return features;
        }
        
        try (BufferedReader br = new BufferedReader(new FileReader(file))) {
            String line;
            while ((line = br.readLine()) != null) {
                String[] parts = line.split("\t");
                if (parts.length == 2) {
                    String name = parts[0];
                    float[] feature = Arrays.stream(parts[1].split(","))
                            .mapToFloat(Float::parseFloat)
                            .toArray();
                    features.put(name, feature);
                }
            }
        }
        
        return features;
    }
    
    // 保存人脸特征到文件
    public void saveFaceFeatures(String path) throws IOException {
        try (BufferedWriter bw = new BufferedWriter(new FileWriter(path))) {
            for (Map.Entry<String, float[]> entry : faceFeatures.entrySet()) {
                bw.write(entry.getKey() + "\t");
                bw.write(Arrays.stream(entry.getValue())
                        .mapToObj(String::valueOf)
                        .collect(Collectors.joining(",")));
                bw.newLine();
            }
        }
    }
    
    // 检测并提取人脸
    public Mat detectAndExtractFace(Mat image) {
        MatOfRect faceDetections = new MatOfRect();
        faceDetector.detectMultiScale(image, faceDetections);
        
        // 如果检测到人脸,返回第一个人脸区域
        if (faceDetections.empty()) {
            return null;
        }
        
        Rect faceRect = faceDetections.toArray()[0];
        Mat face = new Mat(image, faceRect);
        
        // 预处理:调整大小、转为RGB、归一化
        Imgproc.resize(face, face, new Size(150, 150));
        Imgproc.cvtColor(face, face, Imgproc.COLOR_BGR2RGB);
        
        return face;
    }
    
    // 提取人脸特征
    public float[] extractFeatures(Mat face) throws Exception {
        if (face == null) {
            return null;
        }
        
        // 转换为模型输入格式 (1, 3, 150, 150)
        float[] data = new float[3 * 150 * 150];
        int idx = 0;
        
        for (int c = 0; c < 3; c++) {
            for (int h = 0; h < 150; h++) {
                for (int w = 0; w < 150; w++) {
                    double[] pixel = face.get(h, w);
                    // 归一化:ImageNet的均值和标准差
                    data[idx++] = (float)((pixel[c] / 255.0 - 0.485) / 0.229);
                }
            }
        }
        
        // 创建输入张量
        INDArray input = Nd4j.create(data, new int[]{1, 3, 150, 150});
        
        // 运行模型
        OrtSession.Result result = session.run(
            Collections.singletonMap("input", env.createTensor(input.data().asFloat()))
        );
        
        // 提取输出特征
        float[] features = result.get(0).getValue().asFloat();
        
        // 特征归一化
        INDArray featureArray = Nd4j.create(features);
        featureArray = Transforms.unitVec(featureArray);
        
        return featureArray.data().asFloat();
    }
    
    // 计算特征相似度(余弦相似度)
    private float calculateSimilarity(float[] feature1, float[] feature2) {
        float dotProduct = 0.0f;
        float norm1 = 0.0f;
        float norm2 = 0.0f;
        
        for (int i = 0; i < feature1.length; i++) {
            dotProduct += feature1[i] * feature2[i];
            norm1 += feature1[i] * feature1[i];
            norm2 += feature2[i] * feature2[i];
        }
        
        return dotProduct / (float)(Math.sqrt(norm1) * Math.sqrt(norm2));
    }
    
    // 识别人脸
    public String recognizeFace(float[] features) {
        if (features == null || faceFeatures.isEmpty()) {
            return "Unknown";
        }
        
        String bestMatch = "Unknown";
        float maxSimilarity = 0.0f;
        
        for (Map.Entry<String, float[]> entry : faceFeatures.entrySet()) {
            float similarity = calculateSimilarity(features, entry.getValue());
            if (similarity > maxSimilarity && similarity > THRESHOLD) {
                maxSimilarity = similarity;
                bestMatch = entry.getKey();
            }
        }
        
        return bestMatch;
    }
    
    // 注册新人脸
    public void registerFace(String name, float[] features) {
        if (name != null && features != null) {
            faceFeatures.put(name, features);
        }
    }
    
    // 主方法示例
    public static void main(String[] args) {
        try {
            // 初始化人脸识别器
            FaceRecognition recognizer = new FaceRecognition(
                "face_model.onnx", 
                "face_features.txt"
            );
            
            // 处理测试图片
            Mat image = Imgcodecs.imread("test_face.jpg");
            Mat face = recognizer.detectAndExtractFace(image);
            
            if (face != null) {
                float[] features = recognizer.extractFeatures(face);
                
                // 识别人脸
                String result = recognizer.recognizeFace(features);
                System.out.println("识别结果: " + result);
                
                // 如果是未知人脸,可以选择注册
                if ("Unknown".equals(result)) {
                    // recognizer.registerFace("new_person", features);
                    // recognizer.saveFaceFeatures("face_features.txt");
                }
            } else {
                System.out.println("未检测到人脸");
            }
            
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

实现说明

  1. 整体流程

    • 先用 PyTorch 训练人脸特征提取模型,将人脸图像转换为固定维度的特征向量
    • 将训练好的模型导出为 ONNX 格式,方便跨平台使用
    • 在 Java 中使用 OpenCV 检测人脸,使用 Deeplearning4j 加载 ONNX 模型提取特征
    • 通过计算特征向量之间的余弦相似度来识别人脸
  2. 环境依赖

    • Python: PyTorch, torchvision, Pillow, numpy, scikit-learn
    • Java: OpenCV, Deeplearning4j, ND4J, ONNX Runtime
  3. 使用方法

    • 准备人脸数据集,按人名分类存放
    • 运行 PyTorch 代码训练模型并导出 ONNX 格式
    • 在 Java 项目中添加依赖,配置 OpenCV
    • 使用 Java 代码进行人脸检测、特征提取和识别
  4. 注意事项

    • 人脸数据集质量对识别效果影响很大,建议每个人至少提供 5-10 张不同角度和光照的照片
    • 阈值 THRESHOLD 需要根据实际测试结果调整
    • 生产环境中需要优化三元组选择策略,提高模型性能
    • Java 代码中需要正确配置 OpenCV 库文件路径

这个方案实现了从模型训练到实际应用的完整流程,可以根据具体需求进行优化和扩展。


网站公告

今日签到

点亮在社区的每一天
去签到