pytorch模型的保存与加载

只保存网络的参数

Save

1
torch.save(model.state_dict(), PATH)

Load

1
model.load_state_dict(torch.load(PATH))

注意:

  • 模型的加载需要使用torch.load加载路径下的文件,逆序列化dict

保存全部的信息

Save

1
2
3
4
5
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
...
}, PATH)

Load

1
2
3
checkpoint = torch.load(PATH)
epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])

Reference

https://blog.csdn.net/dss_dssssd/article/details/89409183