pytorch:预训练权重、冻结训练和断点恢复
机器学习与生成对抗网络
共 4673字,需浏览 10分钟
·
2022-01-05 04:05
知乎—吵鸡凶鸭OvO 侵删
01
If I have seen further, it is by standing on the shoulders of giants.
02
# 第一步:读取当前模型参数
model_dict = model.state_dict()
# 第二步:读取预训练模型
pretrained_dict = torch.load(model_path, map_location = device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
# 第三步:使用预训练的模型更新当前模型参数
model_dict.update(pretrained_dict)
# 第四步:加载模型参数
model.load_state_dict(model_dict)
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
temp = {}
for k, v in pretrained_dict.items():
try:
if np.shape(model_dict[k]) == np.shape(v):
temp[k]=v
except:
pass
model_dict.update(temp)
03
# 冻结阶段训练参数,learning_rate和batch_size可以设置大一点
Init_Epoch = 0
Freeze_Epoch = 50
Freeze_batch_size = 8
Freeze_lr = 1e-3
# 解冻阶段训练参数,learning_rate和batch_size设置小一点
UnFreeze_Epoch = 100
Unfreeze_batch_size = 4
Unfreeze_lr = 1e-4
# 可以加一个变量控制是否进行冻结训练
Freeze_Train = True
# 冻结一部分进行训练
batch_size = Freeze_batch_size
lr = Freeze_lr
start_epoch = Init_Epoch
end_epoch = Freeze_Epoch
if Freeze_Train:
for param in model.backbone.parameters():
param.requires_grad = False
# 解冻后训练
batch_size = Unfreeze_batch_size
lr = Unfreeze_lr
start_epoch = Freeze_Epoch
end_epoch = UnFreeze_Epoch
if Freeze_Train:
for param in model.backbone.parameters():
param.requires_grad = True
04
torch.save(model.state_dict(), "你要保存到的路径")
05
猜您喜欢:
附下载 |《TensorFlow 2.0 深度学习算法实战》
评论