Pytorch mixed precision 概述(混合精度)
点击上方“程序员大白”,选择“星标”公众号
重磅干货,第一时间送达
01
import torchvision
import torch
import torch.cuda.amp
import gc
import time
# Timing utilities
start_time = None
def start_timer():
global start_time
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.synchronize() # 同步后得出的时间才是实际运行的时间
start_time = time.time()
def end_timer_and_print(local_msg):
torch.cuda.synchronize()
end_time = time.time()
print("\n" + local_msg)
print("Total execution time = {:.3f} sec".format(end_time - start_time))
print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))
num_batches = 50
batch_size = 70
epochs = 3
# 随机创建训练数据
data = [torch.randn(batch_size, 3, 224, 224, device="cuda") for _ in range(num_batches)]
targets = [torch.randint(0, 1000, size=(batch_size, ), device='cuda') for _ in range(num_batches)]
# 创建一个模型
net = torchvision.models.resnext50_32x4d().cuda()
# 定义损失函数
loss_fn = torch.nn.CrossEntropyLoss().cuda()
# 定义优化器
opt = torch.optim.SGD(net.parameters(), lr=0.001)
# 是否使用混合精度训练
use_amp = True
# Constructs scaler once, at the beginning of the convergence run, using default args.
# If your network fails to converge with default GradScaler args, please file an issue.
# The same GradScaler instance should be used for the entire convergence run.
# If you perform multiple convergence runs in the same script, each run should use
# a dedicated fresh GradScaler instance. GradScaler instances are lightweight.
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
start_timer()
for epoch in range(epochs):
for input, target in zip(data, targets):
with torch.cuda.amp.autocast(enabled=use_amp):
output = net(input)
loss = loss_fn(output, target)
# 放大loss Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(opt)
# Updates the scale for next iteration.
scaler.update()
opt.zero_grad(set_to_none=True) # set_to_none=True here can modestly improve performance
end_timer_and_print("Mixed precision:")
02
混合精度测试
推荐阅读
关于程序员大白
程序员大白是一群哈工大,东北大学,西湖大学和上海交通大学的硕士博士运营维护的号,大家乐于分享高质量文章,喜欢总结知识,欢迎关注[程序员大白],大家一起学习进步!
评论