1. 网络模型定义与模型参数保存
定义网络模型与基本参数,以及模型训练和模型保存
使用torch.save()方法保存模型
在save_dict={}中可以保存epoch,model,optimizer,scheduler,loss等参数。
my_net = VisionTransformer()
n_epoch = 200
lr = 0.001
optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-6)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epoch, eta_min=lr / 100)
loss_classification = torch.nn.CrossEntropyLoss()
if cuda:
my_net = my_net.cuda()
loss_classification = loss_classification.cuda()
for p in my_net.parameters():
p.requires_grad = True
bestacc = 0.0
savepth = 'mySavepthPath'
for epoch in range(n_epoch):
my_net.train()
....
if acc > bestacc:
save_dict = {
'epoch': epoch,
'model': my_net.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(save_dict, savepth + '.pth')
2. 加载模型继续训练
使用torch.load加载模型,完整代码如下。
要注意的是,要先定义模型和优化器optimizer,把模型放到gpu上,然后再加载模型。
否则执行optimizer.step()时会出现下面这个错误。
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
my_net = VisionTransformer()
n_epoch = 200
lr = 0.001
optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-6)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epoch, eta_min=lr / 100)
loss_classification = torch.nn.CrossEntropyLoss()
if cuda:
my_net = my_net.cuda()
loss_classification = loss_classification.cuda()
Resume = True
start_epoch = -1
if Resume:
path_checkpoint = 'mySavepthPath.pth'
checkpoint = torch.load(path_checkpoint, map_location=torch.device('cuda'))
my_net.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
print("start_epoch:", start_epoch)
print('-----------------------------')
for p in my_net.parameters():
p.requires_grad = True
bestacc = 0.0
savepth = 'mySavepthPath'
new_start = 0 if start_epoch == -1 else start_epoch
for epoch in range(start_epoch + 1, new_start+n_epoch):
my_net.train()
....
if acc > bestacc:
save_dict = {
'epoch': epoch,
'model': my_net.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(save_dict, savepth + '.pth')