pytorch保存模型非常简单,主要有两种方法:
- 只保存参数;(官方推荐)
- 保存整个模型 (结构 参数)。 由于保存整个模型需要大量的存储,官方建议只保存参数,然后在构建模型的基础上加载。本文介绍了两种方法,但只详细说明了第一种方法。
-
只保存参数
1.保存
一般来说,参数可以通过一个句子来保存:
-
torch.save(model.state_dict(), path)
其中model定义模型,如 model=vgg16( ), path是保存参数的路径,如 path='./model.pth' , path='./model.tar', path='./model.pkl', 保存参数的文件必须有后缀扩展名。
特别是,如果你想保存某个训练中使用的优化器,epochs这些信息可以组合成字典,然后保存字典:
-
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch} torch.save(state, path)
2.加载
对于上述第一种情况,加载模型只需一句话:
model.load_state_dict(torch.load(path))
对于以字典形式保存的第二种方法,加载方法如下:
checkpoint = torch.load(path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) epoch = checkpoint(['epoch'])
需要注意的是,只保存参数的方法应在加载时提前定义与原模型一致的模型,并在模型的实例对象中(假设称为model)在使用上述加载语句之前,已经定义了与原模型相同的加载语句Net, 并实例化 model=Net( ) 。
另外,如果每一个epoch或每n个epoch应保存一次参数,并可设置不同的参数path,如 path='./model' str(epoch) '.pth这样,就不一样了epoch参数可以保存在不同的文件中,保存识别率最高的模型参数也可以选择,只需在保存模型句子之前添加一个if判断句子。
以下是保存最新参数的具体例子:
#-*- coding:utf-8 -*- 本文件用于举例说明pytorch保存和加载文件的方法 __author__ = 'puxitong from UESTC' import torch as torch import torchvision as tv import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision.transforms as transforms from torchvision.transforms import ToPILImage import torch.backends.cudnn as cudnn import datetime import argparse # 参数声明 batch_size = 32 epochs = 10 WORKERS = 0 # dataloder线程数 test_flag = True #测试标志,True加载保存的模型进行测试 ROOT = '/home/pxt/pytorch/cifar' # MNIST数据集保存路径 log_dir = '/home/pxt/pytorch/logs/cifar_model.pth' # 模型保存路径 # 加载MNIST数据集 transform = tv.transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform) test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform) train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS) test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS) # 构造模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 64, 3, padding=1) self.conv2 = nn.Conv2d(64, 128, 3, padding=1) self.conv3 = nn.Conv2d(128, 256, 3, padding=1) self.conv4 = nn.Conv2d(256, 256, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(256 * 8 * 8, 1024) self.fc2 = nn.Linear(1024, 256) self.fc3 = nn.Linear(256, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = self.pool(F.relu(self.conv2(x))) x = F.relu(self.conv3(x)) x = self.pool(F.relu(self.conv4(x))) x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3]) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = Net().cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # 模型训练 def train(model, train_loader, epoch): model.train() train_loss = 0 for i, data in enumerate(train_loader, 0): x, y = data x = x.cuda() y = y.cuda() optimizer.zero_grad() y_hat = model(x) loss = criterion(y_hat, y) loss.backward() optimizer.step() train_loss = loss loss_mean = train_loss / (i 1) print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item())) # 模型测试 def test(model, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for i, data in enumerate(test_loader, 0): x, y = data x = x.cuda() y = y.cuda() optimizer.zero_grad() y_hat = model(x) test_loss = criterion(y_hat, y).item() pred = y_hat.max(1, keepdim=True)[1] correct = pred.eq(y.view_as(pred)).sum().item() test_loss /= (i 1) print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_data), 100. * correct / len(test_data))) def main(): # 如果test_flag=True,加载保存模型 if test_flag: # 直接验证加载保存模型,不执行本模块后续步骤 checkpoint = torch.load(log_dir) model.loadstate_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epochs = checkpoint['epoch']
test(model, test_load)
return
for epoch in range(0, epochs):
train(model, train_load, epoch)
test(model, test_load)
# 保存模型
state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
torch.save(state, log_dir)
if __name__ == '__main__':
main()
3.在加载的模型基础上继续训练
在训练模型的时候可能会因为一些问题导致程序中断,或者常常需要观察训练情况的变化来更改学习率等参数,这时候就需要加载中断前保存的模型,并在此基础上继续训练,这时候只需要对上例中的 main() 函数做相应的修改即可,修改后的 main() 函数如下:
def main():
# 如果test_flag=True,则加载已保存的模型
if test_flag:
# 加载保存的模型直接进行测试机验证,不进行此模块以后的步骤
checkpoint = torch.load(log_dir)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
test(model, test_load)
return
# 如果有保存的模型,则加载模型,并在其基础上继续训练
if os.path.exists(log_dir):
checkpoint = torch.load(log_dir)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
print('加载 epoch {} 成功!'.format(start_epoch))
else:
start_epoch = 0
print('无保存模型,将从头开始训练!')
for epoch in range(start_epoch+1, epochs):
train(model, train_load, epoch)
test(model, test_load)
# 保存模型
state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
torch.save(state, log_dir)
二、保存整个模型 1.保存
torch.save(model, path)
2.加载
model = torch.load(path)
用法可参照上例。
这篇博客是一个快速上手指南,想深入了解PyTorch保存和加载模型中的相关函数和方法,请移步这篇博客:
PyTorch模型保存深入理解 - 简书 (jianshu.com)