Hugging Face发布PyTorch新库「Accelerate」:适用于多GPU、TPU、混合精度训练
极市导读
多数 PyTorch 高级库都支持分布式训练和混合精度训练,但是它们引入的抽象化往往需要用户学习新的 API 来定制训练循环。许多 PyTorch 用户希望完全控制自己的训练循环,但不想编写和维护训练所需的样板代码。Hugging Face 最近发布的新库 Accelerate 解决了这个问题。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
import torch
import torch.nn.functional as F
from datasets import load_dataset
+ from accelerate import Accelerator
+ accelerator = Accelerator()
- device = 'cpu'
+ device = accelerator.device
model = torch.nn.Transformer().to(device)
optim = torch.optim.Adam(model.parameters())
dataset = load_dataset('my_dataset')
data = torch.utils.data.DataLoader(dataset, shuffle=True)
+ model, optim, data = accelerator.prepare(model, optim, data)
model.train()
for epoch in range(10):
for source, targets in data:
source = source.to(device)
targets = targets.to(device)
optimizer.zero_grad()
output = model(source)
loss = F.cross_entropy(output, targets)
+ accelerator.backward(loss)
- loss.backward()
optimizer.step()
import torch
import torch.nn.functional as F
from datasets import load_dataset
+ from accelerate import Accelerator
+ accelerator = Accelerator()
- device = 'cpu'
+ model = torch.nn.Transformer()
- model = torch.nn.Transformer().to(device)
optim = torch.optim.Adam(model.parameters())
dataset = load_dataset('my_dataset')
data = torch.utils.data.DataLoader(dataset, shuffle=True)
+ model, optim, data = accelerator.prepare(model, optim, data)
model.train()
for epoch in range(10):
for source, targets in data:
- source = source.to(device)
- targets = targets.to(device)
optimizer.zero_grad()
output = model(source)
loss = F.cross_entropy(output, targets)
+ accelerator.backward(loss)
- loss.backward()
optimizer.step()
accelerate config
accelerate launch my_script.py --args_to_my_script
Accelerate 的运作原理
accelerator = Accelerator()
model, optim, data = accelerator.prepare(model, optim, data)
模型
优化器
数据加载器
accelerator.backward(loss)
CPU
单 GPU
单一节点多 GPU
多节点多 GPU
TPU
带有本地 AMP 的 FP16(路线图上的顶点)
推荐阅读
2021-04-21
2021-04-20
2021-04-19
# CV技术社群邀请函 #
备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)
即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群
每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~
评论