资讯详情

pytorch数据集相关模块

torchvision 使用中数据集

数据集和transfroms内容结合在一起

使用标准数据集

torchvision.datasets 官方指导文件:https://pytorch.org/vision/stable/datasets.html

torchvision中的数据集

关于torchvision中的datasets,选择标准数据的类型有很多,但需要说明每个数据集中参数

例子:

CIFAR10数据集

参数

  • ( string ) – 根目录的数据集, cifar-10-batches-py如果下载设置为 True,该目录存在或将保存在该目录中。
  • ( bool , optional ) – 如果为真,则从训练集创建数据集,否则从测试集创建。
  • ( callable , optional ) – 它接收函数/转换 PIL 图像并返回转换版本。transforms.RandomCrop
  • ( callable , optional ) – 函数/转换接收目标并转换。
  • ( bool , optional ) – 如果为 true,则从 Internet 将数据集下载并放入根目录中。若数据集已下载,则不再下载。
torchvision.datasets.CIFAR10(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False) 

以CIFAR以数据集为例

import torchvision  # torchvision中CIFAR10数据集下载 train_set = torchvision.datasets.CIFAR10(root="./DataSet",train=True, download=True) test_set = torchvision.datasets.CIFAR10(root="./DataSet",train=False, download=True)  # 所有输出数据集classes print(test_set.classes)  # 将test_set提取第一个数据,img保存图片,target保存相应的标签 img, target = test_set[0] img.show() print(test_set[0]) print(img)
print(target)

输出:
# 已经下载过的数据不会重复下载
Files already downloaded and verified
Files already downloaded and verified
#test_set 中的所有classes 
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# test_set[0]
(<PIL.Image.Image image mode=RGB size=32x32 at 0x201A62E97B8>, 3)
# img
<PIL.Image.Image image mode=RGB size=32x32 at 0x201A62E95F8>
# target
3

简单介绍一下CIFAR10数据集:

CIFAR-10 数据集由 10 个类别的 60000 个 32x32 彩色图像组成,每个类别包含 6000 个图像。有 50000 个训练图像和 10000 个测试图像。

数据集分为五个训练批次和一个测试批次,每个批次有 10000 张图像。测试批次恰好包含来自每个类别的 1000 个随机选择的图像。训练批次包含随机顺序的剩余图像,但一些训练批次可能包含来自一个类的图像多于另一个。在它们之间,训练批次恰好包含来自每个类别的 5000 张图像。

Set与Transforms联动

将Transforms,torchvision的数据预处理功能同dataset进行结合,将数据集中数据进行数据类型转换,并使用torchvision的tensorboard小工具进行数据类型转换后的记录


import torchvision

# 数据集中数据为PIL.Image类型,使用ToTensor进行类型转换
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

# 下载数据集,在其中参数规定数据的处理方式
train_set = torchvision.datasets.CIFAR10(root="./DataSet", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./DataSet", train=False, transform=dataset_transform, download=True)

print(type(test_set[0]))

# 可以看到数据类型为tensor,使用tensorboard进行记录
writer = SummaryWriter("Set_logs")
for i in range(10):
    img, target = test_set[i]
    writer.add_image(tag="Top10", img_tensor=img, global_step=i)
writer.close()

在terminal中tensorboard --logdir=Set_logs进入tensorboard界面进行查看

DataLoader的使用

与Dataset的不同,dataset是现成的已经完备的数据集,dataloader是数据加载器,dataloader所作的是从dataset中取数据,如何去、取多少,由dataloader中参数进行控制

官网中的信息:

数据加载器。结合数据集和采样器,并提供给定数据集的可迭代对象。

支持具有单进程或多进程加载、自定义加载顺序和可选的自动批处理(整理)和内存固定的地图样式和可迭代样式数据集。DataLoader

参数

  • ( Dataset ) – 从中加载数据的数据集。
  • ( int , optional ) – 每批要加载多少样本(默认值:1)。
  • ( bool , optional ) – 设置为True在每个 epoch 重新洗牌数据(默认值:)False
  • ( Sampler or Iterable , optional ) – 定义从数据集中抽取样本的策略。可以是任何已 实施的Iterable__len__如果指定,则shuffle不得指定。
  • Sampler或**Iterable *,*可选)——类似sampler,但一次返回一批索引。batch_size与、shufflesampler和互斥 drop_last
  • ( int , optional ) – 用于数据加载的子进程数。0表示数据将在主进程中加载。(默认0:)
  • ( callable , optional ) – 合并样本列表以形成小批量张量。从地图样式数据集中使用批量加载时使用。
  • ( bool , optional ) – 如果True是,数据加载器将在返回之前将张量复制到 CUDA 固定内存中。如果您的数据元素是自定义类型,或者您collate_fn返回的批次是自定义类型,请参见下面的示例。
  • ( bool , optional ) –True如果数据集大小不能被批次大小整除,则设置为丢弃最后一个不完整的批次。如果False数据集的大小不能被批大小整除,那么最后一批将更小。(默认False:)
  • ( numeric , optional ) – 如果为正,则从工人那里收集批次的超时值。应始终为非负数。(默认0:)
  • ( callable , optional ) - 如果不是None,这将在每个工作子进程上调用,并在播种之后和数据加载之前以工作人员 ID(一个 int in )作为输入。(默认:)[0, num_workers - 1]``None
  • torch.Generator*,*可选)- 如果不是None,则 RandomSampler 将使用此 RNG 生成随机索引和多处理以生成 工作人员的base_seed。(默认None:)
  • ( int , optional , keyword-only arg ) – 每个工作人员预先加载的样本数。2意味着将在所有工作人员中预取总共 2 * num_workers 个样本。(默认2:)
  • ( bool , optional ) – 如果True是 ,数据加载器将不会在数据集被使用一次后关闭工作进程。这允许保持工作人员数据集实例处于活动状态。(默认False:)

DataLoader()中参数较多,但其中大部分就有默认值。

作用

DataLoader()中仅有参数 dataset不具备默认值,是为了方便使用者使用非官方自制的数据集。

使用示例

batch_size=4参数的具体意义

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-th8ZhBFA-1652926054632)(D:\STUDY\神经网络和深度学习\images\batch_size的具体含义.png)]

dataset[0]返回(img,target),batch_size=4将每4个dataset同类型数据打包 及 imgs[img0,img1,img2,img3],target[target0,target1,target2,target3]

""" 使用dataloader加载dataset中CIFAR10数据集,并将取打包重新排列,变化过程放入tensorboard中 """
import torchvision

# 使用dataset中CIFAR10数据集的测试集做为演示
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10(root="./DataSet", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

# 测试数据集中第一张图片样本及其归类
img, target = test_data[0]
print(img.shape)    #torch.Size([3, 32, 32])
print(target)       #3

# 参数batch_size=4 的具体含义就是将test_data中每4个元素同类型的数据进行打包
""" imgs, targets = test_loader[0] print(imgs.shape) print(targets) TypeError: 'DataLoader' object does not support indexing """

writer = SummaryWriter("dataloader_logs")
step = 0
for data in test_loader:
    imgs,targets = data
    #print(imgs.shape) # torch.Size([4, 3, 32, 32])
    #print(targets) # tensor([1, 6, 4, 1])
    writer.add_images("test_data",imgs,step)
    step += 1

writer.close()


以上就是dataloader相关的简单操作,详细具体的使用还要参考源码及官方文档

ze([4, 3, 32, 32]) #print(targets) # tensor([1, 6, 4, 1]) writer.add_images(“test_data”,imgs,step) step += 1

writer.close()


以上就是dataloader相关的简单操作,详细具体的使用还要参考源码及官方文档

在实际训练过程中,也会使用`for data in dataloader:`这样的形式来加载数据,每个data中的imgs会被输送到神经网络中。

标签: pcb传感器进口201a75

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

锐单商城 - 一站式电子元器件采购平台