参考:about-pytorch, about-tokenizers
在魔搭社区链接下载qwen3的tokenizer.json文件
添加依赖库:
cargo add tokenizers
tokenizers库初体验:
use tokenizers::tokenizer::{self, Result, Tokenizer};
fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;
let text = "Hello, do you like tea? <|endoftext|> In the sunlit terraces of someunknownPlace.";
let encoding = tokenizer.encode(text, false)?;
println!("{:?}\n", encoding.get_tokens());
let ids = encoding.get_ids();
println!("{:?}\n", ids);
let text = tokenizer.decode(ids, false)?;
println!("{:?}\n", text);
Ok(())
}
定义一个dataset trait,包含常用的方法
trait Dataset {
fn get_batch(&self, start: usize, end: usize) -> Result<(Tensor, Tensor)> ;
fn len(&self) -> usize;
fn shuffle(&mut self) -> Result<()>;
}
定义tokenDataset
struct TokenDataset {
inputs_ids: Tensor,
target_ids: Tensor,
device: Device
}
为TokenDataset实现Dataset的trait:
impl Dataset for TokenDataset {
fn get_batch(&self, start: usize, end: usize) -> Result<(Tensor, Tensor)> {
Ok((self.inputs_ids.i((start..end, ..))?, self.target_ids.i((start..end, ..))?))
}
fn len(&self) -> usize {
self.inputs_ids.shape().dims()[0]
}
fn shuffle(&mut self) -> Result<()> {
let len = self.len();
let mut indices: Vec<u32> = (0..len).map(|i| i as u32).collect();
let mut rng = rand::rng();
indices.shuffle(&mut rng);
let idx_tensor = Tensor::from_vec(indices.clone(), (indices.len(), ), &self.device)?;
self.inputs_ids = self.inputs_ids.index_select(&idx_tensor, 0)?;
self.target_ids = self.target_ids.index_select(&idx_tensor, 0)?;
Ok(())
}
}
为TokenDataset定义new方法:
impl TokenDataset {
fn new(
txt: String,
tokenizer: Tokenizer,
max_length: usize,
stride: usize,
device: Device
) -> Result<Self> {
let tokens = tokenizer.encode(txt, true)?;
let tokens_id = tokens.get_ids();
let token_len = tokens_id.len();
if token_len <= max_length {
return Err(Box::new(candle_core::Error::msg("Text is too short for given max_length")));
}
let max_start_index = token_len - max_length;
let mut inputs_ids_vec: Vec<u32> = Vec::with_capacity(max_start_index * max_length);
let mut target_ids_vec: Vec<u32> = Vec::with_capacity(max_start_index * max_length);
for i in (0..max_start_index).step_by(stride) {
inputs_ids_vec.extend_from_slice(&tokens_id[i..i+max_length]);
target_ids_vec.extend_from_slice(&tokens_id[i+1..i+1+max_length]);
}
let total_samples = inputs_ids_vec.len() / max_length;
let inputs_ids = Tensor::from_vec(inputs_ids_vec, (total_samples, max_length), &device)?;
let target_ids = Tensor::from_vec(target_ids_vec, (total_samples, max_length), &device)?;
Ok(Self { inputs_ids, target_ids, device })
}
fn get_item(&self, idx: usize) -> Result<(Tensor, Tensor)>{
Ok((self.inputs_ids.i((idx, ..))?, self.target_ids.i((idx, ..))?))
}
}
定义Dataloader, 实现了Dataset trait的struct都可以用这个加载
struct DataLoader<'a> {
dataset: Box<dyn Dataset + 'a>,
batch_size: usize,
shuffle: bool,
current_index: usize
}
为Dataloader实现常用方法:
impl<'a> DataLoader<'a> {
pub fn new<D: Dataset + 'a>(dataset: D, batch_size: usize, shuffle: bool) -> Self {
Self {
dataset: Box::new(dataset),
batch_size,
shuffle,
current_index: 0,
}
}
pub fn reset(&mut self) {
self.current_index = 0;
if self.shuffle {
let _ = self.dataset.shuffle();
}
}
}
为Dataloader实现Iterator trait:
impl<'a> Iterator for DataLoader<'a> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {
let start = self.current_index * self.batch_size;
let end = std::cmp::min(start+self.batch_size, self.dataset.len());
if start >= end {
return None;
}
let batch = self.dataset.get_batch(start, end).ok()?;
self.current_index += 1;
Some(Ok(batch))
}
}
测试dataloader:
use tokenizers::tokenizer::{self, Result, Tokenizer};
#[allow(unused)]
mod learn_tokenizer;
use learn_tokenizer::read_txt;
use candle_core::{Device, Tensor, IndexOp};
use rand::seq::SliceRandom;
fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;
let text = read_txt("assets/the-verdict.txt")?;
let device = Device::cuda_if_available(0)?;
let dataset = TokenDataset::new(text, tokenizer, 512, 256, device.clone())?;
let (inputs, targets) = dataset.get_item(0)?;
println!("{:?}\n", inputs);
println!("{:?}\n", targets);
let len = dataset.len();
println!("{:?}", len);
let mut loader = DataLoader::new(dataset, 6, true);
loader.reset();
for batch in &mut loader {
let (x, y) = batch.unwrap();
println!("input: {:?}", x);
println!("target: {:?}", y);
}
Ok(())
}