DataLoader中sampler参数介绍
Sampler
决定了 Dataset
的采样顺序。
DataLoader
| Sampler
| DataSet
关系
Sampler
: 提供数据集中元素的索引DataSet
: 根据Sampler
提供的索引来检索数据DataLoader
: 批量加载数据用于后续的训练和测试
Sampler
1 | class Sampler(object): |
PyTorch官网已经实现了多种 Sampler
:
SequentialSampler
若
shuffle=False
,且未指定sampler
,默认使用
1 | class SequentialSampler(Sampler): |
RandomSampler
若
shuffle=True
,且未指定sampler
,默认使用
1 | class RandomSampler(Sampler): |
BatchSampler
like
sampler
, but returns a batch of indices at a time. Mutually exclusive withbatch_size
,shuffle
,sampler
, anddrop_last
- 在
DataLoader
中设置batch_sampler=batch_sampler
的时候,上面四个参数都必须是默认值。也很好理解,每次采样返回一个batch,那么batch_size
肯定为1
1 | class BatchSampler(Sampler): |
- 可以看到在构造
BatchSampler
实例的时候,需要传入一个sampler作为实参
最佳实践
最近看到一篇推文,分享了一个使模型训练速度提升20%的Trick—BlockShuffle 。fork了原作者的代码,并自定义了 batch_sampler
,源码见:TransformersWsz/BlockShuffleTest