クロの制作日記

Tensorflowのwhile_loopで配列に値を結合する方法

最初に

tf.while_loopで以下のように配列に値を結合していくとエラーが吐かれます。

def add(x):
    return x+1

def condition(_i, x):
    return _i < 4

def update(_i, x):
    
    x = tf.concat([x, [_i]],axis=0)

    return _i+1, x

ind = tf.constant(0)
v1 = tf.Variable([])
print(v1)
init_val = [ind, v1]

_, loop = tf.while_loop(cond=condition, body=update, loop_vars=init_val)

loop = loop.stack()

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run(loop))
ValueError: Input tensor 'Variable_2/read:0' enters the loop with shape (0,), but has shape (1,) after one iteration. To allow the shape to vary across iterations, use the `shape_invariants` argument of tf.while_loop to specify a less-specific shape

どうやら、tf.while_loopの初期状態(loop_vars)のデータ型やshapeがループ中に変更していることが原因のようです。




解決方法

以下の質問に解決方法が載っていました。
stackoverflow.com

shape_invariants

while_loopの引数であるshape_invariabntsにshapeを指定してあげると良いようです。

init_valsは2個めの要素が配列になっているので、そのshapeをNoneにすると配列に値を結合しても問題ないっぽい。

_, loop = tf.while_loop(cond=condition, body=update, loop_vars=init_val, shape_invariants=[ind.get_shape(), tf.TensorShape([None])])

TensorArray

初期状態の配列をTensorArrayで定義すると良いようです。

TensorArrayはstack()を実行するまでは配列に値が結合されていない(結合する値が登録されているだけ?)らしく、while_loopのループ中はshapeが変わらないので、エラーが吐かれないっぽい。

def add(x):
    return x+1

def condition(_i, x):
    return _i < 4

def update(_i, x):
    
    x = x.write(_i, add(_i))    

    return _i+1, x

ind = tf.constant(0)
v1 = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
init_val = [ind, v1]
_, loop = tf.while_loop(cond=condition, body=update, loop_vars=init_val)

loop = loop.stack()

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run(loop))