SRGAN-超分辨率图像复原

共 4630字,需浏览 10分钟

 ·

2021-09-23 16:48

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

本文转自|机器学习算法工程师


github:

https://github.com/OUCMachineLearning/OUCML/blob/master/GAN/srgan_celebA/srgan.py

arxiv:

https://arxiv.org/abs/1609.04802


我的研究方向:GAN

大家好,我是中国海洋大学的陈扬。在遥远的九月份,我开始做了keras的系列教程,现在我主要的研究方向转到了生成对抗网络,生成对抗网络的代码实现和训练机制比分类模型都要复杂和难入门.之前一段时间时间一直在帮璇姐跑cvpr的实验代码,做了蛮多的对比实验,其中我就发现了,keras的代码实现和可阅读性很好,搭生成对抗网络网络GAN就好像搭乐高积木一样有趣哦。不只是demo哦,我还会在接下来的一系列 keras教程中教你搭建Alexnet,Vggnet,Resnet,DCGAN,ACGAN,CGAN,SRGAN,等等实际的模型并且教你如何在GPU服务器上运行。


前言


上个星期发了一篇有关GAN入门的文章,同学们都觉得挺有趣的,上一次我写了如何理解最基础的GAN的原理,今天我给大家带来的是如何运用强大的GAN做一些好玩的应用.

超分辨率复原一直是计算机视觉领域一个十分热门的研究方向,在商业上也有着很大的用武之地,随着2014年goodflew那篇惊世骇俗的GAN发表出来,GAN伴随着CNN一起,可谓是乘风破浪,衍生出来琳琅满目的各种应用.

简单的来说,就给定一个低分辨率图片作为噪声z的输入,通过生成器的变换把噪声的概率分布空间尽可能的去拟合真实数据的分布空间.


01 基本框架

在这里,我们把生成器看的目标看成是要以次充好,判别器的目标是要明辨真假.

我们可以的看到,在生成器的前6层网络中,我们运用了残差块,为什么要用残差块呢?

因为我们可以从上图看出来,当损失函数从判别器开始反向传播会生成器的时候,实际上进过来很多层,我们知道越深的网络隐藏参数越多,在反向传播的过程中也越容易梯度弥散.而且残差连接的方法,就有效的保证了我们梯度信息能够有效的传递而增强生成对抗网络的鲁棒性.(事实上沃瑟斯坦loss也可以增强GAN训练的鲁棒性,以后会写)


02 celebA

再来聊聊今天用的数据集,这是Celeb-A,里面有大量的带标注信息的明星人脸.在目前很多的GAN的应用中,都是用CelebA作为基础的数据集,这个数据集大概在1.2G左右,可以在kaggle上下载.

  • 浏览数据集

[https://www.kaggle.com/jessicali9530/celeba-dataset]

A popular component of computer vision and deep learning revolves around identifying faces for various applications from logging into your phone with your face or searching through surveillance images for a particular suspect. This dataset is great for training and testing models for face detection, particularly for recognising facial attributes such as finding people with brown hair, are smiling, or wearing glasses. Images cover large pose variations, background clutter, diverse people, supported by a large quantity of images and rich annotations. This data was originally collected by researchers at MMLAB, The Chinese University of Hong Kong (specific reference in Acknowledgment section).


03 Overall


202,599 number of face images of various celebrities

10,177 unique identities, but names of identities are not given

40 binary attribute annotations per image

5 landmark locations


04 Super-Resolution IMAGE


简单点说,就是给你一张模糊的图片,让你复原一张高清的图片.

05 我们如何用生成对抗网络来做呢?


这个时候,我们可以把LRimg看成是一个噪声z的输入,G生成的是一个FAKE-HRimg,我们让D分辨fake-HRimg and original HRimg.


06 定义一个目标函数


Our ultimate goal is to train a generating function G that estimates for a given LR input image its corresponding HR counterpart. To achieve this, we train a generator network as a feed-forward CNN GθG parametrized by θG. Here θG = {W1:L ; b1:L } denotes the weights and biases of a L-layer deep network and is obtained by optimizing a SR-specific

loss function lSR. For training images IHR , n = 1, . . . , N n

withcorrespondingILR,n=1,...,N,wesolve:


07 提出preceptual loss


作者认为这更接近人的主观感受,因为使用pixel-wise的MSE使得图像变得平滑,而如果先用VGG来抓取到高级特征(feature)表示,再对feature使用MSE,可以更好的抓取不变特征。


  • 核心公式

这个公式我们要分成两个部分来看,先看前半部分:

这个公式的意思是,先看加号前面,我们希望D最大,所以应该最大,意味着我的判别器可以很好的识别出,真实的高分辨率图像是"true",在看加号后面的,要让log尽可能的大,需要的是ΘD(ΘG(z))尽可能的小,意味着我们生成模型复原的图片应该尽可能的被判别模型视为"FALSE".

再看后半部分部分

我们应该让G尽可能的小,加号前面的式子并没有G,所以无关,在看加号后面的式子

,要让ΘG尽可能地小,就要ΘD(ΘG(Z))尽可能的大,也就是说本来就一张低分辨率生成的图片,判别器却被迷惑了,以为是一张原始的高分辨率图片.这就是所谓的以次充好.


08 网络设计

09 loss函数

###vgg用于提取特征
self.vgg.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
###生成器
self.combined.compile(loss=['binary_crossentropy', 'mse'],
loss_weights=[1e-3, 1],
optimizer=optimizer)
###判别器
self.discriminator.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])


10 train
  • 训练判别器
d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


  • 训练生成器

image_features = self.vgg.predict(imgs_hr)
# Train the generators
g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])


11 实际结果

5000 batchsize

12 对比实验结果

13 致谢


嘻嘻嘻,大家要是喜欢这个系列的话就给我个小小的赞哦,等到我期末考考完试文件录视频来讲如何从零开始搭建生成对抗网络,emmmm现在大二学业压力确实大了,不过我们的创造会一如既往的用心做下去的,感谢你们的陪伴,也是我持续创造的动力源泉.


 

END


好消息,小白学视觉团队的知识星球开通啦,为了感谢大家的支持与厚爱,团队决定将价值149元的知识星球现时免费加入。各位小伙伴们要抓住机会哦!


下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 17
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报