keras手写数字识别--入门
程序
由于mnist数据集直接使用
(x_train, y_train), (x_test, y_test) = mnist.load_data()
这种加载方式,有时候由于网络原因,很难加载成功。为此,可以直接通过地址其地址下载下来。然后使用numpy加载一下数据就行。
# -*- coding: utf-8 -*-
import keras
# from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop
import matplotlib.pyplot as plt
import numpy as np
batch_size = 128
num_classes = 10
epochs = 20
#由于使用程序下载很困难,这里手动下载导入数据
# the data, shuffled and split between train and test sets
# (x_train, y_train), (x_test, y_test) = mnist.load_data()
path='F:/program_work/python_work/KerasTest/data/mnist.npz'
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
f.close()
x_train = x_train.reshape(60000, 784).astype('float32')
x_test = x_test.reshape(10000, 784).astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
# label为0~9共10个类别,keras要求格式为binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
# 全连接模型
model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax'))
model.summary()
#损失函数使用交叉熵
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy'])
#模型估计
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Total loss on Test Set:', score[0])
print('Accuracy of Testing Set:', score[1])
#预测
result = model.predict_classes(x_test)
correct_indices = np.nonzero(result == y_test)[0]
incorrect_indices = np.nonzero(result != y_test)[0]
plt.figure()
for i, correct in enumerate(correct_indices[:9]):
plt.subplot(3,3,i+1)
plt.imshow(x_test[correct].reshape(28,28), cmap='gray', interpolation='none')
plt.title("Predicted {}, Class {}".format(result[correct], y_test[correct]))
plt.figure()
for i, incorrect in enumerate(incorrect_indices[:9]):
plt.subplot(3,3,i+1)
plt.imshow(x_test[incorrect].reshape(28,28), cmap='gray', interpolation='none')
plt.title("Predicted {}, Class {}".format(result[incorrect], y_test[incorrect]))
plt.show()
上面程序中,我们可以查看一些训练集的例子。如下图所示:
![这里写图片描述][Image 1]
训练结果为:
![这里写图片描述][Image 1]
关于全连接的理解,可以参考李宏毅的ppt。
![这里写图片描述][Image 1]
损失函数通常使用的有以下两种。
![这里写图片描述][Image 1]
对应的程序为:
model.compile(loss='mse',
optimizer=RMSprop(),
metrics=['accuracy'])
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy'])
同时,模型的激活函数也有其他的,如ReLU,sigmoid等。对应的程序调整为:
model.add(Dense(num_classes, activation='relu'))
model.add(Dense(num_classes, activation='sigmoid'))
优化方式也可以调整为其他的,如Adam()或者SGD()等,对应的程序可以调整为:
model.compile(loss='categorical_crossentropy',
optimizer=Adam(),
metrics=['accuracy'])
model.compile(loss='categorical_crossentropy',
optimizer=SGD(lr=0.1),
metrics=['accuracy'])
[Image 1]:
还没有评论,来说两句吧...