查看: 1754|回复: 0

在PyTorch中保存训练模型的最佳方法

[复制链接]

27

主题

37

帖子

116

积分

论坛管理

Rank: 4

积分
116
发表于 2018-10-15 15:26:20 | 显示全部楼层 |阅读模式
序列化和恢复模型有两种主要方法。 第一个(推荐)保存并仅加载模型参数:
  1. torch.save(the_model.state_dict(), PATH)
复制代码
然后:
  1. the_model = TheModelClass(*args, **kwargs)
  2. the_model.load_state_dict(torch.load(PATH))
复制代码
第二个保存并加载整个模型:
  1. torch.save(the_model, PATH)
复制代码
然后:
  1. the_model = torch.load(PATH)
复制代码
如果您需要继续训练您要保存的模型,则需要保存的不仅仅是模型。您还需要保存优化器,时期,分数等的状态。您可以这样做:
  1. state = {
  2.     'epoch': epoch,
  3.     'state_dict': model.state_dict(),
  4.     'optimizer': optimizer.state_dict(),
  5.     ...
  6. }
  7. torch.save(state, filepath)
复制代码


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表