查看: 2887|回复: 1

torch.load的坑

[复制链接]

665

主题

1234

帖子

6683

积分

xdtech

Rank: 5Rank: 5

积分
6683
发表于 2020-5-13 08:34:14 | 显示全部楼层 |阅读模式
使用torch.load的时候遇见两个坑
第一个是遇到如下错误:

ModuleNotFoundError: No module named 'models'
1
官方说明如下:the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.
也就是说,当使用这个函数的时候,pytorch序列化的是参数以及model class的路径。所以使用这个函数之前,必须保证定义model的文件目录结构相同。一般来说,model.py只要和要运行的py文件在同个目录就不会报这个错。

第二个是如果使用GPU,那么load的时候会默认将模型放置到同样编号的GPU上去。例如训练的时候使用的是cuda:0,那么使用torch.load之后,模型会默认放置到0号GPU上去,这时候,即使再使用to(“cuda:1”),模型仍然会占用0号GPU的显存。

如果要完全迁移到1号GPU,则应该使用:

torch.load("....pth", map_location={'cuda:0':"cuda:1"})
1
这样模型便不会再占用0号GPU的显存


回复

使用道具 举报

665

主题

1234

帖子

6683

积分

xdtech

Rank: 5Rank: 5

积分
6683
 楼主| 发表于 2020-5-13 08:34:20 | 显示全部楼层
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表