资讯详情

Pytorch模型保存与加载,并在加载的模型基础上继续训练

pytorch保存模型非常简单,主要有两种方法:

  1. 只保存参数;(官方推荐)
  2. 保存整个模型 (结构 参数)。 由于保存整个模型需要大量的存储,官方建议只保存参数,然后在构建模型的基础上加载。本文介绍了两种方法,但只详细说明了第一种方法。
  • 只保存参数

    1.保存

    一般来说,参数可以通过一个句子来保存:

  • torch.save(model.state_dict(), path)

    其中model定义模型,如 model=vgg16( ), path是保存参数的路径,如 path='./model.pth' , path='./model.tar', path='./model.pkl', 保存参数的文件必须有后缀扩展名。

    特别是,如果你想保存某个训练中使用的优化器,epochs这些信息可以组合成字典,然后保存字典:

  • state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch} torch.save(state, path)

    2.加载

    对于上述第一种情况,加载模型只需一句话:

model.load_state_dict(torch.load(path))

对于以字典形式保存的第二种方法,加载方法如下:

checkpoint = torch.load(path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) epoch = checkpoint(['epoch'])

需要注意的是,只保存参数的方法应在加载时提前定义与原模型一致的模型,并在模型的实例对象中(假设称为model)在使用上述加载语句之前,已经定义了与原模型相同的加载语句Net, 并实例化 model=Net( ) 。

另外,如果每一个epoch或每n个epoch应保存一次参数,并可设置不同的参数path,如 path='./model' str(epoch) '.pth这样,就不一样了epoch参数可以保存在不同的文件中,保存识别率最高的模型参数也可以选择,只需在保存模型句子之前添加一个if判断句子。

以下是保存最新参数的具体例子:

#-*- coding:utf-8 -*-  本文件用于举例说明pytorch保存和加载文件的方法  __author__ = 'puxitong from UESTC'   import torch as torch import torchvision as tv import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision.transforms as transforms from torchvision.transforms import ToPILImage import torch.backends.cudnn as cudnn import datetime import argparse  # 参数声明 batch_size = 32 epochs = 10 WORKERS = 0   # dataloder线程数 test_flag = True  #测试标志,True加载保存的模型进行测试  ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径 log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径  # 加载MNIST数据集 transform = tv.transforms.Compose([         transforms.ToTensor(),         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])  train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform) test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)  train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS) test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)   # 构造模型 class Net(nn.Module):     def __init__(self):         super(Net, self).__init__()         self.conv1 = nn.Conv2d(3, 64, 3, padding=1)         self.conv2 = nn.Conv2d(64, 128, 3, padding=1)         self.conv3 = nn.Conv2d(128, 256, 3, padding=1)         self.conv4 = nn.Conv2d(256, 256, 3, padding=1)         self.pool = nn.MaxPool2d(2, 2)         self.fc1 = nn.Linear(256 * 8 * 8, 1024)         self.fc2 = nn.Linear(1024, 256)         self.fc3 = nn.Linear(256, 10)               def forward(self, x):         x = F.relu(self.conv1(x))         x = self.pool(F.relu(self.conv2(x)))         x = F.relu(self.conv3(x))         x = self.pool(F.relu(self.conv4(x)))         x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])         x = F.relu(self.fc1(x))         x = F.relu(self.fc2(x))         x = self.fc3(x)         return x   model = Net().cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01)   # 模型训练 def train(model, train_loader, epoch):     model.train()     train_loss = 0     for i, data in enumerate(train_loader, 0):         x, y = data         x = x.cuda()         y = y.cuda()         optimizer.zero_grad()         y_hat = model(x)         loss = criterion(y_hat, y)         loss.backward()         optimizer.step()         train_loss  = loss     loss_mean = train_loss / (i 1)     print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))  # 模型测试 def test(model, test_loader):     model.eval()     test_loss = 0     correct = 0     with torch.no_grad():         for i, data in enumerate(test_loader, 0):             x, y = data             x = x.cuda()             y = y.cuda()             optimizer.zero_grad()             y_hat = model(x)             test_loss  = criterion(y_hat, y).item()             pred = y_hat.max(1, keepdim=True)[1]             correct  = pred.eq(y.view_as(pred)).sum().item()         test_loss /= (i 1)         print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(             test_loss, correct, len(test_data), 100. * correct / len(test_data)))   def main():      # 如果test_flag=True,加载保存模型     if test_flag:         # 直接验证加载保存模型,不执行本模块后续步骤         checkpoint = torch.load(log_dir)         model.loadstate_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epochs = checkpoint['epoch']
        test(model, test_load)
        return

    for epoch in range(0, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # 保存模型
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)

if __name__ == '__main__':
    main()

3.在加载的模型基础上继续训练

在训练模型的时候可能会因为一些问题导致程序中断,或者常常需要观察训练情况的变化来更改学习率等参数,这时候就需要加载中断前保存的模型,并在此基础上继续训练,这时候只需要对上例中的 main() 函数做相应的修改即可,修改后的 main() 函数如下:

def main():

    # 如果test_flag=True,则加载已保存的模型
    if test_flag:
        # 加载保存的模型直接进行测试机验证,不进行此模块以后的步骤
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return

    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存模型,将从头开始训练!')

    for epoch in range(start_epoch+1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # 保存模型
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)

二、保存整个模型 1.保存

torch.save(model, path)

2.加载

model = torch.load(path)

用法可参照上例。

这篇博客是一个快速上手指南,想深入了解PyTorch保存和加载模型中的相关函数和方法,请移步这篇博客:

PyTorch模型保存深入理解 - 简书 (jianshu.com)

标签: 三极管pxt8550贴片sot

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

锐单商城 - 一站式电子元器件采购平台