收藏 | 盘点语义分割小技巧

共 18696字,需浏览 38分钟

 ·

2024-07-31 10:22

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

仅作学术分享,不代表本公众号立场,侵权联系删除
转载于:作者丨博学的咖喱酱@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/599936122
编辑丨极市平台

语义分割技巧

选择模型

  • 使用小尺寸图片选择模型,因为小尺寸图片可以有大的batchsize,使得选择模型过程消耗的时间较少。而大尺寸图片和小尺寸图片训练相同的次数,loss下降的效果几乎相同。小尺寸约256x256.
  • 先确定合适的参数,如学习率、损失函数参数等,参数一样的情况下,训练不同的模型。好的模型前期会下降较快,在小尺寸图片下预训练尤其有效。大尺寸图片训练反而下降不快,看出不下降趋势。
  • 确定模型参数量,对于有些网络,模型参数量太大会导致loss曲线奇异。这一般是由于下采样过多、或attention通道数过多造成的,可以减少模型的参数量,确定合适的模型参数量。

训练模型

  • 使用粗略的损失函数参数训练到底后,可以稍微改变损失函数参数,如改变bwloss和dloss的比例,对于测试集精度可能会有提升。
  • 对于语义分割,focalloss等损失函数属于辅助损失函数,在训练前期起到加速训练的效果,在总损失函数中占比例较低,一般低于0.2的比例,约0.1。diceloss是主要损失函数,对于小前景分割有效果,能直接提升测试集精度,一般最后精训练只用diceloss损失,预训练是focalloss+diceloss。
  • 先使用小尺寸图片(约256x256)预训练,再使用中尺寸图片(约512x512)训练,最后使用大尺寸图片(约1024x1024)精训练。小尺寸图片可以有很大的batchsize,且使用小尺寸图片训练相同的次数收敛更快,估计小尺寸图片收敛速度为中尺寸图片的两倍,大尺寸最慢,大尺寸图片通常batchsize只能取1,无法直接训练,且收敛效果很差,一般用于精训练。(使用不同尺寸图片训练是语义分割任务的优点,像目标检测等任务一般很难这样做,所以输入尺寸只能很小)

模型改进

  • maxpool可以增强抗噪能力。avgpool或convstridepool可以保留位置信息。
  • 运用跨层连接结构可以不影响模型容量下加速训练。这些结构通常被封装成一个模块。现有大量的模型都是使用了优美的跨层连接结构,训练更容易,如DFANet。
  • 有时瓶颈结构会有更好的效果,猜测可能起到编码器-解码器的作用,起到降噪的作用。如卷积从inch->outch//2;outch//2->outch//2;outch//2->outch这样,其中outch//2可能小于inch,但这样同样有效。
  • 使用大量的分离卷积,模型参数会更小,容量会很大。分离卷积效果比非对称卷积、大卷积要好,非对称卷积一般是用在很小尺寸(约32x32)的经过下采样的图片上,而大卷积一般是用在网络一开始输入的地方。对于unet的结构,网络深度不是很深,因此使用大卷积可能较好。对于fcn式,网络可以堆叠很深,因此使用分离卷积效果较好。
  • 在语义分割中,attention的本质是降噪。让不需要的信息(噪声)置为零,跟relu的作用一样,所以relu在语义分割中尤为有效。使用attention模块能更好的开发一个模型的潜能,能更快收敛,可以用在模块block的后面,跟BN一样。而且用很多也不会有副作用。attention模块主要有通道型和空间型两种,通道型的attention可能会导致loss训练很难(异常),带空间型的可以缓解这种现象,使得训练容易。常用的有SE-module或CBAM模块。
  • 在编码器-解码器的网络中,主要工作都是由编码器做的,编码器参数要很大,相反,解码器要尽量简单,参数量要很小,如通常的无偏置1x1卷积来聚合编码器信息,或使用无参数的双线性插值来放大图片。使用反卷积等技术上采样反而训练更困难,因为其加大了解码器的参数量。

理解

  • 神经网络可以看成是降噪,即不断精炼信息的过程,这意味着需要不断去除冗余信息或噪声。attention、bn、relu等手段一定程度上可以看成是去除噪声的手段,因此bottleneck结构有效,因为它在瓶颈处精炼出有用信息然后处理这些有用信息。3x3卷积的作用是处理信息,那么1x1卷积的作用就是降维精炼信息,因此1x1升维=>3x3处理=>1x1精炼也是一种有效结构。
  • ottleneck结构本质是降噪。bottleneck结构一般用在网络前面较好,且bottle处卷积参数初始化为正,使用torch.nn.init.uniform_(conv.weight, a=0, b=1)初始化权重和nn.init.zeros_(conv.bias)置偏置为零。原因为在瓶颈处降噪的效果好。
  • 1x1卷积的主要作用是复制图片(增加通道数)给3x3卷积处理,因此增加通道数可以使用1x1卷积而不是3x3卷积,3x3卷积用于处理图片,只需要分离卷积即可。1x1卷积还可以将多张图片(多通道)合并成更少的图片(降通道),即将处理的结果组合起来。
########## DFANET模型定义 #############

##################################### DFANET模型定义 ####################################

# 纯粹的卷积,如Conv2d
class SeparableConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, bottle=False):
        super(SeparableConv2d, self).__init__()
        self.sconv = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, kernel_size=5, padding=2, groups=in_ch, stride=stride),
            Attention(in_ch),
            nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),
        )
        if bottle:
            torch.nn.init.uniform_(self.sconv[0].weight, a=0, b=1)
            torch.nn.init.uniform_(self.sconv[2].weight, a=0, b=1)
            nn.init.zeros_(self.sconv[0].bias)

    def forward(self, x):
        return self.sconv(x)


# bottleneck结构
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, reduction = 2, bottle=False):
        super(Block, self).__init__()
        self.block = nn.Sequential(
            SeparableConv2d(in_ch, out_ch//reduction, stride=stride),
            nn.BatchNorm2d(out_ch//reduction),
            nn.ReLU(inplace=True),
            SeparableConv2d(out_ch//reduction, out_ch//reduction, bottle=bottle),
            nn.BatchNorm2d(out_ch//reduction),
            nn.ReLU(inplace=True),
            SeparableConv2d(out_ch//reduction, out_ch),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
        self.proj = nn.Sequential(
            nn.MaxPool2d(stride, stride=stride),
            nn.Conv2d(in_ch, out_ch, 1, bias=False)
        )
    def forward(self, x):
        out = self.block(x)
        identity = self.proj(x)
        return out + identity


class enc(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(enc, self).__init__()
        self.encblocks = nn.Sequential(
            Block(in_ch, out_ch, stride=2),
            Block(out_ch, out_ch),
            Block(out_ch, out_ch),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            Attention(out_ch),
        )

    def forward(self, x):
        return self.encblocks(x)


class Attention(nn.Module):
    def __init__(self, channels):
        super(Attention, self).__init__()
        mid_channels = int((channels/2)**0.5) # 1 4 9
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.sharedMLP = nn.Sequential(
            nn.Conv2d(channels, mid_channels, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, mid_channels, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, channels, kernel_size=1),
        )

    def forward(self, x):
        avg = self.sharedMLP(self.avg_pool(x))
        x = x * torch.sigmoid(avg)
        return x


class SubBranch(nn.Module):
    def __init__(self, in_ch, chs, branch_index):
        super(SubBranch,self).__init__()
        self.encs = nn.ModuleList()
        current_ch = in_ch
        for i, ch in enumerate(chs):
            self.encs.append(enc(current_ch, ch))
            if branch_index != 0 and i < len(chs)-2:
                current_ch = chs[i]+chs[i+1]
            else:
                current_ch = ch
        self.branch_index = branch_index

    def forward(self, x0, *args):
        retlist = []
        if self.branch_index == 0:
            for i, enc in enumerate(self.encs):
                retlist.append(enc(x0 if i==0 else retlist[-1]))
        else:
            for i, enc in enumerate(self.encs):
                if i == 0:
                    retlist.append(enc(x0))
                elif i-1 < len(args):
                    retlist.append(enc(torch.cat([retlist[-1], args[i-1]], 1)))
                else:
                    retlist.append(enc(retlist[-1]))

        return retlist


class DFA_Encoder(nn.Module):
    def __init__(self, chs, m): # m个subbranch
        super(DFA_Encoder,self).__init__()
        self.branchs = nn.ModuleList()
        for i in range(m):
            current_inch = chs[0] if i==0 else chs[1]+chs[-1]
            self.branchs.append(SubBranch(current_inch, chs[1:], branch_index=i))
        self.n = len(chs) - 1
        self.m = m 

    def forward(self, x):
        lowfeatures = [None]*self.m
        highfeatures = [None]*self.m
        lastvariables = [None]*(self.n-2)# -2
        lasthighfeaturelist = []
        for i, branch in enumerate(self.branchs):
            tem = torch.cat([x if i==0 else lowfeatures[i-1]]+lasthighfeaturelist, 1)
            lowfeatures[i], *lastvariables, highfeatures[i] = branch(tem, *lastvariables)
            if i != self.m-1:
                lasthighfeaturelist = [F.interpolate(highfeatures[i], size=lowfeatures[i].shape[2:], mode='bilinear', align_corners=True)]
        return lowfeatures, highfeatures

class DFA_Decoder(nn.Module):
    def __init__(self, chs, out_ch, m):
        super(DFA_Decoder,self).__init__()
        self.lowconv = nn.Sequential(
            nn.Conv2d(chs[1], chs[0], kernel_size=1, bias=False),
            nn.BatchNorm2d(chs[0]),
        )
        self.highconv = nn.Sequential(
            nn.Conv2d(chs[-1], chs[0], kernel_size=1, bias=False),
            nn.BatchNorm2d(chs[0]),
        )

        self.shuffleconv = nn.Conv2d(chs[0], out_ch, kernel_size=1, bias=True)
        self.m = m

    def forward(self, lows, highs, proj):# proj没什么用
        for i in range(1, self.m):
            lows[i] = F.interpolate(lows[i], size=lows[i-1].shape[2:], mode='bilinear', align_corners=True)
        for i in range(self.m):
            highs[i] = F.interpolate(highs[i], size=lows[0].shape[2:], mode='bilinear', align_corners=True)

        x_low = self.lowconv(sum(lows))
        x_high = self.highconv(sum(highs))
        x_sf = self.shuffleconv(x_low + x_high)
        return F.interpolate(x_sf, scale_factor=2, mode='bilinear', align_corners=True) # 没有有效的上采样方式


################################# PreModule ########################################
class PreModule(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(PreModule, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=7, padding=3, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
        self.block = Block(out_ch, out_ch, reduction=8, bottle=True) # 若下采样,val有极限0.53

    def forward(self, x):
        out = self.conv(x)
        return self.block(out)


class DFANet(nn.Module):
    def __init__(self, chs, in_ch, out_ch, m): # 改成chs=[32, 64, 128]
        super(DFANet, self).__init__()
        self.premodule = PreModule(in_ch, chs[0]) 
        self.encoder = DFA_Encoder(chs, m)
        self.decoder = DFA_Decoder(chs, out_ch, m)

    def forward(self, x):
        x = self.premodule(x)
        lows, highs = self.encoder(x)
        y = self.decoder(lows, highs, x)
        return torch.softmax(y, dim=1)
      
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 22
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报