しんさんのブログ

科学や技術のこと読書のことなど

TensorflowのTutorial "Deep MNIST for Experts"を試した時のメモ

wshinya.hatenablog.com

前回は"MNIST For ML Beginners"の解説を読みプログラムを実行してみました。
今回は、以下のリンクから"Deep MNIST for Experts"の解説を読みプログラムを実行してみます。
Deep MNIST for Experts  |  TensorFlow

前回のBeginners編と今回のExperts編の一番大きな違いは、Beginners編では画像のローカル構造を無視してしまい、1次元のベクトルに展開してしまいましたが、Exprets編では2次元の画像の2次元のまま扱う手法としてCNNを用いているところです。
CNNを用いることで数字の形の特徴をとらえることができ、認識率が大きく向上することが期待されます。

  • 今回のネットワーク構成

畳み込み層
プーリング層
畳み込み層
プーリング層
全結合層
Softmax層

となっています。

2回の畳み込み層により画像のローカルな構造とよりグローバルな構造をとらえることができます。
そのあとは、多値分類問題として、全結合層+Softmax regressionという前回と同じような構造になっています。

まずはmnist_deep.py を実行してみました。
するとWARNINGがずらずらと表示されいますが、
"xxx is deprecated and will be removed in a future version."
という将来のバージョンでこのコマンドは使えなくなりますよという警告ですので、当面は無視しておいても問題なさそうです。
そして実際に学習が進み最終的には、
step 19900, training accuracy 1
test accuracy 0.9913
今度は、testデータで99%の精度を超えてきます。簡単なCNNですがかなりの精度がでますね。

1層目の畳み込み層の
W_conv1 = weight_variable([5, 5, 1, 32])
は、5x5x1(channel数、いまはgray scaleなので1)の畳み込みフィルターが32種類あるということを示しています。
2層目では、
W_conv2 = weight_variable([5, 5, 32, 64])
となっており、5x5x32のフィルターが64種類あることを意味しています。
ここでの32や64はネットワークの途中で学習して特徴量の数に対応しています。
また各層ごとに2x2でMax poolingしていますので,画像サイズは
28 x 28 -> 14 x 14 -> 7x7 と順次スケールされていきます。

2層目のmax poolingが終わった時点では7x7の画像が64枚出力されています。
これをreshapeして1次元ベクトルに変換しているのが
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
です。fully connected layerを構成するためのweight, W_fc1とバイアスb_fc1
をこれに作用させることで、全結合層計算ができます。
最後に活性化関数としてreluを適用してのが以下の行です.
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

このネットワークでは過学習を防止するためにdropout、つまり適当にネットワークの重みWをゼロにするという操作を行います。
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
初めの行はdropoutの際に接続が保持される確率を格納するためのplaceholderで学習時には0.5で推論時にはdropoutしないように1.0にセットしています。
最後に10種類の数値に対応するように単層のsoftmax回帰に相当する層を追加しています。

最適化は
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
とAdamを使用しています。

GradientDescentとAdamで結果がどれくらいちがうのかちょっと実験してみました。
学習率はどちらも0.0001
ADAM:
step 100, training accuracy 0.88
test accuracy 0.9195
SGD:
step 100, training accuracy 0.26
test accuracy 0.2062
まあ、予想通りですが、SGDに比べるとADAMの方がかなり収束が早いです。