Pytorch常见的坑汇总

小白学视觉

共 4507字,需浏览 10分钟

 ·

2021-09-23 16:50

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

本文转自|深度学习这件小事
最近刚开始用pytorch不久,陆陆续续踩了不少坑,记录一下,个人感觉应该都是一些很容易遇到的一些坑,也在此比较感谢帮我排坑的小伙伴,持续更新,也祝愿自己遇到的坑越来越少。
首先作为tensorflow的骨灰级玩家+轻微强迫症患者,一路打怪升级,从0.6版本用到1.2,再用到1.10,经历了tensorfow数个版本更迭,这里不得不说一下tf.data.dataset+tfrecord使用起来效率远比dataloader高的多。
tensorflow有一个比较好用的队列机制,tf.inputproducer + tfrecord, 但是inputproducer有一个bug,就是无法对每个epoch单独shuffle,它只能整体shuffle,也就意味着我们无法进行正常的训练流程(train几个epoch,在validation上测一个epoch,最终选一个validation上的最好的结果,进行test)。后来我当时给官方提了一个issue,官方当时的回答是,这个bug目前无法解决,但是他们在即将到来的tf1.2版本中, 推出的新型数据处理API tf.contrib.data.dataset(tf1.3版本将其合并到了tf.data.dataset)可以完美解决这个bug,并且将于tf2.0摒弃tf.input_producer。然后tf1.2版本刚出来以后,我就立马升级并且开始tf.data.dataset踩坑,踩了大概2周多的坑,(这个新版的API其实功能并不是非常强大,有不少局限性,在此就不展开)。
——————————————————————————
好像扯远了,回归pytorch,首先让我比较尴尬的是pytorch并没有一套属于自己的数据结构以及数据读取算法,dataloader个人感觉其实就是类似于tf中的feed,并没有任何速度以及性能上的提升。
先总结一下遇到的坑:
1.没有高效的数据存储,cv.imread在网络训练过程中效率低

解决方案:
当时看到了一个还不错的github链接,
https://github.com/Lyken17/Efficient-PyTorch
主要是讲如何使用lmdb,h5py,pth,lmdb,n5等数据存储方式皆可以。
个人的感受是,h5在数据调用上比较快,但是如果要使用多线程读写,就尽量不要使用h5,因为h5的多线程读写好像比较麻烦。
http://docs.h5py.org/en/stable/mpi.html
这里贴一下h5数据的读写代码(主要需要注意的是字符串的读写需要encode,decode,最好用create_dataset,直接写的话读的时候会报错):

imagenametotal_.append(os.path.join('images', imagenametotal).encode())
with h5py.File(outfile) as f:
f.create_dataset('imagename', data=imagenametotal_)
f['part'] = parts_
f['S'] = Ss_
f['image'] = cvimgs

with h5py.File(outfile) as f:
imagename = [x.decode() for x in f['imagename']]
kp2ds = np.array(f['part'])
kp3ds = np.array(f['S'])
cvimgs = np.array(f['image'])

2.gpu imbalance
张航学长Hang Zhang (张航)提了一个开源的gpu balance的工具--PyTorch-Encoding。
使用方法还是比较便捷的,如下所示:
from balanced_parallel import DataParallelModel, DataParallelCriterionmodel = DataParallelModel(model, device_ids=gpus).cuda()criterion = loss_fn().cuda()
这里其实有2个注意点,第一,测试的时候需要手动将gpu合并,代码如下:
from torch.nn.parallel.scatter_gather import gatherpreds = gather(preds, 0)
第二,当loss函数有多个组成的时候,比如 loss = loss1 + loss2 + loss3
那么需要把这三个loss写到一个class中,然后再forward里面将其加起来。
其次,我们还可以用另外一个函数distributedDataParallel来解决gpu imbalance的问题.
使用方法如下:(注:此方法好像无法和h5数据同时使用)
from torch.utils.data.distributed import DistributedSamplerfrom torch.nn.parallel import DistributedDataParallel
torch.distributed.init_process_group(backend="nccl")# 配置每个进程的gpulocal_rank = torch.distributed.get_rank()torch.cuda.set_device(local_rank)device = torch.device("cuda", local_rank)
#封装之前要把模型移到对应的gpumodel.to(device)model = torch.nn.parallel.DistributedDataParallel(model,device_ids=[local_rank], output_device=local_rank)
#原有的dataloader上面加一个数据sampletrain_loader = torch.utils.data.DataLoader( train_dataset, sampler=DistributedSampler(train_dataset) )
3.gpu利用率不高,+gpu现存占用浪费
常用配置:
1主函数前面加:(这个会牺牲一点点现存提高模型精度)
cudnn.benchmark = Truetorch.backends.cudnn.deterministic = Falsetorch.backends.cudnn.enabled = True
2训练时,epoch前面加:(定期清空模型,效果感觉不明显)
torch.cuda.empty_cache()
3无用变量前面加:(同上,效果某些操作上还挺明显的)
del xxx(变量名)
4dataloader的长度_len_设置:(dataloader会间歇式出现卡顿,设置成这样会避免不少)
def __len__(self): return self.images.shape[0]
5dataloader的预加载设置:(会在模型训练的时候加载数据,提高一点点gpu利用率)
train_loader = torch.utils.data.DataLoader( train_dataset, pin_memory=True, )
6网络设计很重要,外加不要初始化任何用不到的变量,因为pyroch的初始化和forward是分开的,他不会因为你不去使用,而不去初始化。
7最后放一张目前依旧困扰我的图片:
可以看到,每个epoch刚开始训练数据的时候,第一个iteration时间会占用的非常多,pytorch这里就做的很糟糕,并不是一个动态分配的过程,我也看到了一个看上去比较靠谱的解决方案,解决方案如下 @风车车

在深度学习中喂饱gpu

https://zhuanlan.zhihu.com/p/77633542

但是我看了下代码,可能需要重构dataloader,看了评论好像还有问题,有点懒,目前还没有踩坑,准备后面有时间踩一下。
暂且更新到这里,后续遇到什么坑陆续补充,也欢迎大家给我补充,pytorch初学者小白一枚。

更个新;顺便吐槽一下上面的dali,局限性很大,比较trick的数据预处理很难搞定。
8 apex混合单精度模型
事实证明,apex并没有官网说的那么玄乎,只能减低显存,并不能提速(12G显存大概可以降低到8G左右,效果还挺明显的,但是,速度降低了大概1/3,好像有点得不偿失)。
编译之后提速也很有限,再此留个坑,有小伙伴能解决的可以私信我哈,如果可以解决我会仔细罗列一遍。。



好消息,小白学视觉团队的知识星球开通啦,为了感谢大家的支持与厚爱,团队决定将价值149元的知识星球现时免费加入。各位小伙伴们要抓住机会哦!


下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 24
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报