使用 PyTorch 进行分布式训练
AI算法与图像处理
共 6607字,需浏览 14分钟
·
2021-07-10 05:48
点击下方“AI算法与图像处理”,一起进步!
重磅干货,第一时间送达
size:进行训练的 GPU 设备的数量
rank:对GPU设备有一个序列的id号
# Download and initialize MNIST train dataset
train_dataset = datasets.MNIST('./mnist_data',
download=True,
train=True)
# Wrap train dataset into DataLoader
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True)
# Download and initialize MNIST train dataset
train_dataset = datasets.MNIST('./mnist_data',
download=True,
train=True,
transform=transform)
# Create distributed sampler pinned to rank
sampler = DistributedSampler(train_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True, # May be True
seed=42)
# Wrap train dataset into DataLoader
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
shuffle=False, # Must be False!
num_workers=4,
sampler=sampler,
pin_memory=True)
def create_model():
model = nn.Sequential(
nn.Linear(28*28, 128), # MNIST images are 28x28 pixels
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 10, bias=False) # 10 classes to predict
)
return model
# Initialize the model
model = create_model()
# Initialize the model
model = create_model()
# Create CUDA device
device = torch.device(f'cuda:{rank}')
# Send model parameters to the device
model = model.to(device)
# Wrap the model in DDP wrapper
model = DistributedDataParallel(model, device_ids=[rank], output_device=rank)
for i in range(epochs):
for x, y in train_loader:
# do the training
...
for i in range(epochs):
train_loader.sampler.set_epoch(i)
for x, y in train_loader:
# do the training
...
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
rank = args.local_rank
if rank == 0:
torch.save(model.module.state_dict(), 'model.pt')
python -m torch.distributed.launch --nproc_per_node=4
ddp_tutorial_multi_gpu.py
努力分享优质的计算机视觉相关内容,欢迎关注:
个人微信(如果没有备注不拉群!) 请注明:地区+学校/企业+研究方向+昵称
下载1:何恺明顶会分享
在「AI算法与图像处理」公众号后台回复:何恺明,即可下载。总共有6份PDF,涉及 ResNet、Mask RCNN等经典工作的总结分析
下载2:终身受益的编程指南:Google编程风格指南
在「AI算法与图像处理」公众号后台回复:c++,即可下载。历经十年考验,最权威的编程规范!
下载3 CVPR2021
在「AI算法与图像处理」公众号后台回复:CVPR,即可下载1467篇CVPR 2020论文 和 CVPR 2021 最新论文
点亮 ,告诉大家你也在看
评论