9 PyTorch深度体验
图像分类(Image Classification)
【PyTorch】8.1 图像分类
9.1 模型如何完成图像分类?
9.2 ResNet18模型实例
图像分类ResNet网络结构 参考文献:Deep Residual Learning for Image Recognition
import os import time import torch.nn as nn import torch import torchvision.transforms as transforms from PIL import Image from matplotlib import pyplot as plt import torchvision.models as models # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu") # config vis = True # vis = False vis_row = 4 norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] # 数据预处理 inference_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # 类别
标签 classes = ["ants", "bees"] def img_transform(img_rgb
, transform = None ) : """ 将数据转换为模型读取的形式 :param img_rgb: PIL Image :param transform: torchvision.transform :return: tensor """ if transform is None : raise ValueError ( "找不到transform!必须有transform对img进行处理" ) img_t = transform (img_rgb ) return img_t def get_img_name (img_dir , format = "jpg" ) : """ 获取文件夹下format格式的文件名 :param img_dir: str :param format: str :return: list """ file_names = os .listdir (img_dir ) img_names = list ( filter ( lambda x : x .endswith ( format ) , file_names ) ) if len (img_names ) < 1 : raise ValueError ( "{}下找不到{}格式数据" . format (img_dir , format ) ) return img_names def get_model (m_path , vis_model = False ) : resnet18 = models .resnet18 ( ) num_ftrs = resnet18 .fc .in_features resnet18 .fc = nn .Linear (num_ftrs , 2 ) checkpoint = torch .load (m_path ) resnet18 .load_state_dict (checkpoint [ 'model_state_dict' ] ) if vis_model : from torchsummary import summary # 查看模型结构及参数信息 summary (resnet18 , input_size = ( 3 , 224 , 224 ) , device = "cpu" ) return resnet18 if __name__ == "__main__" : # 设置硬盘存放数据的路径 # img_dir = os.path.join("..", "..", "data/hymenoptera_data/val/bees") img_dir = os .path .join ( ".." , "data_set" , "hymenoptera_data/val/bees" ) model_path = "./model_checkpoint/checkpoint_14_epoch.pkl" time_total = 0 img_list , img_pred = list ( ) , list ( ) # 1. data img_names = get_img_name (img_dir ) num_img = len (img_names ) # 2. model resnet18 = get_model (model_path , True ) resnet18 .to (device ) # 模型迁移加载至GPU resnet18 . eval ( ) # 设置模型为验证状态 with torch .no_grad ( ) : # 以下过程,不用计算梯度,以减少内存消耗,提高运算速度 for idx , img_name in enumerate (img_names ) : path_img = os .path .join (img_dir , img_name ) # step 1/4 : path --> img img_rgb = Image . open (path_img ) .convert ( 'RGB' ) # step 2/4 : img --> tensor(模型输入的格式) img_tensor = img_transform (img_rgb , inference_transform ) img_tensor .unsqueeze_ ( 0 ) # 增加一个新维度,变为4D(符合模型输入) img_tensor = img_tensor .to (device ) # 将(4D)张量数据img_tensor加载至GPU # step 3/4 : tensor --> vector time_tic = time .time ( ) # 记录时间 outputs = resnet18 (img_tensor ) time_toc = time .time ( ) # step 4/4 : visualization _ , pred_int = torch . max (outputs .data , 1 ) pred_str = classes [ int (pred_int ) ] if vis : img_list .append (img_rgb ) img_pred .append (pred_str ) if (idx + 1 ) % (vis_row *vis_row ) == 0 or num_img == idx + 1 : for i in range ( len (img_list ) ) : plt .subplot (vis_row , vis_row , i + 1 ) .imshow (img_list [i ] ) plt .title ( "predict:{}" . format (img_pred [i ] ) ) plt .show ( ) plt .close ( ) img_list , img_pred = list ( ) , list ( ) time_s = time_toc -time_tic time_total += time_s print ( '{:d}/{:d}: {} {:.3f}s ' . format (idx + 1 , num_img , img_name , time_s ) ) print ( "\ndevice:{} total time:{:.1f}s mean:{:.3f}s" . format (device , time_total , time_total /num_img ) ) if torch .cuda .is_available ( ) : print ( "GPU name:{}" . format (torch .cuda .get_device_name ( ) ) )
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 64, 56, 56] 36,864
BatchNorm2d-6 [-1, 64, 56, 56] 128
ReLU-7 [-1, 64, 56, 56] 0
Conv2d-8 [-1, 64, 56, 56] 36,864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
BasicBlock-11 [-1, 64, 56, 56] 0
Conv2d-12 [-1, 64, 56, 56] 36,864
BatchNorm2d-13 [-1, 64, 56, 56] 128
ReLU-14 [-1, 64, 56, 56] 0
Conv2d-15 [-1, 64, 56, 56] 36,864
BatchNorm2d-16 [-1, 64, 56, 56] 128
ReLU-17 [-1, 64, 56, 56] 0
BasicBlock-18 [-1, 64, 56, 56] 0
Conv2d-19 [-1, 128, 28, 28] 73,728
BatchNorm2d-20 [-1, 128, 28, 28] 256
ReLU-21 [-1, 128, 28, 28] 0
Conv2d-22 [-1, 128, 28, 28] 147,456
BatchNorm2d-23 [-1, 128, 28, 28] 256
Conv2d-24 [-1, 128, 28, 28] 8,192
BatchNorm2d-25 [-1, 128, 28, 28] 256
ReLU-26 [-1, 128, 28, 28] 0
BasicBlock-27 [-1, 128, 28, 28] 0
Conv2d-28 [-1, 128, 28, 28] 147,456
BatchNorm2d-29 [-1, 128, 28, 28] 256
ReLU-30 [-1, 128, 28, 28] 0
Conv2d-31 [-1, 128, 28, 28] 147,456
BatchNorm2d-32 [-1, 128, 28, 28] 256
ReLU-33 [-1, 128, 28, 28] 0
BasicBlock-34 [-1, 128, 28, 28] 0
Conv2d-35 [-1, 256, 14, 14] 294,912
BatchNorm2d-36 [-1, 256, 14, 14] 512
ReLU-37 [-1, 256, 14, 14] 0
Conv2d-38 [-1, 256, 14, 14] 589,824
BatchNorm2d-39 [-1, 256, 14, 14] 512
Conv2d-40 [-1, 256, 14, 14] 32,768
BatchNorm2d-41 [-1, 256, 14, 14] 512
ReLU-42 [-1, 256, 14, 14] 0
BasicBlock-43 [-1, 256, 14, 14] 0
Conv2d-44 [-1, 256, 14, 14] 589,824
BatchNorm2d-45 [-1, 256, 14, 14] 512
ReLU-46 [-1, 256, 14, 14] 0
Conv2d-47 [-1, 256, 14, 14] 589,824
BatchNorm2d-48 [-1, 256, 14, 14] 512
ReLU-49 [-1, 256, 14, 14] 0
BasicBlock-50 [-1, 256, 14, 14] 0
Conv2d-51 [-1, 512, 7, 7] 1,179,648
BatchNorm2d-52 [-1, 512, 7, 7] 1,024
ReLU-53 [-1, 512, 7, 7] 0
Conv2d-54 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-55 [-1, 512, 7, 7] 1,024
Conv2d-56 [-1, 512, 7, 7] 131,072
BatchNorm2d-57 [-1, 512, 7, 7] 1,024
ReLU-58 [-1, 512, 7, 7] 0
BasicBlock-59 [-1, 512, 7, 7] 0
Conv2d-60 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-61 [-1, 512, 7, 7] 1,024
ReLU-62 [-1, 512, 7, 7] 0
Conv2d-63 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-64 [-1, 512, 7, 7] 1,024
ReLU-65 [-1, 512, 7,<