资讯详情

【PyTorch】09深度体验之图像分类

9 PyTorch深度体验

图像分类(Image Classification)

【PyTorch】8.1 图像分类

9.1 模型如何完成图像分类?

9.2 ResNet18模型实例

图像分类ResNet网络结构 参考文献:Deep Residual Learning for Image Recognition

import os import time import torch.nn as nn import torch import torchvision.transforms as transforms from PIL import Image from matplotlib import pyplot as plt import torchvision.models as models  # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu")  # config vis = True # vis = False vis_row = 4  norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225]  # 数据预处理 inference_transform = transforms.Compose([     transforms.Resize(256),     transforms.CenterCrop(224),     transforms.ToTensor(),     transforms.Normalize(norm_mean, norm_std), ])  # 类别标签 classes = ["ants", "bees"]  def img_transform(img_rgb
       
        , transform
        =
        None
        )
        : 
        """ 将数据转换为模型读取的形式 :param img_rgb: PIL Image :param transform: torchvision.transform :return: tensor """ 
        if transform 
        is 
        None
        : 
        raise ValueError
        (
        "找不到transform!必须有transform对img进行处理"
        ) img_t 
        = transform
        (img_rgb
        ) 
        return img_t 
        def 
        get_img_name
        (img_dir
        , 
        format
        =
        "jpg"
        )
        : 
        """ 获取文件夹下format格式的文件名 :param img_dir: str :param format: str :return: list """ file_names 
        = os
        .listdir
        (img_dir
        ) img_names 
        = 
        list
        (
        filter
        (
        lambda x
        : x
        .endswith
        (
        format
        )
        , file_names
        )
        ) 
        if 
        len
        (img_names
        ) 
        < 
        1
        : 
        raise ValueError
        (
        "{}下找不到{}格式数据"
        .
        format
        (img_dir
        , 
        format
        )
        ) 
        return img_names 
        def 
        get_model
        (m_path
        , vis_model
        =
        False
        )
        : resnet18 
        = models
        .resnet18
        (
        ) num_ftrs 
        = resnet18
        .fc
        .in_features resnet18
        .fc 
        = nn
        .Linear
        (num_ftrs
        , 
        2
        ) checkpoint 
        = torch
        .load
        (m_path
        ) resnet18
        .load_state_dict
        (checkpoint
        [
        'model_state_dict'
        ]
        ) 
        if vis_model
        : 
        from torchsummary 
        import summary 
        # 查看模型结构及参数信息 summary
        (resnet18
        , input_size
        =
        (
        3
        , 
        224
        , 
        224
        )
        , device
        =
        "cpu"
        ) 
        return resnet18 
        if __name__ 
        == 
        "__main__"
        : 
        # 设置硬盘存放数据的路径 
        # img_dir = os.path.join("..", "..", "data/hymenoptera_data/val/bees") img_dir 
        = os
        .path
        .join
        (
        ".."
        , 
        "data_set"
        , 
        "hymenoptera_data/val/bees"
        ) model_path 
        = 
        "./model_checkpoint/checkpoint_14_epoch.pkl" time_total 
        = 
        0 img_list
        , img_pred 
        = 
        list
        (
        )
        , 
        list
        (
        ) 
        # 1. data img_names 
        = get_img_name
        (img_dir
        ) num_img 
        = 
        len
        (img_names
        ) 
        # 2. model resnet18 
        = get_model
        (model_path
        , 
        True
        ) resnet18
        .to
        (device
        ) 
        # 模型迁移加载至GPU resnet18
        .
        eval
        (
        ) 
        # 设置模型为验证状态 
        with torch
        .no_grad
        (
        )
        : 
        # 以下过程,不用计算梯度,以减少内存消耗,提高运算速度 
        for idx
        , img_name 
        in 
        enumerate
        (img_names
        )
        : path_img 
        = os
        .path
        .join
        (img_dir
        , img_name
        ) 
        # step 1/4 : path --> img img_rgb 
        = Image
        .
        open
        (path_img
        )
        .convert
        (
        'RGB'
        ) 
        # step 2/4 : img --> tensor(模型输入的格式) img_tensor 
        = img_transform
        (img_rgb
        , inference_transform
        ) img_tensor
        .unsqueeze_
        (
        0
        ) 
        # 增加一个新维度,变为4D(符合模型输入) img_tensor 
        = img_tensor
        .to
        (device
        ) 
        # 将(4D)张量数据img_tensor加载至GPU 
        # step 3/4 : tensor --> vector time_tic 
        = time
        .time
        (
        ) 
        # 记录时间 outputs 
        = resnet18
        (img_tensor
        ) time_toc 
        = time
        .time
        (
        ) 
        # step 4/4 : visualization _
        , pred_int 
        = torch
        .
        max
        (outputs
        .data
        , 
        1
        ) pred_str 
        = classes
        [
        int
        (pred_int
        )
        ] 
        if vis
        : img_list
        .append
        (img_rgb
        ) img_pred
        .append
        (pred_str
        ) 
        if 
        (idx
        +
        1
        ) 
        % 
        (vis_row
        *vis_row
        ) 
        == 
        0 
        or num_img 
        == idx
        +
        1
        : 
        for i 
        in 
        range
        (
        len
        (img_list
        )
        )
        : plt
        .subplot
        (vis_row
        , vis_row
        , i
        +
        1
        )
        .imshow
        (img_list
        [i
        ]
        ) plt
        .title
        (
        "predict:{}"
        .
        format
        (img_pred
        [i
        ]
        )
        ) plt
        .show
        (
        ) plt
        .close
        (
        ) img_list
        , img_pred 
        = 
        list
        (
        )
        , 
        list
        (
        ) time_s 
        = time_toc
        -time_tic time_total 
        += time_s 
        print
        (
        '{:d}/{:d}: {} {:.3f}s '
        .
        format
        (idx 
        + 
        1
        , num_img
        , img_name
        , time_s
        )
        ) 
        print
        (
        "\ndevice:{} total time:{:.1f}s mean:{:.3f}s"
        . 
        format
        (device
        , time_total
        , time_total
        /num_img
        )
        ) 
        if torch
        .cuda
        .is_available
        (
        )
        : 
        print
        (
        "GPU name:{}"
        .
        format
        (torch
        .cuda
        .get_device_name
        (
        )
        )
        ) 
       

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15           [-1, 64, 56, 56]          36,864
      BatchNorm2d-16           [-1, 64, 56, 56]             128
             ReLU-17           [-1, 64, 56, 56]               0
       BasicBlock-18           [-1, 64, 56, 56]               0
           Conv2d-19          [-1, 128, 28, 28]          73,728
      BatchNorm2d-20          [-1, 128, 28, 28]             256
             ReLU-21          [-1, 128, 28, 28]               0
           Conv2d-22          [-1, 128, 28, 28]         147,456
      BatchNorm2d-23          [-1, 128, 28, 28]             256
           Conv2d-24          [-1, 128, 28, 28]           8,192
      BatchNorm2d-25          [-1, 128, 28, 28]             256
             ReLU-26          [-1, 128, 28, 28]               0
       BasicBlock-27          [-1, 128, 28, 28]               0
           Conv2d-28          [-1, 128, 28, 28]         147,456
      BatchNorm2d-29          [-1, 128, 28, 28]             256
             ReLU-30          [-1, 128, 28, 28]               0
           Conv2d-31          [-1, 128, 28, 28]         147,456
      BatchNorm2d-32          [-1, 128, 28, 28]             256
             ReLU-33          [-1, 128, 28, 28]               0
       BasicBlock-34          [-1, 128, 28, 28]               0
           Conv2d-35          [-1, 256, 14, 14]         294,912
      BatchNorm2d-36          [-1, 256, 14, 14]             512
             ReLU-37          [-1, 256, 14, 14]               0
           Conv2d-38          [-1, 256, 14, 14]         589,824
      BatchNorm2d-39          [-1, 256, 14, 14]             512
           Conv2d-40          [-1, 256, 14, 14]          32,768
      BatchNorm2d-41          [-1, 256, 14, 14]             512
             ReLU-42          [-1, 256, 14, 14]               0
       BasicBlock-43          [-1, 256, 14, 14]               0
           Conv2d-44          [-1, 256, 14, 14]         589,824
      BatchNorm2d-45          [-1, 256, 14, 14]             512
             ReLU-46          [-1, 256, 14, 14]               0
           Conv2d-47          [-1, 256, 14, 14]         589,824
      BatchNorm2d-48          [-1, 256, 14, 14]             512
             ReLU-49          [-1, 256, 14, 14]               0
       BasicBlock-50          [-1, 256, 14, 14]               0
           Conv2d-51            [-1, 512, 7, 7]       1,179,648
      BatchNorm2d-52            [-1, 512, 7, 7]           1,024
             ReLU-53            [-1, 512, 7, 7]               0
           Conv2d-54            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-55            [-1, 512, 7, 7]           1,024
           Conv2d-56            [-1, 512, 7, 7]         131,072
      BatchNorm2d-57            [-1, 512, 7, 7]           1,024
             ReLU-58            [-1, 512, 7, 7]               0
       BasicBlock-59            [-1, 512, 7, 7]               0
           Conv2d-60            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-61            [-1, 512, 7, 7]           1,024
             ReLU-62            [-1, 512, 7, 7]               0
           Conv2d-63            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-64            [-1, 512, 7, 7]           1,024
             ReLU-65            [-1, 512, 7,<

标签: 二极管db220b

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

锐单商城 - 一站式电子元器件采购平台