Pytorch之模型加载/保存

た 入场券 2022-11-27 00:51 434阅读 0赞

pytorch保存模型有两种方法:

  1. 保存整个模型 (结构+参数)
  2. 只保存参数(官方推荐)

两者都是用torch.save(obj, dir)实现,这个函数的作用是将对象保存到磁盘中,它的内部是使用Python的pickle实现。
两种方法的区别其实就是obj参数的不同:前者的obj是整个model对象,后者的obj是从model里获取的存储了model参数的词典,推荐用第二种,虽然麻烦了一丁点,但是比较灵活,有利于实现预训练、参数迁移等操作。

保存整个模型

这种方法很简单,保存和加载就两行代码,和Python pickle包的用法是一样的,把model当作一个对象直接保存加载就行。

  1. # 保存
  2. model = Mymodel()
  3. torch.save(model, path)
  4. # 加载
  5. model = torch.load(path)

Note:PyTorch约定使用.pt或.pth后缀命名保存文件。

保存参数

重点介绍一下这种方法,一般训完一个模型之后我们不会单独只保存一个模型的参数,为了方便后续操作,比如恢复训、参数迁移等,我们会保存当前转态的一个快照,具体信息可以根据自己的需要,下面列出几个方面:

  • 模型参数
  • 优化器参数
  • loss
  • epoch
  • args

把这些信息用字典包装起来,然后保存即可。

这种方式保存的模型只是它的参数,所以我们在加载时需要先创建好模型,然后再把参数加载进去,如下:

  1. # 获得保存信息
  2. save_data = {
  3. 'model_state_dict': model.state_dict(),
  4. 'optimizer_state_dict': optimizer.state_dict(),
  5. 'loss': loss,
  6. 'epoch': epoch,
  7. 'args': args
  8. ...
  9. }
  10. # 保存
  11. torch.save(save_data , path)
  12. load_data = torch.load(path)
  13. model = Mymodel()
  14. optimizer = Myoptimizer()
  15. # 加载参数
  16. model.load_state_dict(load_data ['model_state_dict'])
  17. optimizer.load_state_dict(load_data ['optimizer_state_dict'])
  18. ...

Note:PyTorch约定使用.pt或.pth后缀命名保存文件。

发表评论

表情:
评论列表 (有 0 条评论,434人围观)

还没有评论,来说两句吧...

相关阅读

    相关 PyTorch模型保存

    PyTorch模型保存与加载 在利用PyTorch构建深度学习模型时,模型的保存和加载是非常重要的一步。这不仅可以保证我们的模型得以长期保存和重复使用,还可以方便我们在不同的