pytorch模型的保存与加载
只保存网络的参数
Save
1 | torch.save(model.state_dict(), PATH) |
Load
1 | model.load_state_dict(torch.load(PATH)) |
注意:
- 模型的加载需要使用
torch.load
加载路径下的文件,逆序列化dict
保存全部的信息
Save
1 | torch.save({ |
Load
1 | checkpoint = torch.load(PATH) |
Reference
https://blog.csdn.net/dss_dssssd/article/details/89409183