本文共 1063 字,大约阅读时间需要 3 分钟。
pytorch当中的序列化与反序列化函数:
torch.save
torch.load
有两种保存模型的方法:
参数保存:
checkpoint_interval = 5"""省略了模型实现代码,checkpoint保存参数部分应该在模型训练部分,每个epoch中"""if (epoch+1) % checkpoint_interval == 0: checkpoint = { "model_state_dict": net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch} path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch) torch.save(checkpoint, path_checkpoint)
checkpoint加载恢复:
保持原实现代码不变,在模型训练部分之前,加入断电恢复代码,如下所示:path_checkpoint = "./checkpoint_4_epoch.pkl"checkpoint = torch.load(path_checkpoint)net.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])start_epoch = checkpoint['epoch']scheduler.last_epoch = start_epoch
转载地址:http://ihobb.baihongyu.com/