Best Practice in PyTorch: 如何控制dataloader的随机shuffle
极市导读
在使用PyTorch进行训练或者测试的过程中,一般来说dataloader在每个epoch返回的样本顺序是不一样的,但在某些特殊情况中,我们可能希望dataloader按照固定的顺序进行多个epoch。本文作者给出了一个简单方便的实现思路,附详解代码。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
问题背景:
在使用PyTorch进行训练或者测试的过程中,一般来说dataloader在每个epoch返回的样本顺序是不一样的,但在某些特殊情况中,我们可能希望dataloader按照固定的顺序进行多个epoch, 或者说,在一个epoch中按照固定的顺序进行多次的样本循环iteration。
现有Sampler:
默认的 RandomSampler 在生成iteration的时候会重新做一次random shuffle,所以无法直接实现这个需求。
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
for _ in range(self.num_samples // n):
yield from torch.randperm(n, generator=generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
上面的代码是RandomSampler中最重要的__iter__函数,我们可以看到每次调用这个函数或者新的iter时会得到一个新的随机顺序的iteration。
再看看另一个常用的sampler,也就是 SequentialSampler。我们在test的时候经常会设置shuffle=false,这时候就相当于使用了SequentialSampler:
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Args:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
在代码中可以看到,这个sampler就是简单地创造并返回一个range序列,无法对其进行shuffle操作。
解决方案:
结合上面两个现有的sampler,我们可以简单地自定义一个新的sampler来实现我们的需求。也就是说,我们希望能够手动控制何时进行shuffle操作,在没有shuffle时我们希望sampler按照前面的顺序返回iteration。
下面是我的实现:
class MySequentialSampler(SequentialSampler):
def __init__(self, data_source, num_data=None):
self.data_source = data_source
self.my_list = list(range(len(self.data_source)))
random.shuffle(self.my_list)
if num_data is None:
self.num_data = len(self.my_list)
else:
self.num_data = num_data
self.my_list = self.my_list[:num_data]
def __iter__(self):
return iter(self.my_list)
def __len__(self):
return self.num_data
def shuffle(self):
self.my_list = list(range(len(self.data_source)))
random.shuffle(self.my_list)
self.my_list = self.my_list[:self.num_data]
这个实现非常简单而且使用方便。在默认情况下基本等同于SequentialSampler (去掉init函数中的shuffle即完全一致)。当我们需要重新shuffle序列的时候,只需要调用shuffle函数即可,比如:dataloader.sampler.shuffle(). 通过这个自定义sampler,我们就可以实现在指定的时候进行shuffle操作,而不是固定在每个iteration结束时进行shuffle。
ps: 理论上也可以直接通过对dataset进行shuffle,但这样操作的缺点是会改变对应的index,另外一般我们在train或者test函数中不会获取到dataset,而只能从loader进行操作(dataloader.dataset一般只能获取到length)。因此,修改sampler可以说是对原训练方法流程最少的方式。
公众号后台回复“目标检测竞赛”获取竞赛经验分享~