数据增强(上):我真的分不清AutoAugment和RandAugment!
点蓝色字关注“机器学习算法工程师”
设为星标,干货直达!
一个模型的性能除了和网络结构本身有关,还非常依赖具体的训练策略,比如优化器,数据增强以及正则化策略等(当然也很训练数据强相关,训练数据量往往决定模型性能的上线)。近年来,图像分类模型在ImageNet数据集的top1 acc已经由原来的56.5(AlexNet,2012)提升至90.88(CoAtNet,2021,用了额外的数据集JFT-3B),这进步除了主要归功于模型,算力和数据的提升,也与训练策略的提升紧密相关。最近刚兴起的vision transformer相比CNN模型往往也需要更heavy的数据增强和正则化策略。这里简单介绍图像分类领域比较常用的训练技巧中数据增强。
数据增强
baseline
ImageNet数据集训练常用的数据增强策略如下,训练过程的数据增强包括随机缩放裁剪(RandomResizedCrop,这种处理方式源自谷歌的Inception,所以称为 Inception-style pre-processing)和水平翻转(RandomHorizontalFlip),而测试阶段是执行缩放和中心裁剪。这其实是一种轻量级的策略,这里称之为baseline。torchvision的实现的ResNet50训练采用的策略就是这个,在ImageNet上的top1 acc可以达到76.1。
from torchvision import transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# 训练
train_transform = transforms.Compose([
# 这里的scale指的是面积,ratio是宽高比
# 具体实现每次先随机确定scale和ratio,可以生成w和h,然后随机确定裁剪位置进行crop
# 最后是resize到target size
transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
# 测试
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
AutoAugment
谷歌在2018年提出通过AutoML来自动搜索数据增强策略,称之为AutoAugment(算是自动数据增强开山之作)。搜索方法采用强化学习,和NAS类似,只不过搜索空间是数据增强策略,而不是网络架构。在搜索空间里,一个policy包含5个sub-policies,每个sub-policy包含两个串行的图像增强操作,每个增强操作有两个超参数:进行该操作的概率和图像增强的幅度(magnitude,这个表示数据增强的强度,比如对于旋转,旋转的角度就是增强幅度,旋转角度越大,增强越大)。每个policy在执行时,首先随机从5个策略中随机选择一个sub-policy,然后序列执行两个图像操作。搜索空间一共有16种图像增强类型,具体如下所示,大部分操作都定义了图像增强的幅度范围,在搜索时需要将幅度值离散化,具体地是将幅度值在定义范围内均匀地取10个值。论文在不同的数据集上( CIFAR-10 , SVHN, ImageNet)做了实验,这里给出在ImageNet数据集上搜索得到的最优policy(最后实际上是将搜索得到的前5个最好的policies合成了一个policy,所以这里包含25个sub-policies):
# operation, probability, magnitude
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None))
基于搜索得到的AutoAugment训练可以将ResNet50在ImageNet数据集上的top1 acc从76.3提升至77.6。一个比较重要的问题,这些从某一个数据集搜索得到的策略是否只对固定的数据集有效,论文也通过具体实验证明了AutoAugment的迁移能力,比如将ImageNet数据集上得到的策略用在5个 FGVC数据集(与ImageNet图像输入大小相似)也均有提升。
目前torchvision库已经实现了AutoAugment,具体使用如下所示(注意AutoAug前也需要包括一个RandomResizedCrop):
from torchvision.transforms import autoaugment, transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
transforms.RandomHorizontalFlip(hflip_prob),
# 这里policy属于torchvision.transforms.autoaugment.AutoAugmentPolicy,
# 对于ImageNet就是 AutoAugmentPolicy.IMAGENET
# 此时aa_policy = autoaugment.AutoAugmentPolicy('imagenet')
autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std)
])
RandAugment
AutoAugment存在的一个问题是搜索空间巨大,这使得搜索只能在代理任务中进行:使用小的模型在ImageNet的一个小的子集( 120类和6000图片)搜索。谷歌在2019年又提出了一个更简单的数据增强策略:RandAugment。这篇论文首先发现AutoAugment这样在小数据集上搜索出来的策略在大的数据集上应用会存在问题,这主要是因为数据增强策略和模型大小和数据量大小存在强相关,如下图所示可以看到模型或者训练数据量越大,其最优的数据增强的幅度越大,这说明AutoAugment得到的结果应该是次优的。另外,Population Based Augmentation这篇论文发现最优的数据增强幅度是随训练过程增加,而且不同的增强操作遵循类似的规律,这启发作者采用固定的增强幅度而不是去搜索。RandAugment相比AutoAugment的策略空间很小( vs ),所以它不需要采用代理任务,甚至直接采用简单的网格搜索。具体地,RandAugment共包含两个超参数:图像增强操作的数量N和一个全局的增强幅度M,其实现代码如下所示,每次从候选操作集合(共14种策略)随机选择N个操作(等概率),然后串行执行(这里没有判断概率,是一定执行)。这里的M取值范围为{0, . . . , 30}(每个图像增强操作归一化到同样的幅度范围),而N取值范围一般为 {1, 2, 3}。
# Identity是恒等变换,不做任何增强
transforms = ['Identity', 'AutoContrast', 'Equalize', 'Rotate', 'Solarize', 'Color', 'Posterize',
'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY']
def randaugment(N, M):
"""Generate a set of distortions.
Args:
N: Number of augmentation transformations to
apply sequentially.
M: Magnitude for all the transformations.
"""
sampled_ops = np.random.choice(transforms, N)
return [(op, M) for op in sampled_ops]
对于ResNet50,其搜索得到的N=2,M=9,RandAugment相比AutoAugment可以在ImageNet得到相似的效果(77.6),不过DeiT中发现使用RandAugment效果更好一些( DeiT-B:81.8 vs 81.2)。目前torchvision库也已经实现了RandAugment,具体使用如下所示:
from torchvision.transforms import autoaugment, transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
transforms.RandomHorizontalFlip(hflip_prob),
autoaugment.RandAugment(interpolation=interpolation),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std)
])
TrivialAugment
虽然RandAugment的搜索空间极小,但是对于不同的数据集还是需要确定最优的N和M,这依然有较大的实验成本。RandAugment后,华为提出了UniformAugment,这种策略不需要搜索也能取得较好的结果。不过这里我们介绍一项更新的工作:TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation。TrivialAugment也不需要任何搜索,整个方法非常简单:每次随机选择一个图像增强操作,然后随机确定它的增强幅度,并对图像进行增强。由于没有任何超参数,所以不需要任何搜索。从实验结果上看,TA可以在多个数据集上取得更好的结果,如在ImageNet数据集上,ResNet50的top1 acc可以达到78.1,超过RandAugment。TrivialAugment的图像增强集合和RandAugment基本一样,不过TA也定义了一套更宽的增强幅度,目前torchvision中已经实现了TrivialAugmentWide,具体使用代码如下所示:
from torchvision.transforms import autoaugment, transforms
augmentation_space = {
# op_name: (magnitudes, signed)
"Identity": (torch.tensor(0.0), False),
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
}
train_transform = transforms.Compose([
transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
transforms.RandomHorizontalFlip(hflip_prob),
autoaugment.TrivialAugmentWide(interpolation=interpolation),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std)
])
小结
这里简单介绍了几种常用且有效的数据增强策略,包括AutoAugment和RandAugment等等,下期我们会讲述其它的策略。
参考
Training data-efficient image transformers & distillation through attention AutoAugment: Learning Augmentation Policies from Data RandAugment: Practical automated data augmentation with a reduced search space TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation
推荐阅读
PyTorch1.10发布:ZeroRedundancyOptimizer和Join
谷歌AI用30亿数据训练了一个20亿参数Vision Transformer模型,在ImageNet上达到新的SOTA!
"未来"的经典之作ViT:transformer is all you need!
PVT:可用于密集任务backbone的金字塔视觉transformer!
涨点神器FixRes:两次超越ImageNet数据集上的SOTA
不妨试试MoCo,来替换ImageNet上pretrain模型!
机器学习算法工程师
一个用心的公众号