DataLoader中sampler参数介绍
Sampler 决定了 Dataset 的采样顺序。
DataLoader | Sampler | DataSet 关系
Sampler: 提供数据集中元素的索引DataSet: 根据Sampler提供的索引来检索数据DataLoader: 批量加载数据用于后续的训练和测试
Sampler
1 | class Sampler(object): |
PyTorch官网已经实现了多种 Sampler :
SequentialSampler
若
shuffle=False,且未指定sampler,默认使用
1 | class SequentialSampler(Sampler): |
RandomSampler
若
shuffle=True,且未指定sampler,默认使用
1 | class RandomSampler(Sampler): |
BatchSampler
like
sampler, but returns a batch of indices at a time. Mutually exclusive withbatch_size,shuffle,sampler, anddrop_last
- 在
DataLoader中设置batch_sampler=batch_sampler的时候,上面四个参数都必须是默认值。也很好理解,每次采样返回一个batch,那么batch_size肯定为1
1 | class BatchSampler(Sampler): |
- 可以看到在构造
BatchSampler实例的时候,需要传入一个sampler作为实参
最佳实践
最近看到一篇推文,分享了一个使模型训练速度提升20%的Trick—BlockShuffle 。fork了原作者的代码,并自定义了 batch_sampler ,源码见:TransformersWsz/BlockShuffleTest