クロの制作日記

tf.condで引数ありの関数を使いたい場合

Tensorflow1.x系でif文のような条件分岐を実装したい場合はtf.condを利用します(Tensorflow2.x系でも使用可能)。
www.tensorflow.org

tf.condは以下のような引数を渡してあげる必要があります。

tf.cond(
    pred     # 条件式(直接True or Falseを指定可能)
    true_fn  # 条件式がTrueの場合に実行する関数
    false_fn # 条件式がFalseの場合に実行する関数
    name     # 名前(なくても可)
}

tf.condの使用例はこんな感じです。

x = tf.Variable(2)

def add():
    return x+1

def sub():
    return x-1

ans = tf.cond(x > tf.Variable(1), add, sub)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run(ans))

注意点

ただし、tf.addに渡す関数は引数なしでないとエラーを吐かれます。

x = tf.Variable(2)

def add(y):
    return x+y

def sub(y):
    return x-y

ans = tf.cond(x > tf.Variable(1), add(tf.Variable(1)), sub(tf.Variable(1)))

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run(ans))
TypeError: true_fn must be callable.




解決策

このエラーを解消するにはpythonのlambdaを利用します。

x = tf.Variable(2)

def add(y):
    return x+y

def sub(y):
    return x-y

ans = tf.cond(x > tf.Variable(1), lambda : add(tf.Variable(1)), lambda : sub(tf.Variable(1)))

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run(ans))

lambdaの詳細が知りたい人はこちらのサイトを参考にしてください。
note.nkmk.me