Rust实现FasterR-CNN目标检测全流程

发布于:2025-07-03 ⋅ 阅读:(21) ⋅ 点赞:(0)

使用 Rust 和 FasterR-CNN 进行目标检测 

FasterR-CNN 是目标检测领域广泛使用的深度学习模型。Rust 生态中可以通过 tch-rs(Torch 绑定)调用预训练的 PyTorch 模型实现。以下为完整实现步骤:


环境准备

安装 Rust 和必要的依赖:

cargo add tch
cargo add anyhow  # 错误处理

下载预训练的 FasterR-CNN 模型(需 PyTorch 格式 .pt 文件),或使用 TorchScript 格式模型。示例中使用 fasterrcnn_resnet50_fpn


加载预训练模型

use tch::{nn, Device, Tensor, Kind};

fn load_model(model_path: &str) -> anyhow::Result<nn::Module> {
    let device = Device::cuda_if_available();
    let model = nn::Module::load(model_path, device)?;
    Ok(model)
}

图像预处理

将输入图像转换为模型需要的格式(归一化 + 标准化):

use tch::vision::image;

fn preprocess_image(img_path: &str) -> anyhow::Result<Tensor> {
    let image = image::load(img_path)?;
    let resized = image.resize(800, 800);  // FasterR-CNN 典型输入尺寸
    let tensor = resized.to_kind(Kind::Float) / 255.0;
    let mean = Tensor::of_slice(&[0.485, 0.456, 0.406]).view([3, 1, 1]);
    let std = Tensor::of_slice(&[0.229, 0.224, 0.225]).view([3, 1, 1]);
    Ok((tensor - mean) / std)
}

运行推理

执行目标检测并获取结果:

fn run_detection(model: &nn::Module, input_tensor: &Tensor) -> anyhow::Result<(Tensor, Tensor)> {
    let output = model.forward_ts(&[input_tensor.unsqueeze(0)])?;
    let boxes = output.get(0).unwrap();
    let scores = output.get(1).unwrap();
    Ok((boxes, scores))
}

后处理与可视化

过滤低置信度检测结果并绘制边框:

use tch::IndexOp;

fn filter_results(bboxes: &Tensor, scores: &Tensor, threshold: f64) -> Vec<(Vec<f64>, f64)> {
    let mut detections = Vec::new();
    for i in 0..scores.size()[0] {
        if scores.double_value(&[i]) > threshold {
            let bbox = bboxes.i(i).to_kind(Kind::Double).to_vec::<f64>().unwrap();
            detections.push((bbox, scores.double_value(&[i])));
        }
    }
    detections
}

使用 imageprocopencv-rust 绘制检测框(需额外安装依赖)。


完整流程示例

fn main() -> anyhow::Result<()> {
    let model = load_model("fasterrcnn.pt")?;
    let input = preprocess_image("input.jpg")?;
    let (bboxes, scores) = run_detection(&model, &input)?;
    
    let detections = filter_results(&bboxes, &scores, 0.7);
    for (bbox, score) in detections {
        println!("Detected: {:?} with score {:.2}", bbox, score);
    }
    Ok(())
}

注意事项

  1. 模型需提前转换为 TorchScript 格式(通过 Python 的 torch.jit.script
  2. GPU 加速需配置 CUDA 环境
  3. 输入图像尺寸应与模型训练时一致
  4. COCO 数据集的类别标签需单独加载

Rust 生态的计算机视觉库(如 cv)可进一步简化图像操作,但 tch-rs 目前是调用 PyTorch 模型的最成熟方案。

Polars 支持各种文件格式

Polars 支持各种文件格式、包括 CSV、Parquet 和 JSON

use polars::prelude::*;

fn main() -> Result<()> {
    // Create a DataFrame with 4 names, ages, and cities
    let df = df![
        "name" => &["周杰伦", "力辣", "张慧费", "王菲"],
        "age" => &[55, 60, 70, 67],
        "city" => &["New York", "Los Angeles", "Chicago", "San Francisco"]
    ]?;

    // Display the DataFrame
    println!("{:?}", df);

    Ok(())
}

集成Polars和Pyo3构建

在Rust中集成Polars(数据框库)和Pyo3(Python绑定)构建Web服务,可以通过以下方法实现:

创建基础Rust项目

使用Cargo初始化新项目,添加必要的依赖。Cargo.toml需要包含以下依赖项:

[dependencies]
actix-web = "4"  # Web框架
polars = { version = "0.28", features = ["lazy"] }  # 数据处理
pyo3 = { version = "0.18", features = ["extension-module"] }  # Python集成
tokio = { version = "1", features = ["full"] }  # 异步运行时