クロの制作日記

tf.functoinとtf.gatherを併用すると勾配が計算できない場合の対処法

最初に

Tenosorflow2.0においてtf.functionでデコレートした関数の内部でtf.gatherを使用し、勾配を計算すると以下のようなエラーが出ました。

AssertionError: Expected all args to be Tensors or Variables; but got CompositeTensor

エラーの原因

問題の根源はtf.gatherを実行するとTensorオブジェクトを 「tf.IndexedSlices」に変換してしまうことにあるらしい。以下の回答によると、標準のTensorオブジェクトなら勾配計算できるけど、tf.IndexedSlicesは対応していないっぽい。

stackoverflow.com


解決策1:Tensorflow2.1にUpdate

Tensorflow2.1では解決されている問題らしいのでTensorflow2.1にアップデートしましょう。
github.com

解決策2:よくわからないテクニックを駆使

GPUとかの環境依存でTensorflow2.0までしか使えないよって人はこちらです。
github.com

何かよくわからない(回答者もなぜ上手くいくかわからないらしい)方法ですが以下のようにtf.gatherをラップした関数を作成するとエラーが消えました。

@tf.function
def gather(x, ind):
    return tf.gather(x + 0, ind)

最近色々なエラーの解決策を調べましたが、この方法が一番摩訶不思議です。回答者の人はどうやってこの解決策にたどり着いたんだろう。




代償(WARNING)

よくわからん方法には代償が付き物です。私が動かした場合はエラーが消えましたが以下のようなWARNINGが大量に出たので使うときは注意してください。

WARNING:tensorflow:5 out of the last 39 calls to <function ooauc.gather at 
0x7fe14cae0510> triggered tf.function retracing. Tracing is expensive and the 
excessive number of tracings is likely due to passing python objects instead of tensors. 
Also, tf.function has experimental_relax_shapes=True option that relaxes argument 
shapes that can avoid unnecessary retracing. Please refer to 
https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args 
and https://www.tensorflow.org/api_docs/python/tf/function for more details.

tf.functionでデコレートした関数の中でtf.functionでデコレート関数を定義しているのがいけないっぽい?リトレーシングされちゃっているよってことらしいので処理速度が遅くなっているかもしれませんが、緊急性はないのでとりあえず放置しときます。

エラーのおよびWARNINGの解決策が他に知っている方は教えてくれたらうれしいです。