MNIST 测试

使用MNIST数据集来做测试。


In [ ]:
import numpy as np
from datetime import datetime
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from sklearn import preprocessing
from neural_network import Network

lb = preprocessing.LabelBinarizer()
lb.fit([k for k in range(10)])

mnist = fetch_mldata('MNIST original', data_home='../datasets/MNIST')
min_max_scaler = preprocessing.MinMaxScaler()
data = min_max_scaler.fit_transform(mnist.data.astype('float64'))
X_train, X_test, y_train, y_test = train_test_split(
    data, mnist.target)
print('test:', X_test.shape)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train)
print('train:', X_train.shape)
print('validation:', X_val.shape)
print('label shape', y_val.shape)
img = np.random.choice(range(X_train.shape[0]))
img = X_train[img].reshape((28, 28))
plt.imshow(img, interpolation='nearest')
plt.show()

y_train = lb.transform(y_train)
net = Network([784, 300, 10])
batch_size = 2000
for epoc in range(100):
    time_start = datetime.now()
    for left in range(0, X_train.shape[0], batch_size):
        right = min(left+batch_size, X_train.shape[0])
        X = X_train[left:right, :]
        y = y_train[left:right, :]
        net.fit(X, y)
        y_pred = net.predict(X_val)
        y_pred = y_pred.argmax(1).reshape((-1,))
        eql = y_pred == y_val
        print('validation accuracy: ', np.sum(eql)/eql.shape[0])
    print('epoc: %d, time: %fs.'%(epoc, (datetime.now()-time_start).seconds))


test: (17500, 784)
train: (39375, 784)
validation: (13125, 784)
label shape (13125,)
D:\Code\myworks\机器学习算法实现\neural_network\neural_network.py:169: RuntimeWarning: overflow encountered in exp
  data = np.exp(data)
D:\Code\myworks\机器学习算法实现\neural_network\neural_network.py:171: RuntimeWarning: invalid value encountered in true_divide
  p = data / s
D:\Code\myworks\机器学习算法实现\neural_network\neural_network.py:173: RuntimeWarning: divide by zero encountered in log
  logp = np.log(p)
D:\Code\myworks\机器学习算法实现\neural_network\neural_network.py:174: RuntimeWarning: invalid value encountered in multiply
  return -np.sum(np.multiply(self.truth, logp), 1)
平均loss:nan
D:\Code\myworks\机器学习算法实现\neural_network\neural_network.py:38: RuntimeWarning: invalid value encountered in less
  cp[cp < 0] = 0
D:\Code\myworks\机器学习算法实现\neural_network\neural_network.py:45: RuntimeWarning: invalid value encountered in less
  cp[cp < 0] = 0
D:\Code\myworks\机器学习算法实现\neural_network\neural_network.py:46: RuntimeWarning: invalid value encountered in greater
  cp[cp > 0] = 1
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan
validation accuracy:  0.0976761904762
平均loss:nan

存在的问题

  • 算法的效率并不高,一个epoc大概要16s
  • 计算softmax时,偶尔会出现溢出的问题
  • 准确率并不高,最高才59%左右,和mnist官网上给的数据差距很大
  • batch模式训练时,loss值会来回震荡