查看: 1448|回复: 1

pytorch之保存与加载模型

[复制链接]

665

主题

1234

帖子

6568

积分

xdtech

Rank: 5Rank: 5

积分
6568
发表于 2020-5-13 08:34:56 | 显示全部楼层 |阅读模式
pytorch与保存、加载模型有关的常用函数3个:
  • torch.save(): 保存一个序列化的对象到磁盘,使用的是Python的pickle库来实现的。
  • torch.load(): 解序列化一个pickled对象并加载到内存当中。
  • torch.nn.Module.load_state_dict(): 加载一个解序列化的state_dict对象
1. state_dict在PyTorch中所有可学习的参数保存在model.parameters()中。state_dict是一个Python字典。保存了各层与其参数张量之间的映射。torch.optim对象也有一个state_dict,它包含了optimizer的state,以及一些超参数。
2. 保存&加载模型来inference(recommended)save
torch.save(model.state_dict(), PATH)
load
model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH))model.eval()  # 当用于inference时不要忘记添加
  • 保存的文件名后缀可以是.pt或.pth
  • 当用于inference时不要忘记添加model.eval()
3. 保存&加载整个模型(not recommended)save
torch.save(model, PATH)
load
# Model class must be defined somewheremodel = torch.load()model.eval()
4. 保存&加载带checkpoint的模型用于inference或resuming trainingsave
torch.save({  'epoch': epoch,  'model_state_dict': model.state_dict(),  'optimizer_state_dict': optimizer.state_dict(),  'loss': loss,  ...  }, PATH)
load
model = TheModelClass(*args, **kwargs)optimizer = TheOptimizerClass(*args, **kwargs)checkpoint = torch.load(PATH)model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])epoch = checkpoint['epoch']loss = checkpoint['loss']model.eval()# ormodel.train()
5. 保存多个模型到一个文件中save
torch.save({  'modelA_state_dict': modelA.state_dict(),  'modelB_state_dict': modelB.state_dict(),  'optimizerA_state_dict': optimizerA.state_dict(),  'optimizerB_state_dict': optimizerB.state_dict(),  ...  }, PATH)
load
modelA = TheModelAClass(*args, **kwargs)modelB = TheModelAClass(*args, **kwargs)optimizerA = TheOptimizerAClass(*args, **kwargs)optimizerB = TheOptimizerBClass(*args, **kwargs)checkpoint = torch.load(PATH)modelA.load_state_dict(checkpoint['modelA_state_dict']modelB.load_state_dict(checkpoint['modelB_state_dict']optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']modelA.eval()modelB.eval()# ormodelA.train()modelB.train()
  • 此情况可能在GAN,Sequence-to-sequence,或ensemble models中使用
  • 保存checkpoint常用.tar文件扩展名
6. Warmstarting Model Using Parameters From A Different Modelsave
torch.save(modelA.state_dict(), PATH)
load
modelB = TheModelBClass(*args, **kwargs)modelB.load_state_dict(torch.load(PATH), strict=False)
  • 在迁移训练时,可能希望只加载部分模型参数,此时可置strict参数为False来忽略那些没有匹配到的keys
7. 保存&加载模型跨设备(1) Save on GPU, Load on CPU
save
torch.save(model.state_dict(), PATH)
load
device = torch.device("cpu")model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH, map_location=device))
(2) Save on GPU, Load on GPU
save
torch.save(model.state_dict(), PATH)
load
device = torch.device("cuda")model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH))model.to(device)
(3) Save on CPU, Load on GPU
save
torch.save(model.state_dict(), PATH)
load
device = torch.device("cuda")model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH, map_location="cuda:0"))model.to(device)
8. 保存torch.nn.DataParallel模型save
torch.save(model.module.state_dict(), PATH)
load
# Load to whatever device you want





回复

使用道具 举报

665

主题

1234

帖子

6568

积分

xdtech

Rank: 5Rank: 5

积分
6568
 楼主| 发表于 2020-5-13 08:35:06 | 显示全部楼层

作者:zhaoQiang012
链接:https://www.jianshu.com/p/60fc57e19615
来源:简书
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表