mnist.py的相关笔记

所用到的扩展模块有

  • urllib.request
  • os.path
  • gzip (下文不解释了,解压压缩包用的,直接用gzip.open()就行
  • pickle
  • os
  • numpy

os相关代码

os.path.dirname(path)

获取文件路径

os.path.absapth()

获取绝对路径

os.path.exists(path)

看函数名知用法。。。

  • 可能用到的特殊属性

    1
    2
    __fire__ #双下划线 表示当前目录的相对路径
    #所以用到转换绝对路径的函数

pickle库

一个可以帮助你以二进制流的格式将知道对象保存到文件里,以便下次加载

保存格式为pkl

  • 好处:
    • 对于大数据文件,我不用每次都占用大量内存去处理(个人理解)
    • 把数据从内存中保存到固定磁盘空间里
  • 注意事项:

    • 加载文件中一定要有你原来对象的定义,否则连它都不知道自己保存的是个啥东西
    • 文件以二进制文件进行读写,别文本文件
  • 常用函数

    1
    2
    3
    4
    5
    6
    7
    8
    9
    #来自pickle的说明
    '''
    Functions:

    dump(object, file)
    dumps(object) -> string
    load(file) -> object
    loads(string) -> object
    '''
    • pickle.dump(obj,file_obj,protocol=None)、

      • mist.py中的protocol是-1

      • protocol表示所用到的压缩协议(下面来自官方文档)

        The optional protocol argument, an integer, tells the pickler to use the given protocol; supported protocols are 0 to HIGHEST_PROTOCOL. If not specified, the default is DEFAULT_PROTOCOL. If a negative number is specified, HIGHEST_PROTOCOL is selected.

    • pickle.load(file_obj)

numpy库

首先一定要做的事就是

1
import numpy as np
  • 函数

    • np.frombuffer(obj,dtype,offset)
  • 用处:将data以流的形式读入转化成ndarray对象

    > 注意是以流的形式
    >
    > 在python里对字符串来说,unicode格式的字符串为流
    >
    > 所以如果传字符串进去要先加个b
    >
    > 如: b'hello' 就是字符流
    
    • dtype为返回的numpy格式

    • offset表示从哪个位置开始读,有时用来跳过魔数

mnist的完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# coding: utf-8
try:
import urllib.request
except ImportError:
raise ImportError('You should use Python 3.x')
import os.path
import gzip #解压压缩包用的,如gz
import pickle #本地保存python大数据,以二进制流保存指定对象到文件里,以便下次加载 保存格式为pkl,有保护性,
import os
import numpy as np


url_base = 'http://yann.lecun.com/exdb/mnist/' #数据集网址
key_file = {
'train_img':'train-images-idx3-ubyte.gz',
'train_label':'train-labels-idx1-ubyte.gz',
'test_img':'t10k-images-idx3-ubyte.gz',
'test_label':'t10k-labels-idx1-ubyte.gz'
} #要下载的文件

dataset_dir = os.path.dirname(os.path.abspath(__file__)) #获取当前目录的绝对路径
save_file = dataset_dir + "/mnist.pkl" #方便保存mnist数据集,先创建pkl格式的数据集

train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784


def _download(file_name):
file_path = dataset_dir + "/" + file_name

if os.path.exists(file_path): #文件存在就不要再下了,浪费空间,浪费时间
return

print("Downloading " + file_name + " ... ")
urllib.request.urlretrieve(url_base + file_name, file_path) #下载
print("Done")

def download_mnist(): #批量下载全部,通过字典
for v in key_file.values():
_download(v)

def _load_label(file_name): #将标签文件数组化
file_path = dataset_dir + "/" + file_name

print("Converting " + file_name + " to NumPy Array ...")
with gzip.open(file_path, 'rb') as f:
labels = np.frombuffer(f.read(), np.uint8, offset=8) #offset跳过魔数
print("Done")

return labels

def _load_img(file_name): #将图像文件数组化
file_path = dataset_dir + "/" + file_name

print("Converting " + file_name + " to NumPy Array ...")
with gzip.open(file_path, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16) #offset跳过魔数
data = data.reshape(-1, img_size)
print("Done")

return data

def _convert_numpy():
dataset = {}
dataset['train_img'] = _load_img(key_file['train_img'])
dataset['train_label'] = _load_label(key_file['train_label'])
dataset['test_img'] = _load_img(key_file['test_img'])
dataset['test_label'] = _load_label(key_file['test_label'])

return dataset

def init_mnist(): #初始程序,启动下载
download_mnist()
dataset = _convert_numpy()
print("Creating pickle file ...")
with open(save_file, 'wb') as f:
pickle.dump(dataset, f, -1)
print("Done!")

def _change_one_hot_label(X):
T = np.zeros((X.size, 10)) #因为是10个结果
for idx, row in enumerate(T):
row[X[idx]] = 1

return T


def load_mnist(normalize=True, flatten=True, one_hot_label=False):
"""读入MNIST数据集

Parameters
----------
normalize : 将图像的像素值正规化为0.0~1.0
one_hot_label :
one_hot_label为True的情况下,标签作为one-hot数组返回
one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
flatten : 是否将图像展开为一维数组

Returns
-------
(训练图像, 训练标签), (测试图像, 测试标签)
"""
if not os.path.exists(save_file):
init_mnist()

with open(save_file, 'rb') as f:
dataset = pickle.load(f)

if normalize:
for key in ('train_img', 'test_img'):
#因为是灰度图片,所以它的像素值为灰度值,像素值范围为0-255
dataset[key] = dataset[key].astype(np.float32)
dataset[key] /= 255.0

if one_hot_label:
dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
dataset['test_label'] = _change_one_hot_label(dataset['test_label'])

if not flatten:
for key in ('train_img', 'test_img'):
dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) #这里可能会报错,请把这里整成一行


if __name__ == '__main__':
init_mnist()