torch.load的坑
使用torch.load的时候遇见两个坑第一个是遇到如下错误:
ModuleNotFoundError: No module named 'models'
1
官方说明如下:the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.
也就是说,当使用这个函数的时候,pytorch序列化的是参数以及model class的路径。所以使用这个函数之前,必须保证定义model的文件目录结构相同。一般来说,model.py只要和要运行的py文件在同个目录就不会报这个错。
第二个是如果使用GPU,那么load的时候会默认将模型放置到同样编号的GPU上去。例如训练的时候使用的是cuda:0,那么使用torch.load之后,模型会默认放置到0号GPU上去,这时候,即使再使用to(“cuda:1”),模型仍然会占用0号GPU的显存。
如果要完全迁移到1号GPU,则应该使用:
torch.load("....pth", map_location={'cuda:0':"cuda:1"})
1
这样模型便不会再占用0号GPU的显存
https://blog.csdn.net/j___t/java/article/details/99618915
页:
[1]