资讯详情

PyTorch数据Pipeline标准化代码模板

点击上方“”,选择加""或“

重磅干货,第一时间送达

PyTorch作为一个流行的深度学习框架,它的热度大大超过了TensorFlow感觉。根据之前的统计,目前TensorFlow虽然它仍然占据着工业界,但它仍然占据着工业界PyTorch在视觉和NLP领域顶级会议已经统一。

作者将关注本文PyTorch读取自定义数据pipeline模板和相关trciks以及如何优化数据读取的pipeline等。我们从PyTorch数据对象类Dataset开始。Dataset在PyTorch模块位于中间utils.data下。

from torch.utils.data import Dataset

本文将围绕Dataset对象分别从原始模板,torchvision的transforms模块、使用pandas辅助阅读,torch内置数据划分功能和DataLoader展开阐述。

PyTorch该官员为我们提供了一个标准化的自定义数据读取代码模块,作为一个读取框架,我们称之为原始模板。其代码结构如下:

from torch.utils.data import Dataset class CustomDataset(Dataset):     def __init__(self, ...):         # stuff              def __getitem__(self, index):         # stuff         return (img, label)              def __len__(self):         # return examples size         return count

根据这个标准化的代码模板,我们只需要根据自己的数据读取任务__init__()、__getitem__()和__len__()可以在三种方法中添加读取逻辑。作为PyTorch范式下的数据读取和后续data loader,三种方法缺一不可。其中:

  • __init__()初始用于初始据读取逻辑的函数,如读取包含标签和图片地址的csv文件、定义transform组合等。

  • __getitem__()函数用于返回数据和标签。目的是为了后续工作dataloader所调用。

  • __len__()函数用于返回样本数量。

现在我们在这个框架中填写几行代码来形成一个简单的数字案例。从1到100创建数字示例:

from torch.utils.data import Dataset class CustomDataset(Dataset):     def __init__(self):         self.samples = list(range(1, 101))     def __len__(self):         return len(self.samples)     def __getitem__(self, idx):         return self.samples[idx]          if __name__ == '__main__':     dataset = CustomDataset()     print(len(dataset))     print(dataset[50])     print(dataset[1:100])

65697ecea19ebc4eb29f77b731e570b1.png

然后我们来看看如何从内存中读取数据,以及如何嵌入读取过程torchvision中的transforms功能。torchvision是独立的torch辅助库用于数据、模型和一些图像增强操作。主要包括datasets默认数据集模块models经典模型模块,transforms图像增强模块和utils模块等。在使用torch读取数据时,通常会匹配transforms模块处理和增强数据。

添加了tranforms读取模块可以改写为:

from torch.utils.data import Dataset from torchvision import transforms as T   class CustomDataset(Dataset):     def __init__(self, ...):         # stuff         ...         # compose the transforms methods         self.transform = T.Compose([T.CenterCrop(100),                                 T.ToTensor()])              def __getitem__(self, index):         # stuff         ...         data = # Some data read from a file or image         # execute the transform         data = self.transform(data)           return (img, label)              def __len__(self):         # return examples size         return count          if __name__ == '__main__':     # Call the dataset     custom_dataset = CustomDataset(...)

可以看出,我们使用了它Compose该方法将各种数据处理方法聚合在一起,定义数据转换方法。它通常被用作初始化方法__init__()函数下。以猫狗图像数据为例。

定义数据读取方法如下:

class DogCat(Dataset):         def __init__(self, root, transforms=None, train=True, val=False):         """         get images and execute transforms.         """         self.val = val         imgs = [os.path.join(root, img) for img in os.listdir(root)]         # train: Cats_Dogs/trainset/cat.1.jpg         # val: Cats_Dogs/valset/cat.10004.jpg         imgs = sorted(imgs, key=lambda x: x.split('.')[-2])         self.imgs = imgs                  if transforms is None:             # normalize                   normalize = T.Normalize(mean = [0.485, 0.456, 0.406],                                      std = [0.229, 0.224, 0.225])             # trainset and valset have different data transform              # trainset need data augmentation but valset don't.             # valset               if self.val:                 self.transforms = T.Compose([                     T.Resize(224),                     T.CenterCrop(224),                     T.ToTensor(),                     normalize                 ])             # trainset             else:                 self.transforms = T.Compose([                     T.Resize(256),           T.RandomResizedCrop(224),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                ])
                       
    def __getitem__(self, index):
        """
        return data and label
        """
        img_path = self.imgs[index]
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        return data, label
  
    def __len__(self):
        """
        return images size.
        """
        return len(self.imgs)


if __name__ == "__main__":
    train_dataset = DogCat('./Cats_Dogs/trainset/', train=True)
    print(len(train_dataset))
    print(train_dataset[0])

     因为这个数据集已经分好了训练集和验证集,所以在读取和transforms的时候需要进行区分。运行示例如下:

     很多时候数据的目录地址和标签都是通过csv文件给出的。如下所示:

     此时在数据读取的pipeline中我们需要在__init__()方法中利用pandas把csv文件中包含的图片地址和标签融合进去。相应的数据读取pipeline模板可以改写为:

class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file
            transform: pytorch transforms for transforms and tensor conversion
        "
""
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Calculate len
        self.data_len = len(self.data_info.index)


    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)
        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)
        # Get label of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]
        return (img_as_tensor, single_image_label)


    def __len__(self):
        return self.data_len


if __name__ == "__main__":
    # Call dataset
    dataset =  CustomDatasetFromCSV('./labels.csv')

     以mnist_label.csv文件为示例:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import pandas as pd


class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file            
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = T.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Third column is for an operation indicator
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # Calculate len
        self.data_len = len(self.data_info.index)


    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)
        # Check if there is an operation
        some_operation = self.operation_arr[index]
        # If there is an operation
        if some_operation:
            # Do some operation on image
            # ...
            # ...
            pass


        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)
        # Get label of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]
        return (img_as_tensor, single_image_label)


    def __len__(self):
        return self.data_len


if __name__ == "__main__":
    transform = T.Compose([T.ToTensor()])
    dataset = CustomDatasetFromCSV('./mnist_labels.csv')
    print(len(dataset))
    print(dataset[5])

   运行示例如下:

     一般来说,为了模型训练的稳定,我们需要对数据划分训练集和验证集。torch的Dataset对象也提供了random_split函数作为数据划分工具,且划分结果可直接供后续的DataLoader使用。

     以kaggle的花朵数据为例:

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from torch.utils.data import random_split


transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.ToTensor()
 ])


dataset = ImageFolder('./flowers_photos', transform=transform)
print(dataset.class_to_idx)


trainset, valset = random_split(dataset, 
                [int(len(dataset)*0.7), len(dataset)-int(len(dataset)*0.7)])


trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):
    img, label = img.numpy(), label.numpy()
    print(img, label)


valloader = DataLoader(dataset=valset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):
    img, label = img.numpy(), label.numpy()
    print(img.shape, label)

     这里使用了ImageFolder模块,可以直接读取各标签对应的文件夹,部分运行示例如下:

     dataset方法写好之后,我们还需要使用DataLoader将其逐个喂给模型。上一节的数据划分我们已经用到了DataLoader函数。从本质上来讲,DataLoader只是调用了__getitem__()方法并按批次返回数据和标签。使用方法如下:

from torch.utils.data import DataLoader
from torchvision import transforms as T


if __name__ == "__main__":
    # Define transforms
    transformations = T.Compose([T.ToTensor()])
    # Define custom dataset
    dataset = CustomDatasetFromCSV('./labels.csv')
    # Define data loader
    data_loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)
    for images, labels in data_loader:
        # Feed the data to the model

     以上就是PyTorch读取数据的Pipeline主要方法和流程。基于Dataset对象的基本框架不变,具体细节可自定义化调整。

 
     

开始面向外开放啦👇👇👇

 
     

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。


下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。


下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。


交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

标签: 自动化口罩机传感器

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

锐单商城 - 一站式电子元器件采购平台