Datasets and Dataloaders
- Dataset (torch.utils.data.Dataset) 存储了样本及其对应的标签。
- DataLoader (torch.utils.data.DataLoader) 方便访问 Dataset。
Dataset 的类型
- 图片
- 文本
- 音频
等等。
现成的 Dataset 有哪些
例如, FashionMNIST。
>>> from torchvision import datasets
>>> dir(datasets)
['CIFAR10', 'CIFAR100', 'CLEVRClassification', 'Caltech101', 'Caltech256', 'CelebA', 'Cityscapes', 'CocoCaptions', 'CocoDetection', 'Country211', 'DTD', 'DatasetFolder', 'EMNIST', 'EuroSAT', 'FER2013', 'FGVCAircraft', 'FakeData', 'FashionMNIST', 'Flickr30k', 'Flickr8k', 'Flowers102', 'FlyingChairs', 'FlyingThings3D', 'Food101', 'GTSRB', 'HD1K', 'HMDB51', 'INaturalist', 'ImageFolder', 'ImageNet', 'KMNIST', 'Kinetics', 'Kinetics400', 'Kitti', 'KittiFlow', 'LFWPairs', 'LFWPeople', 'LSUN', 'LSUNClass', 'MNIST', 'Omniglot', 'OxfordIIITPet', 'PCAM', 'PhotoTour', 'Places365', 'QMNIST', 'RenderedSST2', 'SBDataset', 'SBU', 'SEMEION', 'STL10', 'SUN397', 'SVHN', 'Sintel', 'StanfordCars', 'UCF101', 'USPS', 'VOCDetection', 'VOCSegmentation', 'VisionDataset', 'WIDERFace', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_optical_flow', 'caltech', 'celeba', 'cifar', 'cityscapes', 'clevr', 'coco', 'country211', 'dtd', 'eurosat', 'fakedata', 'fer2013', 'fgvc_aircraft', 'flickr', 'flowers102', 'folder', 'food101', 'gtsrb', 'hmdb51', 'imagenet', 'inaturalist', 'kinetics', 'kitti', 'lfw', 'lsun', 'mnist', 'omniglot', 'oxford_iiit_pet', 'pcam', 'phototour', 'places365', 'rendered_sst2', 'sbd', 'sbu', 'semeion', 'stanford_cars', 'stl10', 'sun397', 'svhn', 'ucf101', 'usps', 'utils', 'video_utils', 'vision', 'voc', 'widerface']
>>>
可以看到除了 FashionMNIST,还有很多其他的数据集。
但是,我有个疑惑。为何 FashionMNIST 不是在 torch.utils.data.Dataset 中,而是在 torchvision.datasets 中。
MNIST 是什么?
FashionMNIST 的 MNIST 后缀代表什么呢?
https://en.wikipedia.org/wiki/MNIST_database
The MNIST database (Modified National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems.
MNIST database 即,改进后的美国国家标准与技术研究院数据库。
- MNIST:手写数字图片的数据集合。包含 6万张训练图片及1万张测试图片。大小为统一的 28x28 像素图片。
- EMNIST: 增加了大小写字母,并包含了手写数字。
- FashionMNIST: 时尚物品数据集合。包含 10 个分类的物品,例如 t shirt,鞋子,西装之类。同样是 28x28 像素的,灰度图片(灰度值0~255)。
从 torchvision.datasets 可以看到,既有 MNIST,也有 EMNIST,FashionMNIST。
数据集的下载
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root="data", train=True, download=True, transform=ToTensor()
)
- root 参数,即数据集所在的目录。
- 在 download 为 True 时,最自动下载数据集文件
- train 代表,这是训练数据
是从 AWS S3 上下载的压缩包文件。
>python main.py
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data\FashionMNIST\raw\train-images-idx3-ubyte.gz
Extracting data\FashionMNIST\raw\train-images-idx3-ubyte.gz to data\FashionMNIST\raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data\FashionMNIST\raw\train-labels-idx1-ubyte.gz
Extracting data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to data\FashionMNIST\raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz
Extracting data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to data\FashionMNIST\raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to data\FashionMNIST\raw
文件大小:
> tree -h
.
├── [4.0K] data
│ └── [4.0K] FashionMNIST
│ └── [4.0K] raw
│ ├── [7.5M] t10k-images-idx3-ubyte
│ ├── [4.2M] t10k-images-idx3-ubyte.gz
│ ├── [9.8K] t10k-labels-idx1-ubyte
│ ├── [5.0K] t10k-labels-idx1-ubyte.gz
│ ├── [ 45M] train-images-idx3-ubyte
│ ├── [ 25M] train-images-idx3-ubyte.gz
│ ├── [ 59K] train-labels-idx1-ubyte
│ └── [ 29K] train-labels-idx1-ubyte.gz
└── [ 346] main.py
3 directories, 9 files
完整代码
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data", train=True, download=True, transform=ToTensor()
)
print(len(training_data)) # 60000
test_data = datasets.FashionMNIST(
root="data", train=False, download=True, transform=ToTensor()
)
print(len(test_data)) # 10000
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
运行结果:
数据集在有本地文件的情况下,加载非常快,1秒都不用。打开那些本地数据集文件,可以看到并不是预想的单个图片文件,而且序列化到一起的二进制文件。
Label 与 Feature
- Label: 输出。即上面 demo 里的物品名称。也是我们用 PyTorch 预测的结果输出。
- Feature: 输入。即作为预测模型的输入,这里是图片的像素 pattern (patterns in the images pixels)。我不知道 pattern 翻译为什么比较好。
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
这里的参数:
- transform:用于修改 feature
- target_transform: 用于修改 label
修改、转换的目的是,使数据适用于训练。
DataLoader 的作用
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
提升效率,批量从数据集合中加载数据,例如,这里一次加载 64 个,同时设置了随机数据读取。这样每次读出来的就是多个图片和 label。
参考
- https://docs.microsoft.com/zh-cn/learn/modules/intro-machine-learning-pytorch/3-data
微信关注我哦 👍
我是来自山东烟台的一名开发者,有感兴趣的话题,或者软件开发需求,欢迎加微信 zhongwei 聊聊, 查看更多联系方式