しんさんのブログ

科学や技術のこと読書のことなど

TensorFlow2.0への移行メモ

TensorFlow2.0への移行メモ

いままでnativeのKerasと組み合わせて使っていたのですが、2.0からはtf.kerasを使うように変更しました。
また、TensorFlow2.0からはeager executionがdefaultになったようですのでそれにまつわることも含めて自分用のメモを書きます。

eager executionは便利なのですが、速度的にはgraph modeの方が早いので、できれば両対応できる方がいいなと思い調べると、
@tf.functionデコレータをつけるとgraphモード(TF1.x系デフォルト)になり、つけないとEagerモード(TF2.0デフォルト)で実行されることがわかりました。
今、どちらのモードで動いているのかを確認するには、
print(tf.executing_eagerly()) # true or false
で確認できます。
今は、@tf.functionを関数の定義の前につけています。
ただし困ったことに@tf.functionを付けてgraph modeにするとerrorが出てしまいます。
エラーの箇所は、model.fit関数です。
modelはkerasのfunctional APIのlayer modelで作成しているのですが、Sequentialモデルでレイヤーを重ねていても同じ問題は起きると思います。
エラーの理由はどうやら学習データの受け渡しのデータ形式がnumpyのndarrayの形式であることに由来しているようです。
そこでデータをtf.data.Datasetで変換して渡すようにしましたがまだerrorが残ります。
model.fitがgraphモードではうまく動かないのかもしれないです。
そもそもTF2.0でeager executionがデフォルトになったことでKerasとtensorflowのlow levl APIを組み合わせるおすすめの方法はないものかと思い始めました。
すべてKerasで書いて,eager exec modeで動かす限りはTF1.xとコードの書き方は変わらないのですが、速度の問題もあり一部graph modeで動かそうとするとKerasのAPIだけでは対応できないようです。
TF2.0でのコーディングに関しては以下のブログにまとまっています。
qiita.com
qiita.com
https://colab.research.google.com/drive/1UCJt8EYjlzCs1H1d1X0iDGYJsHKwu-NO#scrollTo=FjLI719fPfJi
結局、eager execとgraph modeを両対応したいと思えば、network のlayerはkerasを使い、学習はkerasのfitは使わずにtensorflowの低レイヤーで書くというのがよさそうです。

パターン1) modelの定義も学習もすべてkerasを使用して書く。
従来のkerasを使用した書き方。コードを書くのは簡単だが、eager executionでしか実行できないので、
学習に時間がかかる。
パターン2)modelの定義はkeras.model.layerをそのまま使用。学習はtfで書く。
メリットはデバッグ時は@tf.functionをコメントアウトしてeager execution modeでdefine by runの形で実行して
デバッグしやすくして置き、学習時は@tf.functionを有効にして速度重視でdefine and runのgraph mode で実行する。
パターン2の書き方だと上記の両対応が可能.
モデル定義はKeras APIをそのまま使用しますが、データセットをnumpyからtf.dataを使うように変更。
traing をkerasを使わずに書く時はnumpyのままでは入力できない。
Data Augmentationが複雑になってくる場合もtf.dataの形式の方が都合がよい。
tf.image以下が充実しているので、前処理が簡単にかける。
パターン3)modelの定義をkeras.model.layerを継承した独自のクラスで書く。学習はtfで書く。
このパターンが書きやすさとデバッグのしやすさ、パフォーマンスのバランスが一番よさそうですので今後はTF2.0ではこの書き方に
統一していこうと思います。

tensorflowのTFRecordに関しては以下の記事を参照
www.tdi.co.jp
www.tdi.co.jp

ちなみに、TensorFlow のtf.Tensorと NumPy の ndarray 間の変換は以下を参照
テンソルと演算  |  TensorFlow Core

TF2.0のコードの書き方とPytorchとの比較に関しては以下のブログを参照

www.hellocybernetics.tech