人物属性模型移动端实验记录

GiantPandaCV

共 17514字,需浏览 36分钟

 ·

2021-03-18 22:47

【GiantPandaCV导语】最近项目有需求,需要把人物属性用在移动端上,需要输出性别,颜值和年龄三个维度的标签, 用来做数据分析收集使用,对速度和精度有一定的需求,做了一些实验,记录如下。

一、模型

模型结构,这里考虑了两种形式,一种是多头的,一种是单头的,具体如下:

  • SingleHead
    1. backbone+avgpool后面 接一个卷积,卷积核为(inp, (gender_class+beauty_class+age_class), 3, 3)
    2. backbone+avgpool后面 接入一个channel shuff层, 再接入一个卷积,和第一种一样。
  • MutilHead
    1. backbone+avgpool后面,接入三个FC,每个FC对应一个维度的任务。
    2. backbone+avgpool后面,先接入一个SE模块后,接三个FC,每个FC对应一个维度的任务。
    3. backbone+avgpool后面,接入一个512维度的FC,后接入三个FC,每个FC对应一个维度的任务。
    4. backbone+avgpool后面,接入三个512维度的FC来做embeeding,后接入三个FC,每个FC对应一个维度的任务。

如下图所示:

   
   

图1-不同模型结构

训练, 训练数据总计35w,每张图片都带有三个维度的标签,使用Horovod分布式框架进行训练,采用SGD优化器,warmup5个epoch,使用cosine进行衰减学习率,总计训练60个epoch,训练代码可以参考https://github.com/FlyEgle/cub_baseline。

实验对比,对于SingleHead模型,MutilHead的1,2模型,采用的是mobilenetv2作为backbone,对于MutilHead的3,4模型,采用的是mobilenetv2x0.5作为backbone。这里对比的baseline为resnest50的结果,结果如下:

   
   

图2-结果对比

结论,出于性能和速度的考虑,确定了以mbv2x0.5作为backbone,模型结构为mutilhead-4的模型。

模型SIZEFLOPsPARAMsgender_accbeauty_accage_acc
baseline(rs50)2565.7G31M0.9709821430.8973214290.790178571
mbv2x0.5(mutil_head)256127M2.66M0.9040178570.8348214290.725446429

二、蒸馏

mobilenetv2与resnest50在imagenet上的baseline大概相差8个点左右,所以我们自身的实验跑出来的结果也是在合理的范围内。为了进一步提升小模型的精度,选择用resnest50的模型来蒸馏mbv2x0.5的模型(ps:这里尝试过训练一个mbv2x2的模型,不过没有训的比resnest50高,所以还是使用resnest50)。蒸馏,采用的是传统的蒸馏方法,KL散度来作为损失,由于head相同,所以只需要考虑对logits蒸馏即可,KL散度代码如下:

class KLSoftLoss(nn.Module):
    r"""Apply softtarget for kl loss

    Arguments:
        reduction (str): "
batchmean" for the mean loss with the p(x)*(log(p(x)) - log(q(x)))
    "
""
    def __init__(self, temperature=1, reduction="batchmean"):
        super(KLSoftLoss, self).__init__()
        self.reduction = reduction
        self.eps = 1e-7
        self.temperature = temperature
        self.klloss = nn.KLDivLoss(reduction=self.reduction)

    def forward(self, s_logits, t_logits):
        s_prob = F.log_softmax(s_logits / self.temperature, 1)
        t_prob = F.softmax(t_logits / self.temperature, 1)
        loss = self.klloss(s_prob, t_prob) * self.temperature * self.temperature
        return loss

训练, 对于分类的问题,一般情况只是蒸馏输出的logits即可,由于多任务有多个head,所以会有多个logits,分别蒸馏即可,整体框架如下:

   
   

图4-蒸馏训练框架

蒸馏训练代码如下,由于学生和教师的网络差异性较大同时精度相差甚远,所以采用1:1的比例来进行训练,蒸馏的温度为25(T=5):

   
   

图5-蒸馏训练代码

结论,采用了3中不同的分辨率进行蒸馏实验,其中训练的size为224,推理为256的时候效果最好。

模型sizeteachergender_accbeauty_accage_acc
mbv2x0.5224->256resnest500.9665178570.895089290.75446429
mbv2x0.5192->224resnest500.9508928570.897321430.765625
mbv2x0.5160->224resnest500.9598214290.8906250.734375

三、剪枝

Slimming Prune,实验采用的剪枝方法是来自于Learning Efficient Convolutional Networks through Network Slimming,通过对BN的channel进行稀疏化来达到剪枝的效果(个人喜欢用比较简单稳定的方法,便于debug和修改)。

   
   

图5-Slimming

训练和剪枝

  • 训练,训练代码很简单,只需要再更新权重之前进行稀疏化处理即可,sr是超参,一般设置为1e-4,代码如下:

      optimizer.zero_grad()
      loss.backward()

      # use the slimming prune for training
      if args.prune and args.use_sr:
          for m in model.modules():
              if isinstance(m, nn.BatchNorm2d):
                  m.weight.grad.data.add_(args.sr * torch.sign(m.weight.data))

      optimizer.step()
  • 剪枝, 由于模型结构是mobilenetv2的结构,有DW存在,所以,在剪枝的时候需要注意groups的数量和channel需要保持一致,同时,为了方便移动端优化加速,要保证channel是8的倍数,剪枝代码逻辑如下:

    1. 先设置一定的剪枝比例p,如0.1,0.2,0.3...,按BN的channel总数从小到大来进行过滤。
    2. 保留最大比例的最小阈值,防止prune过大,导致模型崩溃。
    3. 对于不满足8的倍数的channel,按8的倍数补齐,补齐的方法是对prune过的channel排序,从大到小按差值补齐。
    4. 保存除了第一个InvertedResidual模块以外的所有模块剪枝后的channel数量,重构模型。
    5. 测试结果,考虑是否进行finetune训练。

剪枝部分代码如下:

def prune_only_res_hidden(percent, model, keep_channel=True, channel_ratio=8, cuda=True):
    """only prune the inverResidual module first bn layer
    "
""
    total = 0
    highest_thre = []
    for m in model.modules():
        if isinstance(m, InvertedResidual):
            # only prune the 3 conv layer
            if len(m.conv) > 5:
                for i in range(len(m.conv)):
                    if i == 1:
                        if isinstance(m.conv[i], nn.BatchNorm2d):
                            total += m.conv[i].weight.data.shape[0]
                            highest_thre.append(m.conv[i].weight.data.abs().max().item())
                            total += m.conv[i+3].weight.data.shape[0]
                            highest_thre.append(m.conv[i+3].weight.data.abs().max().item())



    bn = torch.zeros(total)
    index = 0
    for m in model.modules():
        if isinstance(m, InvertedResidual):
            # only prune the 3 conv layer
            if len(m.conv) > 5:
                for i in range(len(m.conv)):
                    if i != len(m.conv) - 1:
                        if isinstance(m.conv[i], nn.BatchNorm2d):
                            size = m.conv[i].weight.data.shape[0]
                            bn[index:(index+size)] = m.conv[i].weight.data.abs().clone()
                            index += size

    print(bn.size())
    y, i = torch.sort(bn)
    thre_index = int(total * percent)
    thre = y[thre_index]
    highest_thre = min(highest_thre)

    # 判断阈值
    if thre > highest_thre:
        thre = highest_thre

    print("the min thre is {}, the max thre is {}!!!!".format(thre, highest_thre))
    pruned = 0
    c = {}
    cfg_mask = []
    idx = 0

    for m in model.modules():
        if isinstance(m, InvertedResidual):
            # only prune the 3 conv layer
            if len(m.conv) > 5:
                for i in range(len(m.conv)):
                    if i == 1:
                        if isinstance(m.conv[i], nn.BatchNorm2d):
                            weight_copy = m.conv[i].weight.data.clone()
                            if cuda:
                                mask = weight_copy.abs().gt(thre).float().cuda()
                            else:
                                mask = weight_copy.abs().gt(thre).float()

                            if keep_channel:
                                keep_channel_number = get_min_number(torch.sum(mask), channel_ratio)
                                if torch.sum(mask) < keep_channel_number:
                                    n = int(keep_channel_number - torch.sum(mask))
                                    mask_index = torch.where(mask==0)[0]
                                    new_weight = weight_copy.abs()[mask_index]
                                    _, weight_index = torch.sort(new_weight)
                                    w_index = mask_index[weight_index[-n: ]]
                                    mask[w_index] = 1.0

                            pruned = pruned + mask.shape[0] - torch.sum(mask)
                            # first conv + bn
                            m.conv[i].weight.data.mul_(mask)
                            m.conv[i].bias.data.mul_(mask)
                            # second conv + bn
                            m.conv[i+3].weight.data.mul_(mask)
                            m.conv[i+3].bias.data.mul_(mask)
                            c[idx] = int(torch.sum(mask))
                            cfg_mask.append(mask.clone())

                            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(idx, mask.shape[0], int(torch.sum(mask))))
                            idx += 1
    print(c)
    print(len(c))
    print(len(cfg_mask))
    # pruned_ratio = pruned / total
    print('Pre-processing Successful!!!')
    return model, cfg_mask, c

直接保存模型后测试,对比结果如下:

模型ratioFLOPsParamsgenderbeautyage
mobilenetv2x0.50.24111.95M2.54M0.9575892860.8928571430.741071429
mobilenetv2x0.50.3107.51M2.51M0.9598214290.8928571430.741071429
mobilenetv2x0.50.479.57M2.46M0.6093750.5334821430.098214286
mobilenetv2x0.50.579.56M2.46M0.6093750.5334821430.098214286

再次使用resnest50进行蒸馏后,对比结果如下:

模型ratioFLOPsParamsgenderbeautyage
mobilenetv2x0.50.24111.95M2.54M0.968750.9017857140.75
mobilenetv2x0.50.3107.51M2.51M0.9575892860.8839285710.738839286
mobilenetv2x0.50.479.57M2.46M0.9575892860.8816964290.741071429

添加2w标注的业务数据,总计训练数据37w,蒸馏后的结果如下:

模型ratioFLOPsParamsgenderbeautyage
mobilenetv2x0.50.24111.95M2.54M0.9754464290.9017857140.752232143
mobilenetv2x0.50.3107.51M2.51M0.968750.8906250.761160714
mobilenetv2x0.50.479.57M2.46M0.9665178570.8928571430.723214286

针对性能的需求,考虑用0.3的版本,如果速度要求更快的话,考虑0.4的版本。

四、TODO

  1. 训练一个基于BYOL的pretrain模型。
  2. 把没有标注的数据,用模型打上伪标签后参与训练。
  3. 训练一个更大的teacher模型。
  4. 使用百度的JSDivLoss作为蒸馏损失。

五、结论

  • 对于移动端的任务来说,蒸馏和剪枝是必不可少的,尤其是要去训练一个比较好的teacher,这里的teacher可以同结构也可以异结构,只要最后logits一致即可。
  • 由于移动端会根据X8或者X4的倍数优化,所以剪枝的时候尽量保持channel的倍数,建议常备一种便于修改的剪枝代码。
  • 小模型具备成长为大模型的潜质,只要训练方法适当。

结束语

本人才疏学浅,以上都是自己在做项目中的一些方法和实验,以及一些粗浅的思考,并不一定完全正确,只是个人的理解,欢迎大家指正,留言评论。

参考文献

  • mobilenetv2 https://export.arxiv.org/pdf/1801.04381
  • resnest https://export.arxiv.org/pdf/2004.08955
  • Slimming prune https://arxiv.org/pdf/1708.06519.pdf

欢迎关注GiantPandaCV, 在这里你将看到独家的深度学习分享,坚持原创,每天分享我们学习到的新鲜知识。( • ̀ω•́ )✧

有对文章相关的问题,或者想要加入交流群,欢迎添加BBuf微信:

二维码

为了方便读者获取资料以及我们公众号的作者发布一些Github工程的更新,我们成立了一个QQ群,二维码如下,感兴趣可以加入。

公众号QQ交流群


浏览 49
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报