candle-nn提供embedding()初始化Embedding方法:
pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {
let embeddings = vb.get_with_hints(
(in_size, out_size),
"weight",
crate::Init::Randn {
mean: 0.,
stdev: 1.,
},
)?;
Ok(Embedding::new(embeddings, out_size))
}
candle Embedding初体验:
其中Tokenizer和dataset的构造详情参考:rust-candle学习笔记9-使用tokenizers加载qwen3分词,使用分词器处理文本
use candle_nn::{embedding, Embedding, Module, VarBuilder, VarMap};
fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;
let vocab_size = tokenizer.get_vocab_size(true);
let text = read_txt("assets/the-verdict.txt")?;
let device = Device::cuda_if_available(0)?;
let dataset = TokenDataset::new(text, tokenizer, 32, 16, device.clone())?;
let (inputs, targets) = dataset.get_item(0)?;
println!(" inputs: {:?}\n", inputs);
println!(" targets: {:?}\n", targets);
let len = dataset.len();
println!("{:?}", len);
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
let embedding = embedding(vocab_size, 5, vb)?;
let x_embedding = embedding.forward(&inputs)?;
let y_embedding = embedding.forward(&targets)?;
println!(" x_embedding: {:?}\n", x_embedding);
println!("{:?}", x_embedding.to_vec2::<f32>()?);
println!(" y_embedding: {:?}\n", y_embedding);
println!("{:?}", y_embedding.to_vec2::<f32>()?);
Ok(())
}
实现正余弦位置编码:
struct PositionEmbedding {
pos_embedding: Tensor,
device: Device
}
impl PositionEmbedding {
fn new(seq_len: usize, embedding_dim: usize, device: Device) -> Result<Self> {
if embedding_dim % 2 != 0 {
return Err(Box::new(candle_core::Error::msg("embedding_dim must be even")));
}
let mut pos_embedding_vec: Vec<f32> = Vec::with_capacity(seq_len * embedding_dim);
let w_const: f32 = 10000.0;
for t in 0..seq_len {
let i_max = embedding_dim / 2;
for i in 0..i_max {
let denominator = w_const.powf(2.0 * i as f32 / embedding_dim as f32);
let pos_sin_i = (t as f32 / denominator).sin();
let pos_cos_i = (t as f32 / denominator).cos();
pos_embedding_vec.push(pos_sin_i);
pos_embedding_vec.push(pos_cos_i);
}
}
let pos_embedding = Tensor::from_vec(pos_embedding_vec, (seq_len, embedding_dim), &device)?;
Ok(Self { pos_embedding, device })
}
}
测试:
注意:candle 不同维度tensor相加直接用+会报错,要显示的调用广播加,高维tensor和低维tensor谁加谁都可以
fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;
let vocab_size = tokenizer.get_vocab_size(true);
let text = read_txt("assets/the-verdict.txt")?;
let device = Device::cuda_if_available(0)?;
let seq_len = 32;
let dataset = TokenDataset::new(text, tokenizer, seq_len, 16, device.clone())?;
let batch_size: usize = 6;
let mut loader = DataLoader::new(dataset, batch_size, true);
loader.reset();
let (x, y) = loader.next().unwrap()?;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);
let embedding_dim: usize = 256;
let embedding = embedding(vocab_size, embedding_dim, vb)?;
let x_embedding = embedding.forward(&x)?;
let y_embedding = embedding.forward(&y)?;
println!(" x_embedding: {:?}\n", x_embedding);
println!(" y_embedding: {:?}\n", y_embedding);
let pos_embedding = PositionEmbedding::new(seq_len, embedding_dim, device.clone())?;
let pos_emb = pos_embedding.pos_embedding;
// candle 不同维度tensor相加直接用+会报错,
// 广播加要显示的调用
// 下面两种方式都可以
let x_input = x_embedding.broadcast_add(&pos_emb)?;
// let x_input = pos_emb.broadcast_add(&x_embedding)?;
println!(" x_input: {:?}\n", x_input);
Ok(())
}