|
本帖最后由 黑崎一护 于 2018-8-27 14:04 编辑
本文转载于:https://blog.csdn.net/simple_the_best/article/details/75267863
1 简介
MNIST 数据集是一个经典的手写体数据集.
MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:
- Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
- Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
- Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
- Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
复制代码 MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。
不妨新建一个文件夹 – mnist, 将数据集下载到 mnist 以后, 解压即可:
2 读取MNIST数据集
图片是以字节的形式进行存储, 我们需要把它们读取到 NumPy array 中, 以便训练和测试算法.
- import os
- import struct
- import numpy as np
- def load_mnist(path, kind='train'):
- """Load MNIST data from `path`"""
- labels_path = os.path.join(path,'%s-labels-idx1-ubyte'% kind)
- images_path = os.path.join(path,'%s-images-idx3-ubyte'% kind)
- with open(labels_path, 'rb') as lbpath:
- magic, n = struct.unpack('>II',lbpath.read(8))
- labels = np.fromfile(lbpath,dtype=np.uint8)
- with open(images_path, 'rb') as imgpath:
- magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16))
- images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784)
- return images, labels
复制代码 load_mnist 函数返回两个数组, 第一个是一个 n x m 维的 NumPy array(images), 这里的 n 是样本数(行数), m 是特征数(列数). 训练数据集包含 60,000 个样本, 测试数据集包含 10,000 样本. 在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示. 在这里, 我们将 28 x 28 的像素展开为一个一维的行向量, 这些行向量就是图片数组里的行(每行 784 个值, 或者说每行就是代表了一张图片). load_mnist 函数返回的第二个数组(labels) 包含了相应的目标变量, 也就是手写数字的类标签(整数 0-9).第一次见的话, 可能会觉得我们读取图片的方式有点奇怪:
- magic, n = struct.unpack('>II', lbpath.read(8))
- labels = np.fromfile(lbpath, dtype=np.uint8)
复制代码 为了理解这两行代码, 我们先来看一下 MNIST 网站上对数据集的介绍:
- TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
- [offset] [type] [value] [description]
- 0000 32 bit integer 0x00000801(2049) magic number (MSB first)
- 0004 32 bit integer 60000 number of items
- 0008 unsigned byte ?? label
- 0009 unsigned byte ?? label
- ........
- xxxx unsigned byte ?? label
- The labels values are 0 to 9.
复制代码 通过使用上面两行代码, 我们首先读入 magic number, 它是一个文件协议的描述, 也是在我们调用 fromfile 方法将字节读入 NumPy array 之前在文件缓冲中的 item 数(n). 作为参数值传入 struct.unpack 的 >II 有两个部分:
>: 这是指大端(用来定义字节是如何存储的); 如果你还不知道什么是大端和小端, Endianness 是一个非常好的解释. (关于大小端, 更多内容可见<<深入理解计算机系统 – 2.1 节信息存储>>)
I: 这是指一个无符号整数.
通过执行下面的代码, 我们将会从刚刚解压 MNIST 数据集后的 mnist 目录下加载 60,000 个训练样本和 10,000 个测试样本.
3 MNIST数据集可视化
为了了解 MNIST 中的图片看起来到底是个啥, 让我们来对它们进行可视化处理. 从 feature matrix 中将 784-像素值 的向量 reshape 为之前的 28*28 的形状, 然后通过 matplotlib 的 imshow 函数进行绘制:- import matplotlib.pyplot as plt
- fig, ax = plt.subplots(
- nrows=2,
- ncols=5,
- sharex=True,
- sharey=True, )
- ax = ax.flatten()
- for i in range(10):
- img = X_train[y_train == i][0].reshape(28, 28)
- ax[i].imshow(img, cmap='Greys', interpolation='nearest')
- ax[0].set_xticks([])
- ax[0].set_yticks([])
- plt.tight_layout()
- plt.show()
复制代码 我们现在应该可以看到一个 2*5 的图片, 里面分别是 0-9 单个数字的图片.
此外, 我们还可以绘制某一数字的多个样本图片, 来看一下这些手写样本到底有多不同:
- fig, ax = plt.subplots(nrows=5,ncols=5,sharex=True,sharey=True, )
- ax = ax.flatten()
- for i in range(25):
- img = X_train[y_train == 7].reshape(28, 28)
- ax.imshow(img, cmap='Greys', interpolation='nearest')
- ax[0].set_xticks([])
- ax[0].set_yticks([])
- plt.tight_layout()
- plt.show()
复制代码 执行上面的代码后, 我们应该看到数字 7 的 25 个不同形态:
4 MNIST数据集保存为CSV格式
另外, 我们也可以选择将 MNIST 图片数据和标签保存为 CSV 文件, 这样就可以在不支持特殊的字节格式的程序中打开数据集. 但是, 有一点要说明, CSV 的文件格式将会占用更多的磁盘空间, 如下所示:
- train_img.csv: 109.5 MB
- train_labels.csv: 120 KB
- test_img.csv: 18.3 MB
- test_labels: 20 KB
复制代码 如果我们打算保存这些 CSV 文件, 在将 MNIST 数据集加载入 NumPy array 以后, 我们应该执行下列代码:
- np.savetxt('train_img.csv', X_train,fmt='%i', delimiter=',')
- np.savetxt('train_labels.csv', y_train,fmt='%i', delimiter=',')
- np.savetxt('test_img.csv', X_test,fmt='%i', delimiter=',')
- np.savetxt('test_labels.csv', y_test,fmt='%i', delimiter=',')
复制代码 一旦将数据集保存为 CSV 文件, 我们也可以用 NumPy 的 genfromtxt 函数重新将它们加载入程序中:- X_train = np.genfromtxt('train_img.csv',dtype=int, delimiter=',')
- y_train = np.genfromtxt('train_labels.csv',dtype=int, delimiter=',')
- X_test = np.genfromtxt('test_img.csv',dtype=int, delimiter=',')
- y_test = np.genfromtxt('test_labels.csv',dtype=int, delimiter=',')
复制代码 不过, 从 CSV 文件中加载 MNIST 数据将会显著发给更长的时间, 因此如果可能的话, 还是建议你维持数据集原有的字节格式.
参考:
- Book , Python Machine Learning.
|
|