手書き数字(MNIST)をTensorFlowの単純なニューラルネットワークで判別

TensorFlowのディープラーニングの練習として、手書き数字(MNIST)をTensorFlowの単純なニューラルネットワークで判別させました。

Google Colaboratoryで実行しました。

TensorFlowのバージョンは2.3.0です。

!pip list | grep tensorflow
tensorflow                    2.3.0          
tensorflow-addons             0.8.3          
tensorflow-datasets           2.1.0          
tensorflow-estimator          2.3.0          
tensorflow-gcs-config         2.3.0          
tensorflow-hub                0.9.0          
tensorflow-metadata           0.24.0         
tensorflow-privacy            0.2.2          
tensorflow-probability        0.11.0         

以下はGoogle Colaboratoryで実行したコードです。

まずはインポート。

import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist

出力画像の大きさを指定。デフォルトでは小さいため。

plt.rcParams['figure.figsize'] = (16.0, 7.0)

MNISTのデータをロード。

(x_train, y_train), (x_test, y_test) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

教師データの画像の一部を表示してみる。

for i in range(32):
  plt.subplot(5, 16, i + 1)
  plt.imshow(x_train[i], cmap='gray')

f:id:suzuki-navi:20200923204325p:plain

単純なニューラルネットワークで学習。

# 入力と出力サイズ
in_size = 28 * 28
out_size = 10

# 入力データを 28*28 の2次元配列から 784 の1次元配列に変換
x_train_reshape = x_train.reshape(-1, in_size).astype('float32') / 255
x_test_reshape = x_test.reshape(-1, in_size).astype('float32') / 255

# one-hot encoding
y_train_onehot = tf.keras.backend.one_hot(y_train, out_size)
y_test_onehot = tf.keras.backend.one_hot(y_test, out_size)

# モデル構造を定義
def createModel():
  hidden_size = 64
  model = tf.keras.models.Sequential()
  model.add(tf.keras.layers.Dense(hidden_size, activation='relu', input_shape=(in_size,)))
  model.add(tf.keras.layers.Dense(out_size, activation='softmax'))
  return model

model = createModel()

# モデルを構築
model.compile(
    loss = "categorical_crossentropy",
    optimizer = "adam",
    metrics=["accuracy"])

# 学習を実行
result = model.fit(x_train_reshape, y_train_onehot,
    batch_size=50,
    epochs=10,
    verbose=1,
    validation_data=(x_test_reshape, y_test_onehot))

# 学習の様子をグラフへ描画
def plotLearning():
  # ロスの推移をプロット
  plt.plot(result.history['loss'])
  plt.plot(result.history['val_loss'])
  plt.title('Loss')
  plt.legend(['train', 'test'], loc='upper left')
  plt.show()

  # 正解率の推移をプロット
  plt.plot(result.history['accuracy'])
  plt.plot(result.history['val_accuracy'])
  plt.title('Accuracy')
  plt.legend(['train', 'test'], loc='upper left')
  plt.show()

plotLearning()
Epoch 1/10
1200/1200 [==============================] - 2s 2ms/step - loss: 0.3328 - accuracy: 0.9073 - val_loss: 0.1855 - val_accuracy: 0.9454
Epoch 2/10
1200/1200 [==============================] - 2s 2ms/step - loss: 0.1622 - accuracy: 0.9529 - val_loss: 0.1375 - val_accuracy: 0.9593
Epoch 3/10
1200/1200 [==============================] - 2s 2ms/step - loss: 0.1192 - accuracy: 0.9649 - val_loss: 0.1159 - val_accuracy: 0.9668
Epoch 4/10
1200/1200 [==============================] - 2s 2ms/step - loss: 0.0936 - accuracy: 0.9723 - val_loss: 0.0971 - val_accuracy: 0.9718
Epoch 5/10
1200/1200 [==============================] - 2s 2ms/step - loss: 0.0761 - accuracy: 0.9773 - val_loss: 0.0950 - val_accuracy: 0.9711
Epoch 6/10
1200/1200 [==============================] - 2s 2ms/step - loss: 0.0647 - accuracy: 0.9806 - val_loss: 0.0878 - val_accuracy: 0.9729
Epoch 7/10
1200/1200 [==============================] - 2s 2ms/step - loss: 0.0552 - accuracy: 0.9834 - val_loss: 0.0947 - val_accuracy: 0.9714
Epoch 8/10
1200/1200 [==============================] - 2s 2ms/step - loss: 0.0484 - accuracy: 0.9858 - val_loss: 0.0822 - val_accuracy: 0.9759
Epoch 9/10
1200/1200 [==============================] - 2s 2ms/step - loss: 0.0423 - accuracy: 0.9874 - val_loss: 0.0825 - val_accuracy: 0.9747
Epoch 10/10
1200/1200 [==============================] - 2s 2ms/step - loss: 0.0362 - accuracy: 0.9891 - val_loss: 0.0837 - val_accuracy: 0.9738

f:id:suzuki-navi:20200923204346p:plain

f:id:suzuki-navi:20200923204400p:plain

テストデータを判別させてみる。(赤字が判別結果。たまに誤判定があるのがわかる)

pred = model.predict(x_test_reshape[0:64])
for i in range(64):
  plt.subplot(5, 16, i + 1)
  plt.imshow(x_test[i], cmap="gray")
  plt.text(0, -2, str(np.argmax(pred[i])), size="xx-large", color="#FF0000")

f:id:suzuki-navi:20200923204421p:plain

ニューラルネットワークの構造を画像で表示。

tf.keras.utils.plot_model(model, show_shapes=True)

f:id:suzuki-navi:20200923204440p:plain