|
序列化和恢复模型有两种主要方法。 第一个(推荐)保存并仅加载模型参数:
- torch.save(the_model.state_dict(), PATH)
复制代码 然后:
- the_model = TheModelClass(*args, **kwargs)
- the_model.load_state_dict(torch.load(PATH))
复制代码 第二个保存并加载整个模型:
- torch.save(the_model, PATH)
复制代码 然后:
- the_model = torch.load(PATH)
复制代码 如果您需要继续训练您要保存的模型,则需要保存的不仅仅是模型。您还需要保存优化器,时期,分数等的状态。您可以这样做:
- state = {
- 'epoch': epoch,
- 'state_dict': model.state_dict(),
- 'optimizer': optimizer.state_dict(),
- ...
- }
- torch.save(state, filepath)
复制代码
|
|