Pytorch 数据流中常见Trick总结

共 3744字,需浏览 8分钟

 ·

2021-12-25 00:56

点击下方AI算法与图像处理”,一起进步!

重磅干货,第一时间送达

仅作学术分享,不代表本公众号立场,侵权联系删除
转载于:作者丨zlhroughlove@知乎
来源丨https://zhuanlan.zhihu.com/p/441317369
编辑丨极市平台

前言

在使用Pytorch建模时,常见的流程为先写Model,再写Dataset,最后写Trainer。Dataset 是整个项目开发中投入时间第二多,也是中间关键的步骤。往往需要事先对于其设计有明确的思考,不然可能会因为Dataset的一些问题又要去调整Model,Trainer。本文将目前开发中的一些思考以及遇到的问题做一个总结,提供给各位读者一个比较通用的模版,抛砖引玉~

一、Dataset的定义

from torch.utils.data import Dataset, DataLoader, RandomSampler

对于不同类型的建模任务,模型的输入各不相同。自然语言,多模态,点击率预估,往往这些场景输入模型的数据并不是来自于单一文件,而且可能无法全部存入内存。Dataset需要整合项目的数据,对于单条样本涉及到的数据做一个提取与归纳。不但如此,项目可能还涉及到多种模型,任务的训练。Dataset需要为不同的模型以及训练任务提供不同的单条样本输入,作为一个数据生成器,把后续模型训练任务需要的所有基础数据,标签全返回了。所以往往我们可以定义一个BaseDataset类,继承torch.utils.data.Dataset,这个类可以初始化一些文件路径,配置等。后面不同的模型训练任务定义相应的Dataset类继承BaseDataset。

Dataset通用的结构为:

class BaseDataset(Dataset):

    def __init__(self, config):
        self.config = config
        if os.path.isfile(config.file_path) is False:
            raise ValueError(f"Input file path {config.file_path} not found")
        logger.info(f"Creating features from dataset file at {config.file_path}")
        # 一次性全读进内存
        self.data = joblib.load(config.file_path)
        self.nums = len(self.data)

    def __len__(self):
        return self.nums

    def __getitem__(self, i) -> Dict[str, tensor]:
        sample_i = self.data[i]
        return {"f1":torch.tensor(sample_i["f1"]).long(),"f2":torch.tensor(sample_i["f2"]).long(),torch.LongTensor([sample_i["label"]])}

如果无法全部读取进内存需要再__getitem__方法内构建数据,做自然语言则可以吧tokenizer初始化到该类中,在__getitem__方法内完成tokenizer。改方法的输出推荐做成字典形式。

对于不同的训练任务可以通过以下方法返回响应的数据生成器

def build_dataset(task_type, features, **kwargs):
    assert task_type in ['task1''task2'], 'task mismatch'

    if task_type == 'task1':
        dataset = task1Dataset(features))
    else:
        dataset = task2Dataset(features)

    return dataset

有时模型的训练任务需要做数据增强,对比学习,构造多种的预训练任务输入。Dataset的职能边界是提供一套基础的单样本数据输入生成器。如果是MLM任务,可以在Dataset内生成maskposition以及label。如果是在batch内的对比学习则应该在DataLoader生产batch数据后再进行。

二、DataLoader的定义

DataLoader的作用是对Dataset进行多进程高效地构建每个训练批次的数据。传入的数据可以认为是长度为batch大小的多个__getitem__ 方法返回的字典list。DataLoader的职能边界是根据Dataset提供的单条样本数据有选择的构建一个batch的模型输入数据。

其通常的结构为对Train,Valid,Test分别建立:

train_sampler = RandomSampler(train_dataset)
train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.train_batch_size,
                              sampler=train_sampler,
                              shuffle=(train_sampler is None)
                              collate_fn=None, # 一般不用设置
                              num_workers=4)

首先对于sampler 还有一种定义方式:

sampler = torch.utils.data.distributed.DistributedSampler(dataset)

至于batch内数据是否需要做shuffle也需要根据损失函数确定(对比学习慎用)

DataLoader会自动合并__getitem__ 方法返回的字典内每个key内每个tensor,在tensor的第0维度新增一个batch大小的维度。如果该方法返回的每条样本长度不同无法拼接,batchsize>1就会报错。但是又一些任务在还没有确定后续的批样本对应的任务时,Dataset可能返回的字典里每个key可能就是长度不同的tensor,甚至是list,这时候需要使用collate_fn参数告诉DataLoader如何取样。我们可以定义自己的函数来准确地实现想要的功能。

如果__getitem__方法返回的是tuple((list, list)) 可以使用:

def merge_sample(x):
    return zip(*x)

train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.train_batch_size,
                              sampler=train_sampler,
                              shuffle=(train_sampler is None)
                              collate_fn=merge_sample,
                              num_workers=4)

拼接数据,后续再做进一步处理。(此时list内数据还是不等长,无法转为tensor)

如果__getitem_方法返回的是Dict[str,tensor],自定义的collate_fn方法内需要实现:List[Dict[str,tensor(xx)]]->Dict[str,tensor(bs,xx)]的操作,pad_sequence过程也可以在自定义方法内实现。(总之collate_fn中不但可以处理不等长数据,还可以对一个batch的数据做精修。当然也可以在DataLoader之后再做修改batch内的数据。)

值得注意的是在cpu环境下,如果要自定义collate_fn,num_workers必须设置为0,不然就会有问题..

通过以下方式可以检查一下输入后续模型的数据是否已经是想要的格式

for step, batch_data in enumerate(train_loader):
    if step < 1:
        print(batch_data)
    else:
        break

之后数据将数据放入gpu device, 一个batch的数据进入device端后就与内存上的数据不再互相干扰。之后数据就可以喂给模型了:

for key in batch_data.keys():
    batch_data[key] = batch_data[key].to(device)
loss = model(**batch_data)



努力分享优质的计算机视觉相关内容,欢迎关注:

交流群


欢迎加入公众号读者群一起和同行交流,目前有美颜、三维视觉计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群


个人微信(如果没有备注不拉群!
请注明:地区+学校/企业+研究方向+昵称



下载1:何恺明顶会分享


AI算法与图像处理」公众号后台回复:何恺明,即可下载。总共有6份PDF,涉及 ResNet、Mask RCNN等经典工作的总结分析


下载2:终身受益的编程指南:Google编程风格指南


AI算法与图像处理」公众号后台回复:c++,即可下载。历经十年考验,最权威的编程规范!



下载3 CVPR2021

AI算法与图像处公众号后台回复:CVPR即可下载1467篇CVPR 2020论文 和 CVPR 2021 最新论文


浏览 37
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报