クロの制作日記

tf.saved_model.load()でのTypeError: '_UserObject' object is not callable loadの対処法

以下のように学習を行ったモデルを読み込み、実行をしようとすると「TypeError: '_UserObject' object is not callable load」というエラーがでました。

学習と保存
class ClassifierModel(tf.keras.Model):

    # モデルのレイヤー定義
    def __init__(self, **kwargs):
        super(ClassifierModel, self).__init__(**kwargs)

        # 1層目
        self.fc1 = tf.keras.layers.Dense(HIDDEN_DIM, activation='relu')

        # 2層目
        self.fc2 = tf.keras.layers.Dense(HIDDEN_DIM, activation='relu')

        # 3層目
        self.fc3= tf.keras.layers.Dense(CLASS_NUM, activation='relu')
    
    # モデルの実行
    def call(self, x_t, training):
        
        # 1層目
        x = self.fc1(x_t)

        # 2層目
        x = self.fc2(x)

        # 3層目
        x = self.fc3(x)

        return x
model = ClassifierModel()

for i in range(100):
    #--------------
    モデルを学習
    #--------------

tf.saved_model.save(model, 'model/')
読み込みと実行
loaded_model = tf.saved_model.load('model/')

loaded_model(x, tf.cast(False, tf.bool))
エラー
*** TypeError: '_UserObject' object is not callable load




対処法

色々と対処法をネットの海を探しましたので、順を追って説明していきます。

tf.function

公式サイトに
www.tensorflow.org

以下のような文面がありました。

When you save a tf.Module, any tf.Variable attributes, tf.function-decorated methods, and tf.Modules found via recursive traversal are saved. (See the Checkpoint tutorial for more about this recursive traversal.) However, any Python attributes, functions, and data are lost. This means that when a tf.function is saved, no Python code is saved.

細かいニュアンスは理解できていないかもしれませんが、ようするに「tf.functionでデコレートしておかないとモデルを保存できませんよ」ということだと思います。

なので、モデルを以下のように書き換えました。

class ClassifierModel(tf.keras.Model):

    # モデルのレイヤー定義
    @tf.function
    def __init__(self, **kwargs):
        super(ClassifierModel, self).__init__(**kwargs)

        # 1層目
        self.fc1 = tf.keras.layers.Dense(HIDDEN_DIM, activation='relu')

        # 2層目
        self.fc2 = tf.keras.layers.Dense(HIDDEN_DIM, activation='relu')

        # 3層目
        self.fc3= tf.keras.layers.Dense(CLASS_NUM, activation='relu')
    
    # モデルの実行
    @tf.function
    def call(self, x_t, training):
        
        # 1層目
        x = self.fc1(x_t)

        # 2層目
        x = self.fc2(x)

        # 3層目
        x = self.fc3(x)

        return x
model = ClassifierModel()

for i in range(100):
    #--------------
    モデルを学習
    #--------------

tf.saved_model.save(model, 'model/')

すると、エラー内容が変わりました。

*** ValueError: Could not find matching function to call loaded from the SavedModel. Got:
  Positional arguments (2 total):
    * Tensor("x_t:0", shape=(256, 24), dtype=float64)
    * Tensor("training:0", shape=(), dtype=bool)
  Keyword arguments: {}

Expected these arguments to match one of the following 4 option(s):

Option 1:
  Positional arguments (2 total):
    * TensorSpec(shape=(256, 24), dtype=tf.float32, name='x_t')
    * TensorSpec(shape=(), dtype=tf.bool, name='training')
  Keyword arguments: {}

Option 2:
  Positional arguments (2 total):
    * TensorSpec(shape=(8, 24), dtype=tf.float32, name='x_t')
    * TensorSpec(shape=(), dtype=tf.bool, name='training')
  Keyword arguments: {}

Option 3:
  Positional arguments (2 total):
    * TensorSpec(shape=(252, 24), dtype=tf.float32, name='x_t')
    * TensorSpec(shape=(), dtype=tf.bool, name='training')
  Keyword arguments: {}

Option 4:
  Positional arguments (2 total):
    * TensorSpec(shape=(208, 24), dtype=tf.float32, name='x_t')
    * TensorSpec(shape=(), dtype=tf.bool, name='training')
  Keyword arguments: {}

「Positionnal arguments(2 tota):」以降は私の実行したコードによるものなので無視してください。こんな感じに、よくわからんエラーが出ました。下にあるOptionの通りに入力データをいじっても変わらず...。

「call」=> 「__call__」

エラー内容で検索してみたところ、以下のissueが見つかりました。
github.com

スレッドをずらーと読んでいると以下のように「call」を「__call__」のようにすれば?とあったので

class Model(tf.keras.Model):
    def __init__(self):
        super(Model, self).__init__()
        self.d = tf.keras.layers.Dense(2)

    # "dunder"-methods (__...__) typically are called implicitly by Python
    def __call__(self, x, training=True, mask=None):
        return self.d(x)
...
model(tf.random.normal((1, 3))) # no explicit .call here

以下のように改良してみるとエラーが消え、想定していた出力値になりました!!

class ClassifierModel(tf.keras.Model):

    # モデルのレイヤー定義
    @tf.function
    def __init__(self, **kwargs):
        super(ClassifierModel, self).__init__(**kwargs)

        # 1層目
        self.fc1 = tf.keras.layers.Dense(HIDDEN_DIM, activation='relu')

        # 2層目
        self.fc2 = tf.keras.layers.Dense(HIDDEN_DIM, activation='relu')

        # 3層目
        self.fc3= tf.keras.layers.Dense(CLASS_NUM, activation='relu')
    
    # モデルの実行
    @tf.function
    def __call__(self, x_t, training):
        
        # 1層目
        x = self.fc1(x_t)

        # 2層目
        x = self.fc2(x)

        # 3層目
        x = self.fc3(x)

        return x
model = ClassifierModel()

for i in range(100):
    #--------------
    モデルを学習
    #--------------

tf.saved_model.save(model, 'model/')

入力データ数が固定される...

エラーは消えて万々歳だったのですが、どうやら制限もあるようで...。

tf.functionを加えた際に出てきたこのエラー文を覚えていますでしょうか。

Option 1:
  Positional arguments (2 total):
    * TensorSpec(shape=(256, 24), dtype=tf.float32, name='x_t')
    * TensorSpec(shape=(), dtype=tf.bool, name='training')
  Keyword arguments: {}

Option 2:
  Positional arguments (2 total):
    * TensorSpec(shape=(8, 24), dtype=tf.float32, name='x_t')
    * TensorSpec(shape=(), dtype=tf.bool, name='training')
  Keyword arguments: {}

Option 3:
  Positional arguments (2 total):
    * TensorSpec(shape=(252, 24), dtype=tf.float32, name='x_t')
    * TensorSpec(shape=(), dtype=tf.bool, name='training')
  Keyword arguments: {}

Option 4:
  Positional arguments (2 total):
    * TensorSpec(shape=(208, 24), dtype=tf.float32, name='x_t')
    * TensorSpec(shape=(), dtype=tf.bool, name='training')
  Keyword arguments: {}



この公式サイトの
www.tensorflow.org]

以下の文面から判断するに、

Because no Python code is saved, calling a tf.function with a new input signature will fail:

どうやら、tf.functionはグラフの保存しかできないので、学習時(今回はミニバッチ学習を行った)に入力したデータのshape以外では実行が上手くいかないようです。

これの対処法がないかなーと思って探した結果、以下のissueで「最近のバージョンで直したからアップグレードしてみ」的なことが書いていました。
github.com

ただ、cudaのバージョンの問題でアップグレードができないので確認ができない...。

現段階では、問題がないのでとりあえず放置。確認できる人はやってみてください。