Best Practice in PyTorch: 如何控制dataloader的随机shuffle

共 3458字,需浏览 7分钟

 ·

2022-05-20 05:45

↑ 点击蓝字 关注极市平台

作者丨魏鸿鑫@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/515697362
编辑丨极市平台

极市导读

 

在使用PyTorch进行训练或者测试的过程中,一般来说dataloader在每个epoch返回的样本顺序是不一样的,但在某些特殊情况中,我们可能希望dataloader按照固定的顺序进行多个epoch。本文作者给出了一个简单方便的实现思路,附详解代码。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

问题背景:

在使用PyTorch进行训练或者测试的过程中,一般来说dataloader在每个epoch返回的样本顺序是不一样的,但在某些特殊情况中,我们可能希望dataloader按照固定的顺序进行多个epoch, 或者说,在一个epoch中按照固定的顺序进行多次的样本循环iteration。

现有Sampler:

默认的 RandomSampler 在生成iteration的时候会重新做一次random shuffle,所以无法直接实现这个需求。

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator

        if self.replacement:
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:
            for _ in range(self.num_samples // n):
                yield from torch.randperm(n, generator=generator).tolist()
            yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

上面的代码是RandomSampler中最重要的__iter__函数,我们可以看到每次调用这个函数或者新的iter时会得到一个新的随机顺序的iteration。

再看看另一个常用的sampler,也就是 SequentialSampler。我们在test的时候经常会设置shuffle=false,这时候就相当于使用了SequentialSampler:

class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    "
""
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

在代码中可以看到,这个sampler就是简单地创造并返回一个range序列,无法对其进行shuffle操作。

解决方案:

结合上面两个现有的sampler,我们可以简单地自定义一个新的sampler来实现我们的需求。也就是说,我们希望能够手动控制何时进行shuffle操作,在没有shuffle时我们希望sampler按照前面的顺序返回iteration。

下面是我的实现:

class MySequentialSampler(SequentialSampler):
    def __init__(self, data_source, num_data=None):
        self.data_source = data_source
        self.my_list = list(range(len(self.data_source)))
        random.shuffle(self.my_list)
        if num_data is None:
            self.num_data = len(self.my_list)
        else:
            self.num_data = num_data
            self.my_list = self.my_list[:num_data]

    def __iter__(self):
        return iter(self.my_list)

    def __len__(self):
        return self.num_data

    def shuffle(self):
        self.my_list = list(range(len(self.data_source)))
        random.shuffle(self.my_list)
        self.my_list = self.my_list[:self.num_data]

这个实现非常简单而且使用方便。在默认情况下基本等同于SequentialSampler (去掉init函数中的shuffle即完全一致)。当我们需要重新shuffle序列的时候,只需要调用shuffle函数即可,比如:dataloader.sampler.shuffle(). 通过这个自定义sampler,我们就可以实现在指定的时候进行shuffle操作,而不是固定在每个iteration结束时进行shuffle。


ps: 理论上也可以直接通过对dataset进行shuffle,但这样操作的缺点是会改变对应的index,另外一般我们在train或者test函数中不会获取到dataset,而只能从loader进行操作(dataloader.dataset一般只能获取到length)。因此,修改sampler可以说是对原训练方法流程最少的方式。

公众号后台回复“目标检测竞赛”获取竞赛经验分享~

△点击卡片关注极市平台,获取最新CV干货
极市干货
数据集资源汇总:90+深度学习开源数据集整理|包括目标检测、工业缺陷、图像分割等多个方向
CVPR 2022:CVPR'22 最新132篇论文分方向整理CVPR'22 最新106篇论文分方向整理一文看尽 CVPR 2022 最新 20 篇 Oral 论文
极市动态:光大环保与极视角正式开启厂区智慧安防项目合作!极视角成为首批「青岛市人工智能产业链链主企业」!
最新竞赛:六大真实场景赛题!ECV2022极市计算机视觉开发者榜单大赛预报名开启

觉得有用麻烦给个在看啦~  
浏览 36
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报