最初に
Tenosorflow2.0においてtf.functionでデコレートした関数の内部でtf.gatherを使用し、勾配を計算すると以下のようなエラーが出ました。
AssertionError: Expected all args to be Tensors or Variables; but got CompositeTensor
解決策1:Tensorflow2.1にUpdate
Tensorflow2.1では解決されている問題らしいのでTensorflow2.1にアップデートしましょう。
github.com
解決策2:よくわからないテクニックを駆使
GPUとかの環境依存でTensorflow2.0までしか使えないよって人はこちらです。
github.com
何かよくわからない(回答者もなぜ上手くいくかわからないらしい)方法ですが以下のようにtf.gatherをラップした関数を作成するとエラーが消えました。
@tf.function def gather(x, ind): return tf.gather(x + 0, ind)
最近色々なエラーの解決策を調べましたが、この方法が一番摩訶不思議です。回答者の人はどうやってこの解決策にたどり着いたんだろう。
代償(WARNING)
よくわからん方法には代償が付き物です。私が動かした場合はエラーが消えましたが以下のようなWARNINGが大量に出たので使うときは注意してください。
WARNING:tensorflow:5 out of the last 39 calls to <function ooauc.gather at 0x7fe14cae0510> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
tf.functionでデコレートした関数の中でtf.functionでデコレート関数を定義しているのがいけないっぽい?リトレーシングされちゃっているよってことらしいので処理速度が遅くなっているかもしれませんが、緊急性はないのでとりあえず放置しときます。
エラーのおよびWARNINGの解決策が他に知っている方は教えてくれたらうれしいです。