视觉底层任务优秀开源工作:BasicSR 库使用方法
极市导读
本文整理自 BasicSR 原作者的关于 BasicSR 库的文档介绍,github 介绍以及相关知乎讲解,旨在对 BasicSR 库的特性和使用方法做一次汇总和梳理。>>加入极市CV技术交流群,走在计算机视觉的最前沿
目录
1 什么是 BasicSR 库
2 安装 BasicSR
3 如何使用 BasicSR 开发自己的项目
1 什么是 BasicSR 库
BasicSR 是全称 Basic Super-Resolution 的缩写,它是一个基于 PyTorch 的开源图像视频复原工具箱 (Open-Source Image and Video Restoration Toolbox)。它适配多种视觉底层任务,比如超分辨率,去噪,去模糊,去 JPEG 压缩噪声等。旨在将各种 Super Restoration 模型整合在一起,形成一个复现视觉底层任务模型结果的统一框架。
作者:Xintao Wang,博士毕业于香港中文大学信息工程专业。主要研究计算机底层视觉,特别是图像和视频的超分辨率。曾在 NTIRE, PIRM 等国际超分辨率比赛中多次获得冠军,提出了具有影响力的 ESRGAN,EDVR 等工作。Google Scholar 引用2900余次。目前专注于研究迈向实际应用的图像和视频的复原与增强。首先致敬大佬!
本文整理自 BasicSR 原作者的关于 BasicSR 库的文档介绍,github 介绍以及相关知乎讲解,旨在对 BasicSR 库的特性和使用方法做一次汇总和梳理,算是个引子。更多关于 1. BasicSR 代码解读,2. 如何使用,3. 作者的个人经验等等的更多内容也欢迎大家参考作者的官方指南~ (持续更新)。
作者个人主页:https://xinntao.github.io/
作者 github 链接:https://github.com/xinntao
BasicSR 库链接:https://github.com/xinntao/BasicSR
Gitee码云地址:https://gitee.com/xinntao/BasicSR
作者官方指南:https://www.zhihu.com/column/c_1295528110138163200
BasicSR 库提供了一套非常全面的图像/视频复原的代码框架。因为在实际问题中,超分往往是不会单独出现的,它往往是和去噪,去模糊,去压缩等复杂的降质 (降低质量) 的过程一起存在的。从这一点来看,BasicSR 提供了一个很棒的分布式训练和评估的代码框架,方便后人开发。更难能可贵的是它还在不断地更新迭代新的训练方法,新的复原模型和优化代码。
并且,除了上面的 BasicSR 库的链接,作者还专门撰写了 BasicSR 库的训练示例的 github,帮助新手更方便地了解和使用 BasicSR 库。
https://github.com/xinntao/BasicSR-examples
BasicSR 库的目标:
按照作者的讲述,BasicSR 是为了方便研究者的。因此之前很多写法,设计都是以方便做实验来做的。作者平时是使用 BasicSR 来开发新算法,做实验的,因此会把自己觉得好用的,便捷的设计放进去。这些改动也都是经过自己试验的,每个改动都会有个来由,为解决某个小问题而改动的。
作者现在依旧做着图像和视频的复原工作,而且会朝着实际应用的方向走。加上他平时也用 BasicSR 来开发,因此这个代码库的发展路径会不可避免地受到作者研究的关注点的影响。作者会将自己正在使用的,复现的,尝试的一些代码逐渐 merge到 BasicSR 里面。
总之,BasicSR 一方面以研究者为中心,另一方面会提供方便用户使用的脚本和说明文档,朝着 AdvancedSR 的目标前进。
2 安装 BasicSR
有两种方式安装 BasicSR。
如果你想研究 BasicSR 的细节或者开发它以满足你的需求,建议选择方案一。 如果你只是想把 BasicSR 作为一个包使用,建议选择方案二。
第一种安装方式是先把 BasicSR 代码 local clone,就是下载到本地的方式安装:
克隆代码:
git clone https://github.com/xinntao/BasicSR.git
安装依赖库:
cd BasicSR
pip install -r requirements.txt
安装 BasicSR:在 BasicSR 根目录下执行以下命令以安装 BasicSR:
情况1:不需要 C++ extensions:
python setup.py develop
情况2:需要 JIT mode 的 C++ extensions,且在安装过程中不需要编译它们:
python setup.py develop
情况3:在安装过程中需要编译 C++ extensions:
BASICSR_EXT=True python setup.py develop
情况3.5:如果还需要指定 CUDA 路径:
CUDA_HOME=/usr/local/cuda \
CUDNN_INCLUDE_DIR=/usr/local/cuda \
CUDNN_LIB_DIR=/usr/local/cuda \
BASICSR_EXT=True python setup.py develop
第二种安装方式是把 BasicSR 当作一个额外的 python package - basicsr,即可以通过 pip 安装:
情况1:不需要 C++ extensions:
pip install basicsr
情况2:需要 JIT mode 的 C++ extensions,且在安装过程中不需要编译它们:
pip install basicsr
情况3:在安装过程中需要编译 C++ extensions:
BASICSR_EXT=True pip install basicsr
如果遇到运行错误,如 ImportError: cannot import name 'deform_conv_ext' | 'fused_act_ext' | 'upfirdn2d_ext',你可以通过重新安装检查编译过程。下面的命令将打印详细的日志。
BASICSR_EXT=True pip install basicsr -vvv
情况3.5:如果还需要指定 CUDA 路径:
CUDA_HOME=/usr/local/cuda \
CUDNN_INCLUDE_DIR=/usr/local/cuda \
CUDNN_LIB_DIR=/usr/local/cuda \
BASICSR_EXT=True pip install basicsr
一些 PyTorch C++ extensions 的使用:
deformable convolution:href="https://github.com/xinntao/BasicSR/blob/master/basicsr/ops">dcnfor EDVR (officialtorchvision.ops.deform_conv2d
instead)
StyleGAN customized operators:upfirdn2dandfused_actfor StyleGAN2
有2种方案:
compile the PyTorch C++ extensions during installation OR load the PyTorch C++ extensions just-in-time (JIT)
可以根据需要选择合适的方案 (下面是这两种方案的优劣对比):
如果你不需要使用 PyTorch C++ extensions,就跳过这一步,也不需要设置BASICSR_EXT
或BASICSR_JIT
环境变量。
3 如何使用 BasicSR 开发自己的项目
这里原作者给大家提供了一个使用 BasicSR 的教程,这一小节译自作者的 example 教程,包含的所有代码都在下面的链接里面:
当我们开发一个新的方法时,我们往往在改进: data, arch, model;而很多流程、基础的功能其实是共用的。那么,我们希望可以专注于主要功能的开发,而不要重复造轮子。BasicSR 能够把很多相似的功能都独立出来,我们只要关心 data, arch, model 的开发即可。
首先通过第二种安装方式安装 BasicSR:
pip install basicsr
git clone https://github.com/xinntao/BasicSR-examples.git
cd BasicSR-examples
接下来原作者通过一个简单的例子展示 BasicSR 如何使用:
训练集:BSDS100(https%3A//github.com/xinntao/BasicSR-examples/releases/download/0.0.0/BSDS100.zip) 验证集:Set5(https%3A//github.com/xinntao/BasicSR-examples/releases/download/0.0.0/Set5.zip)
下载这些数据集:
python scripts/prepare_example_data.py
示例数据下载完以后在 datasets/example
文件夹中。
原作者的这个示例使用的是一个超分任务,它以低分辨率图像为输入,输出高分辨率图像。低分辨率图像包含:1) CV2 X4 降采样。2) JPEG压缩 (quality=70)。
原作者这里的网络架构使用一个类似 SRCNN 的网络结构,并且在训练中同时使用 L1 和 L2 (MSE) 损失。这个示例分以下几步来完成:
① data
首先原作者的例子里面写了一个新的data/example_dataset.py。实现一个 dataset 一般需要以下几步。
读取 Ground-Truth (GT) 图像。BasicSR 提供了 FileClient 这个功能来从文件夹中方便地读取 LMDB 文件和 meta_info txt。在这个例子中,作者使用了 folder mode。 合成低分辨率的图像。通过 __getitem__(self, index)
这个函数我们能够轻易完成数据读取的流程,比如说 downsampling 或者加 JPEG compression。更多的基本操作在 [basicsr/data/degradations], [basicsr/data/tranforms] ,和 [basicsr/data/data_util] 里面。把图片转化成 PyTorch 的 tensor 类型并返回。
有两点需要注意:
在 ExampleDataset
之前加上@DATASET_REGISTRY.register()
。新的 dataset 在命名时要以 _dataset.py
为结尾,比如example_dataset.py
,这样程序能够自动 import 这个类。
在 option 里面加上:
datasets:
train: # training dataset
name: ExampleBSDS100
type: ExampleDataset # the class name
# ----- the followings are the arguments of ExampleDataset ----- #
dataroot_gt: datasets/example/BSDS100
io_backend:
type: disk
gt_size: 128
use_flip: true
use_rot: true
# ----- arguments of data loader ----- #
use_shuffle: true
num_worker_per_gpu: 3
batch_size_per_gpu: 16
dataset_enlarge_ratio: 10
prefetch_mode: ~
val: # validation dataset
name: ExampleSet5
type: ExampleDataset
dataroot_gt: datasets/example/Set5
io_backend:
type: disk
② arch
原作者的例子里面写了一个新的archs/example_arch.py。
有两点需要注意:
在 ExampleArch
之前加上@ARCH_REGISTRY.register()
。新的 arch 在命名时要以 _arch.py
为结尾,比如example_arch.py
,这样程序能够自动 import 这个类。
在 option 里面加上:
# network structures
network_g:
type: ExampleArch # the class name
# ----- the followings are the arguments of ExampleArch ----- #
num_in_ch: 3
num_out_ch: 3
num_feat: 64
upscale: 4
③ model
原作者的例子里面写了一个新的models/example_model.py,包含模型的训练过程。
在这个文件中,通过 SRModel
import 模型。许多模型都有类似的操作,所以你可以从 basicsr/models 继承和修改。损失函数使用 L1 and L2 (MSE) loss。 其他内容,如 setup_optimizers
,validation
,save
是从SRModel
里面 import 的。
有两点需要注意:
在 ExampleModel
之前加上@MODEL_REGISTRY.register()
。新的 model 在命名时要以 _model.py
为结尾,比如example_model.py
,这样程序能够自动 import 这个类。
在 option 里面加上:
# training settings
train:
optim_g:
type: Adam
lr: !!float 2e-4
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [50000]
gamma: 0.5
total_iter: 100000
warmup_iter: -1 # no warm up
# ----- the followings are the configurations for two losses ----- #
# losses
l1_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
l2_opt:
type: MSELoss
loss_weight: 1.0
reduction: mean
④ training pipeline
整体的 training pipeline 可以复用 BasicSR库的 basicsr/train.py。基于它,train.py 文件可以非常简洁:
import os.path as osp
import archs # noqa: F401
import data # noqa: F401
import models # noqa: F401
from basicsr.train import train_pipeline
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir))
train_pipeline(root_path)
以上四步就写完了 example_option.yml,接下来开始运行:
⑤ debug 模式
运行命令:
python train.py -opt options/example_option.yml --debug
⑥ normal training
运行命令:
python train.py -opt options/example_option.yml
如果训练过程被意外中断,需要恢复。请在命令中使用 --auto_resume。
python train.py -opt options/example_option.yml --auto_resume
到目前为止,你已经完成了使用 BasicSR 开发自己的项目。
最后要强调的是,本文只是对 BasicSR 库的一个简单介绍,算是个引子。更多关于 1. BasicSR 代码解读,2. 如何使用,3. 作者的个人经验等等的更多内容也欢迎大家参考作者的官方指南~ (持续更新)。
参考:
https://zhuanlan.zhihu.com/p/261223409
https://github.com/xinntao/BasicSR-examples
https://github.com/xinntao/BasicSR
公众号后台回复“数据集”获取30+深度学习数据集下载~
# 极市平台签约作者#
科技猛兽
知乎:科技猛兽
清华大学自动化系19级硕士
研究领域:AI边缘计算 (Efficient AI with Tiny Resource):专注模型压缩,搜索,量化,加速,加法网络,以及它们与其他任务的结合,更好地服务于端侧设备。
作品精选