rust-candle学习笔记10-使用Embedding

发布于:2025-05-10 ⋅ 阅读:(11) ⋅ 点赞:(0)

参考:about-pytorch

candle-nn提供embedding()初始化Embedding方法:

pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {
    let embeddings = vb.get_with_hints(
        (in_size, out_size),
        "weight",
        crate::Init::Randn {
            mean: 0.,
            stdev: 1.,
        },
    )?;
    Ok(Embedding::new(embeddings, out_size))
}

 candle Embedding初体验:

其中Tokenizer和dataset的构造详情参考:rust-candle学习笔记9-使用tokenizers加载qwen3分词,使用分词器处理文本

use candle_nn::{embedding, Embedding, Module, VarBuilder, VarMap};

fn main() -> Result<()> {
    let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;
    let vocab_size = tokenizer.get_vocab_size(true);

    let text = read_txt("assets/the-verdict.txt")?;
    let device = Device::cuda_if_available(0)?;
    let dataset = TokenDataset::new(text, tokenizer, 32, 16, device.clone())?;
    let (inputs, targets) = dataset.get_item(0)?;
    println!(" inputs: {:?}\n", inputs);
    println!(" targets: {:?}\n", targets);
    let len = dataset.len();
    println!("{:?}", len);

    let varmap = VarMap::new();
    let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
    let embedding = embedding(vocab_size, 5, vb)?;
    let x_embedding = embedding.forward(&inputs)?;
    let y_embedding = embedding.forward(&targets)?;
    println!(" x_embedding: {:?}\n", x_embedding);
    println!("{:?}", x_embedding.to_vec2::<f32>()?);
    println!(" y_embedding: {:?}\n", y_embedding);
    println!("{:?}", y_embedding.to_vec2::<f32>()?);
    
    Ok(())
}

实现正余弦位置编码:

struct PositionEmbedding {
    pos_embedding: Tensor,
    device: Device
}
impl PositionEmbedding {
    fn new(seq_len: usize, embedding_dim: usize, device: Device) -> Result<Self> {
        if embedding_dim % 2 != 0 {
            return Err(Box::new(candle_core::Error::msg("embedding_dim must be even")));
        }
        let mut pos_embedding_vec: Vec<f32> = Vec::with_capacity(seq_len * embedding_dim);
        let w_const: f32 = 10000.0;
        for t in 0..seq_len {
            let i_max = embedding_dim / 2;
            for i in 0..i_max {
                let denominator = w_const.powf(2.0 * i as f32 / embedding_dim as f32);
                let pos_sin_i = (t as f32 / denominator).sin();
                let pos_cos_i = (t as f32 / denominator).cos();
                pos_embedding_vec.push(pos_sin_i);
                pos_embedding_vec.push(pos_cos_i);
            }
        }
        let pos_embedding = Tensor::from_vec(pos_embedding_vec, (seq_len, embedding_dim), &device)?;
        Ok(Self { pos_embedding, device })
    }
}

测试:

注意:candle 不同维度tensor相加直接用+会报错,要显示的调用广播加,高维tensor和低维tensor谁加谁都可以

fn main() -> Result<()> {
    let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;
    let vocab_size = tokenizer.get_vocab_size(true);

    let text = read_txt("assets/the-verdict.txt")?;
    let device = Device::cuda_if_available(0)?;
    let seq_len = 32;
    let dataset = TokenDataset::new(text, tokenizer, seq_len, 16, device.clone())?;
    let batch_size: usize = 6;
    let mut loader = DataLoader::new(dataset, batch_size, true);
    loader.reset();
    let (x, y) = loader.next().unwrap()?;
    let varmap = VarMap::new();
    let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
    let embedding_dim: usize = 256;
    let embedding = embedding(vocab_size, embedding_dim, vb)?;
    let x_embedding = embedding.forward(&x)?;
    let y_embedding = embedding.forward(&y)?;
    println!(" x_embedding: {:?}\n", x_embedding);
    println!(" y_embedding: {:?}\n", y_embedding);
    let pos_embedding = PositionEmbedding::new(seq_len, embedding_dim, device.clone())?;
    let pos_emb = pos_embedding.pos_embedding;
    // candle 不同维度tensor相加直接用+会报错,
    // 广播加要显示的调用
    // 下面两种方式都可以
    let x_input = x_embedding.broadcast_add(&pos_emb)?;
    // let x_input = pos_emb.broadcast_add(&x_embedding)?;
    println!(" x_input: {:?}\n", x_input);
    Ok(())
}


网站公告

今日签到

点亮在社区的每一天
去签到