Tensorflow1.x系でif文のような条件分岐を実装したい場合はtf.condを利用します(Tensorflow2.x系でも使用可能)。
www.tensorflow.org
tf.condは以下のような引数を渡してあげる必要があります。
tf.cond( pred # 条件式(直接True or Falseを指定可能) true_fn # 条件式がTrueの場合に実行する関数 false_fn # 条件式がFalseの場合に実行する関数 name # 名前(なくても可) }
tf.condの使用例はこんな感じです。
x = tf.Variable(2) def add(): return x+1 def sub(): return x-1 ans = tf.cond(x > tf.Variable(1), add, sub) sess = tf.Session() sess.run(tf.global_variables_initializer()) print(sess.run(ans))
注意点
ただし、tf.addに渡す関数は引数なしでないとエラーを吐かれます。
x = tf.Variable(2) def add(y): return x+y def sub(y): return x-y ans = tf.cond(x > tf.Variable(1), add(tf.Variable(1)), sub(tf.Variable(1))) sess = tf.Session() sess.run(tf.global_variables_initializer()) print(sess.run(ans))
TypeError: true_fn must be callable.
解決策
このエラーを解消するにはpythonのlambdaを利用します。
x = tf.Variable(2) def add(y): return x+y def sub(y): return x-y ans = tf.cond(x > tf.Variable(1), lambda : add(tf.Variable(1)), lambda : sub(tf.Variable(1))) sess = tf.Session() sess.run(tf.global_variables_initializer()) print(sess.run(ans))
lambdaの詳細が知りたい人はこちらのサイトを参考にしてください。
note.nkmk.me