深度学习中Dataset类通用的架构思路
Dataset 类设计的必备部分
1. 初始化 __init__
- 配置和路径管理:保存
config
,区分train/val/test
路径。 - 加载原始数据:CSV、JSON、Numpy、Parquet 等。
- 预处理器/归一化器:如
StandardScaler
,或者 Tokenizer(在 NLP 任务里)。 - 准备辅助信息:比如 meta 特征、文本 embedding。
- 构造样本列表(self.samples):保证后面取样时直接
O(1)
访问。
2. 数据预处理
- normalize / inverse_transform:数值数据标准化和反变换。
- tokenize / pad:文本分词、对齐。
- feature engineering:特征拼接、缺失值处理。
3. 核心接口
__len__
: 返回数据集样本数。__getitem__
: 返回一个样本(通常是(features, label)
的 tuple 或 dict)。
4. 可选接口
get_scaler()
: 返回归一化器。get_vocab()
: NLP 任务里返回词表。collate_fn
: 定义 batch 内如何拼接(特别是变长序列)。save_cache
/load_cache
: 大数据集可以存缓存,避免每次都重新处理。
5. 继承关系
BaseDataset:负责
- 通用逻辑(加载文件、归一化、拼装 sample)。
- 提供钩子函数,比如
load_paths(flag)
、process_sample(sample)
。
子类:只需要实现 路径差异 或 样本加工方式差异。
通用代码结构示意
class BaseDataset(Dataset):
def __init__(self, config, flag="train", scaler=None):
self.config = config
self.flag = flag
self.scaler = scaler or StandardScaler()
self.samples = []
self._load_data()
self._build_samples()
def _load_data(self):
"""子类可重写,加载原始数据"""
raise NotImplementedError
def _build_samples(self):
"""子类可重写,拼装每个样本的x, y, feats"""
raise NotImplementedError
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def get_scaler(self):
return self.scaler
def inverse_transform(self, x):
return x * self.std + self.mean
子类只管:
class ElectricityDataset(BaseDataset):
def _load_data(self):
# 只写路径和文件加载逻辑
pass
def _build_samples(self):
# 根据任务需要定义样本结构
pass
调用示例
data_config = {
"root": "data/electricity/",
"train_file": "train.json",
"train_meta_file": "train_meta.npy",
"train_news_file": "train_news.npy"
}
train_config = {
"batch_size": 64,
"learning_rate": 1e-3,
"epochs": 20
}
train_ds = ElectricityDataset(data_config, flag="train")
train_loader = DataLoader(
train_ds,
batch_size=train_config["batch_size"],
shuffle=True,
collate_fn=custom_collate_fn
)
)