pytorch保存与加载模型来测试或继续训练

蔚落 2022-12-27 09:22 351阅读 0赞

目录

  • 摘要
  • state_dict
  • 恢复训练实例
  • 加载部分预训练模型
  • 保存 & 加载模型 来inference
    • 保存/加载state_dict (推荐)
      • 保存:
      • 加载:
    • 保存/加载整个模型
      • 保存:
      • 加载:
    • 保存 & 加载一个通用Checkpoint来做测试或恢复训练
      • 保存:
      • 加载:
    • 加载不同模型的参数warmstarting
      • 保存:
      • 加载:
  • 跨设备保存/加载模型(CPU与GPU)
    • 模型保存在GPU上,加载到CPU
      • 保存
      • 加载:
    • 模型保存在GPU上,加载到GPU
      • 保存:
      • 加载:

摘要

pytorch中与保存和加载模型有关函数有三个:
1.torch.save:将序列化的对象保存到磁盘。此函数使用Python的pickle实用程序进行序列化。使用此功能可以保存各种对象的模型,张量和字典。

  1. torch.load:使用pickle的unpickle工具将pickle的对象文件反序列化到内存中。即加载save保存的东西。
  2. torch.nn.Module.load_state_dict:使用反序列化的state_dict加载模型的参数字典。注意,这意味着它的传入的参数应该是一个state_dict类型,也就torch.load加载出来的。

state_dict

stat_dict是一个字典,该字典包含model每一层的tensor类型的可学习参数。只有包含可学习参数的网络层才能将其参数映射到state_dict字典中,此外,stat_dict也包含优化器的state和超参数。

官网给的一个示例:

  1. # Define model
  2. class TheModelClass(nn.Module):
  3. def __init__(self):
  4. super(TheModelClass, self).__init__()
  5. self.conv1 = nn.Conv2d(3, 6, 5)
  6. self.pool = nn.MaxPool2d(2, 2)
  7. self.conv2 = nn.Conv2d(6, 16, 5)
  8. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  9. self.fc2 = nn.Linear(120, 84)
  10. self.fc3 = nn.Linear(84, 10)
  11. def forward(self, x):
  12. x = self.pool(F.relu(self.conv1(x)))
  13. x = self.pool(F.relu(self.conv2(x)))
  14. x = x.view(-1, 16 * 5 * 5)
  15. x = F.relu(self.fc1(x))
  16. x = F.relu(self.fc2(x))
  17. x = self.fc3(x)
  18. return x
  19. # Initialize model
  20. model = TheModelClass()
  21. # Initialize optimizer
  22. optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  23. # Print model's state_dict
  24. print("Model's state_dict:")
  25. for param_tensor in model.state_dict():
  26. print(param_tensor, "\t", model.state_dict()[param_tensor].size())
  27. # Print optimizer's state_dict
  28. print("Optimizer's state_dict:")
  29. for var_name in optimizer.state_dict():
  30. print(var_name, "\t", optimizer.state_dict()[var_name])

output:

  1. Model's state_dict:
  2. conv1.weight torch.Size([6, 3, 5, 5])
  3. conv1.bias torch.Size([6])
  4. conv2.weight torch.Size([16, 6, 5, 5])
  5. conv2.bias torch.Size([16])
  6. fc1.weight torch.Size([120, 400])
  7. fc1.bias torch.Size([120])
  8. fc2.weight torch.Size([84, 120])
  9. fc2.bias torch.Size([84])
  10. fc3.weight torch.Size([10, 84])
  11. fc3.bias torch.Size([10])
  12. Optimizer's state_dict:
  13. state { }
  14. param_groups [{ 'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

恢复训练实例

保存模型和加载模型的函数如下

  1. def save_checkpoint_state(dir,epoch,model,optimizer):
  2. #保存模型
  3. checkpoint = {
  4. 'epoch': epoch,
  5. 'model_state_dict': model.state_dict(),
  6. 'optimizer_state_dict': optimizer.state_dict(),
  7. }
  8. if not os.path.isdir(dir):
  9. os.mkdir(dir)
  10. torch.save(checkpoint, os.path.join(dir,'checkpoint-epoch%d.tar'%(epoch)))
  11. def get_checkpoint_state(dir,ckp_name,device,model,optimizer):
  12. # 恢复上次的训练状态
  13. print("Resume from checkpoint...")
  14. checkpoint = torch.load(os.path.join(dir,ckp_name),map_location=device)
  15. model.load_state_dict(checkpoint['model_state_dict'])
  16. epoch=checkpoint['epoch']
  17. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  18. #scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
  19. print('sucessfully recover from the last state')
  20. return model,epoch,optimizer

如果加入了lr_scheduler,那么lr_scheduler的state_dict也要加进来。

使用时:

  1. # 引用包省略
  2. #保持模型函数
  3. def save_checkpoint_state(epoch, model, optimizer, scheduler, running_loss):
  4. checkpoint = {
  5. "epoch": epoch,
  6. "model_state_dict": model.state_dict(),
  7. "optimizer_state_dict": optimizer.state_dict(),
  8. "scheduler_state_dict": scheduler.state_dict()
  9. }
  10. torch.save(checkpoint, "checkpoint-epoch%d-loss%d.tar" % (epoch, running_loss))
  11. # 加载模型函数
  12. def load_checkpoint_state(path, device, model, optimizer, scheduler):
  13. checkpoint = torch.load(path, map_location=device)
  14. model.load_state_dict(checkpoint["model_state_dict"])
  15. epoch = checkpoint["epoch"]
  16. optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
  17. scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
  18. return model, epoch, optimizer, scheduler
  19. # 是否恢复训练
  20. resume = False # True
  21. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  22. def train():
  23. trans = transforms.Compose([
  24. transforms.ToPILImage(),
  25. transforms.RandomResizedCrop(512),
  26. transforms.RandomHorizontalFlip(),
  27. transforms.RandomVerticalFlip(),
  28. transforms.RandomRotation(90),
  29. transforms.ToTensor(),
  30. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  31. ])
  32. # get training dataset
  33. leafDiseaseCLS = CustomDataSet(images_path, is_to_ls, trans)
  34. data_loader = DataLoader(leafDiseaseCLS,
  35. batch_size=16,
  36. num_workers=0,
  37. shuffle=True,
  38. pin_memory=False)
  39. # get model
  40. model = EfficientNet.from_pretrained("efficientnet-b3")
  41. # extract the parameter of fully connected layer
  42. fc_features = model._fc.in_features
  43. # modify the number of classes
  44. model._fc = nn.Linear(fc_features, 5)
  45. model.to(device)
  46. # optimizer
  47. optimizer = optim.SGD(model.parameters(),
  48. lr=0.001,
  49. momentum=0.9,
  50. weight_decay=5e-4)
  51. scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[6, 10], gamma=1/3.)
  52. # loss
  53. #loss_func = nn.CrossEntropyLoss()
  54. loss_func = FocalCosineLoss()
  55. start_epoch = -1
  56. if resume:
  57. model, start_epoch, optimizer,scheduler = load_checkpoint_state("../path/to/checkpoint.tar",
  58. device,
  59. model,
  60. optimizer,
  61. scheduler)
  62. model.train()
  63. epochs = 3
  64. for epoch in range(start_epoch + 1, epochs):
  65. running_loss = 0.0
  66. print("Epoch {}/{}".format(epoch, epochs))
  67. for step, train_data in tqdm(enumerate(data_loader)):
  68. x_train, y_train = train_data
  69. x_train = Variable(x_train.to(device))
  70. y_train = Variable(y_train.to(device))
  71. # forward
  72. prediction = model(x_train)
  73. optimizer.zero_grad()
  74. loss = loss_func(prediction, y_train)
  75. running_loss += loss.item()
  76. # backward
  77. loss.backward()
  78. optimizer.step()
  79. scheduler.step()
  80. # saving model
  81. torch.save(model.state_dict(), str(int(running_loss)) + "_" + str(epoch) + ".pth")
  82. save_checkpoint_state(epoch, model, optimizer, scheduler, running_loss)
  83. print("Loss:{}".format(running_loss))
  84. if __name__ == "__main__":
  85. train()

加载部分预训练模型

大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

  1. pretrained_dict = torch.load("model_data/yolo_weights.pth", map_location=device)
  2. model_dict = model.state_dict()
  3. # 将 pretrained_dict 里不属于 model_dict 的键剔除掉
  4. pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict}
  5. #pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
  6. # 更新现有的 model_dict
  7. model_dict.update(pretrained_dict)
  8. # 加载我们真正需要的 state_dict
  9. model.load_state_dict(model_dict)

保存 & 加载模型 来inference

保存/加载state_dict (推荐)

保存:

推荐仅仅保存模型的state_dict,保存的时候文件类型可以是.pt或.pth

  1. torch.save(model.state_dict(), PATH)

加载:

在保存模型进行推理时,只需保存已训练模型的学习参数。

  1. model = TheModelClass(*args, **kwargs)
  2. model.load_state_dict(torch.load(PATH))
  3. model.eval()

注意:在测试前必须使用model.eval()把dropout和batch normalization设为测试模式。

保存/加载整个模型

保存:

  1. torch.save(model, PATH)

加载:

  1. # Model class must be defined somewhere
  2. model = torch.load(PATH)
  3. model.eval()

保存 & 加载一个通用Checkpoint来做测试或恢复训练

保存:

  1. torch.save({
  2. 'epoch': epoch,
  3. 'model_state_dict': model.state_dict(),
  4. 'optimizer_state_dict': optimizer.state_dict(),
  5. 'loss': loss,
  6. ...
  7. }, PATH)

加载:

  1. model = TheModelClass(*args, **kwargs)
  2. optimizer = TheOptimizerClass(*args, **kwargs)
  3. checkpoint = torch.load(PATH)
  4. model.load_state_dict(checkpoint['model_state_dict'])
  5. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  6. epoch = checkpoint['epoch']
  7. loss = checkpoint['loss']
  8. model.eval()
  9. # - or -
  10. model.train()

保存用于检查或继续训练的checkpoint时,不仅要保存模型的state_dict,还要保存优化器的state_dict,因为它包含随着模型训练而更新的缓冲区和参数。其他项目需要保存的还有中断训练时的epoch,最新记录的训练损失,外部torch.nn.Embedding层等。
一般使用字典方式保存这些不同部分,然后使用torch.save()序列化字典。PyTorch约定是使用.tar文件扩展名保存这样的checkpoint,并且文件缀名用.tar。这种方式保存的话比单独保存模型文件大2-3倍。
要加载项目,请首先初始化模型和优化器,然后使用torch.load()加载本地的字典(checkpoint)。
如果做inference,必须调用model.eval()来将dropout和batch normalization层设置为评估模式。不这样做将产生不一致的推断结果。如果希望恢复训练,调用model.train()以确保这些层处于训练模式。

加载不同模型的参数warmstarting

保存:

  1. torch.save(modelA.state_dict(), PATH)

加载:

  1. modelB = TheModelBClass(*args, **kwargs)
  2. modelB.load_state_dict(torch.load(PATH), strict=False)

在转移学习或训练新的复杂模型时,部分加载模型或加载部分模型是常见方案。利用经过训练的参数,即使只有少数几个可用的参数,也将有助于热启动训练过程,并希望与从头开始训练相比,可以更快地收敛模型。
无论是从缺少部分key的state_dict加载,还是要使用比要加载的模型有更多的key的state_dict加载都行,只需要在load_state_dict()函数中将strict参数设置为False,以忽略不匹配项键。
如果要将参数从一层加载到另一层,但是某些key不匹配,只需更改要加载的state_dict中参数键的名称,以匹配要加载到的模型中的键。

跨设备保存/加载模型(CPU与GPU)

模型保存在GPU上,加载到CPU

保存

  1. torch.save(model.state_dict(), PATH)

加载:

  1. device = torch.device('cpu')
  2. model = TheModelClass(*args, **kwargs)
  3. model.load_state_dict(torch.load(PATH, map_location=device))

模型保存在GPU上,加载到GPU

保存:

  1. torch.save(model.state_dict(), PATH)

加载:

  1. device = torch.device("cuda")
  2. model = TheModelClass(*args, **kwargs)
  3. model.load_state_dict(torch.load(PATH))
  4. model.to(device)
  5. # Make sure to call input = input.to(device) on any input tensors that you feed to the model

一定要使用.to(torch.device(‘cuda’))将所有输入模型的数据转到GPU上。请注意,调用my_tensor.to(device)会在GPU上返回my_tensor的新副本。它不会覆盖my_tensor。因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device(‘cuda’))。

当然还有CPU上训练GPU来加载的,但这种情况较少,就不放操作了.

内容来自pytorch官网。
要有看官网文档的心,打破畏难情绪,然后回发现看doc真的还挺简单。

发表评论

表情:
评论列表 (有 0 条评论,351人围观)

还没有评论,来说两句吧...

相关阅读

    相关 PyTorch模型保存

    PyTorch模型保存与加载 在利用PyTorch构建深度学习模型时,模型的保存和加载是非常重要的一步。这不仅可以保证我们的模型得以长期保存和重复使用,还可以方便我们在不同的