クロの制作日記

tf.functionの注意点

最初に

Tensorflow2.x系ではtf.functionという機能が追加されてています。(tensorflow1.x系のeagerモードでも使えるかもしれませんが、確認していないです...)
www.tensorflow.org

Tensorflow2.x系はdefine by runで実行されるので、tensorflow1.xのように計算グラフを定義しないため、実行速度が遅くなるというデメリットがあります。

それを是正するのがtf.functionで、関数の上の行に「@tf.function」と記述することでその関数の計算グラフを作成してくれます。

じゃあ全部の関数にtf.functionを付ければめっちゃ早くなるんじゃない?となるかもしれませんが、そういうわけでもないようです。




以下が公式の注意点と推奨事項です。

  • オブジェクトの変更やリストへの追加のような Python の副作用に依存しないこと
  • tf.functions は NumPy の演算や Python の組み込み演算よりも、TensorFlow の演算に適していること
  • 迷ったときは、for x in y というイディオムを使うこと

これだけだと具体的にどうすりゃええんか分からないかもしれないので、上に貼ったサイト「tf.functionで性能アップ」に書いてあったことを抜粋しておきます。

引数はTensor

関数の引数を「training=True」とか「num_layers=10」のようにpythonのデータを渡してしまうとグラフを再トレースしてしまい、非効率になるそうです。

ただし、引数をTensorにするかTensorにキャストするとその問題を回避してくれるそうです。

tf.function内でのpythonコード

以下のようなコードを実行すると、f(x)の引数が変わらない限り、pythonの関数であるprint()が呼ばれないそうです。

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
実行結果
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

毎回関数を実行してほしいときは、上のようにtensorflowの関数を使うか、以下のように、tf.py_functionを使います。

external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

ただ、tf.py_function()は性能が高くないとか分散環境では上手く動作しないとかの問題もあるようですので、やはりtensorflowの関数を使った方がよさげかな。

イテレータ

pythoniter関数などを使って繰り返し処理を行うと、イテレータ全体がトレースされ、巨大な計算グラフを生成してしまう可能性があるそうです。

そうすると処理速度が遅くなってしまうので、tf.data.Datasetでデータを扱い、それをfor文で回すのが安全な方法らしいです。Tensorflow側ではtf,data.Datasetとfor文の組み合わせの場合、ループを安全に変換する機能があるとのこと。

変数の宣言は関数外

tf.functionの中で変数を宣言すると、関数を呼び出すたびに変数を再利用するのでなく新規作成してしまうので、基本的には変数宣言は関数外で行いましょう。

エラー
@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
OK
# しかし、曖昧さの無いコードは大丈夫

v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

print(f(1.0))  # 2.0
print(f(2.0))  # 4.0

他にも工夫すれば何とかなるらしいので、上の公式サイトをチェック

if文やwhile文、for文

これらの制御文は、条件式がTensorであればtf.condやtf.while_loopに変換してくれるそうです。ただ、tf.data.Datasetとfor文の組み合わせの場合は、tf.data.Dataset.reduceに変換されるみたい。

細かい注意点が結構あるので公式サイトをチェックしましょう。




結論

tf.functionを用いた関数は以下のことを最低でも気を付けるべきという感じみたいです。

  1. 引数はTensor
  2. 関数内ではTensorflowの関数を用いる
  3. 学習・教師データなどはDatasetで管理
  4. pythonの構文はif文、while文、for文は条件付きだが、使っても良い

tf.functionを用いて速度を上げたい場合は、define by runになったからとはいえども、tensorflowの関数やtensorを使って処理を行おうねっていうのが結論かなと思います。