クロの制作日記

Tensorflow2(subclassing API)でモデル・パラメータをsave(保存)・restore(復元)する方法

最初に

ニューラルネットワーク(NN)において、学習を反復して行うことでNNの精度を向上させますが、あまり学習を反復させすぎると過学習が起こる可能性があります。

そこで、過学習が起こる前、または最も精度が高い(何らかの評価指標において良い結果が出ている)場合のNNのパラメータを取得するためには、NNのパラメータを反復回数ごとにNNのパラメータを保存しておく必要があります。

今回は、Tensorflow2.0.0のsubclassing APIを用いて作成したNNのパラメータを保存する方法を紹介します。
  

Sequntial APIでパラメータ・モデルの保存方法

Sequantial APIでのパラメータ・モデルの保存はSubclassing APIでの方法と少し違いますので、以下のサイトを参考にしてください。
qiita.com

ちなみに、こんな感じにモデルの定義をするのがSequantial APIです。

Sequantial API
sequential_model = keras.Sequential(
    [keras.Input(shape=(784,), name='digits'),
     keras.layers.Dense(64, activation='relu', name='dense_1'), 
     keras.layers.Dense(64, activation='relu', name='dense_2'),
     keras.layers.Dense(10, name='predictions')])




Subclassing APIでのパラメータ・モデルの保存方法

Subclassing APIでのモデルの定義はこんな感じです。

Subclassing API
class ClassifierModel(tf.keras.Model):
    def __init__(self, **kwargs):
        super(ClassifierModel, self).__init__(**kwargs)
        
        self.fc1 = tf.keras.layers.Dense(100, activation='relu')
        self.fc2 = tf.keras.layers.Dense(100, activation='relu')
        self.fc3= tf.keras.layers.Dense(10 activation='softmax')
    
    def call(self, x_t):
        x = self.fc1(x_t)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

  

パラメータの保存

Tensorflowの機能であるcheckpointを使うことで、学習途中のパラメータを保存することができます。

詳しい内容は英語ですが、こちらを参考にしてください。
www.tensorflow.org
www.tensorflow.org
www.tensorflow.org

checkpointは以下のように定義します。

tf.train.Checkpoint
    ckpt = tf.train.Checkpoint(
        # 反復回数
        step=tf.Variable(0), 
        # 最適化関数
        optimizer=opt,
        # 定義したモデル 
        net=classifier_model
        )

optimizerとモデルの情報を渡す必要がありますので、それらの定義をしてからcheckpointを定義するようにしましょう。

次にCheckpointを保存するためにChekpointManagerを定義します。

ChekpointManager
manager = tf.train.CheckpointManager(
        # 先ほど定義したcheckpoint
        checkpoint=ckpt,
        # パラメータを保存するディレクトリの設定
        directory='model/' 
        # パラメータを何個保存するか
        max_to_keep=None)

max_to_keepにパラメータを保存する個数を指定してあげる、例えば50に設定すると、直近50個のパラメータを保存し、古いものは自動的に消去してくれます。

これでCheckpointを利用する準備が整いました。後は以下のように、学習を反復させるコードの中でパラメータを保存します。

checkpointでのパラメータの保存
for ite in range(100):
    #--------------------------
    # 学習を行うコードをここに
    #--------------------------
    
    # パラメータを保存
    manager.save(checkpoint_number=ckpt.step)
    # 反復回数を記録
    ckpt.step.assign_add(1)

  

パラメータをrestore

以下のように保存したパラメータのパスを渡してあげるとパラメータを反映させることができます。

restore
ckpt.restore('model/ckpt-0')

「ckpt-」の後ろの数字を変えてあげると任意の反復回数のパラメータを反映できます。

因みに、保存したときのモデルと同じ構造のモデルでないと反映させることができないので注意してください。




モデルの保存

パラメータの保存とロードはできたので、次はモデルの保存とロードをしていきましょう。使用する関数は「tf.saved_model.save」と「tf.saved_model.load」です

www.tensorflow.org
www.tensorflow.org



モデルの保存・ロードをするためには少しモデルの定義の仕方を変えないといけません。詳細は以下の記事を参照にしてください。
www.kuroshum.com

簡潔に言うと、「tf.function」の追加と「call」を「__call__」に変更します。

Subclassing API
class ClassifierModel(tf.keras.Model):
    @tf.function
    def __init__(self, **kwargs):
        super(ClassifierModel, self).__init__(**kwargs)
        
        self.fc1 = tf.keras.layers.Dense(100, activation='relu')
        self.fc2 = tf.keras.layers.Dense(100, activation='relu')
        self.fc3= tf.keras.layers.Dense(10 activation='softmax')
    
    @tf.function    
    def __call__(self, x_t):
        x = self.fc1(x_t)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

あとはモデルの保存のコードを書くだけです。「tf.saved_model.save」で指定のディレクトリに保存します。

モデルの保存
for ite in range(100):
    #--------------------------
    # 学習を行うコードをここに
    #--------------------------
    
    # パラメータを保存
    manager.save(checkpoint_number=ckpt.step)
    # 反復回数を記録
    ckpt.step.assign_add(1)

# パラメータをrestore
ckpt.restore('model/ckpt-0')

# restoreしたパラメータを保持しているモデルを保存
tf.saved_model.save(classifier_model, 'model/')

  

モデルをロード

checkpointの場合と同じように、モデルを保存したディレクトリを指定してあげるとモデルをロードしてくれます。

loaded_model = tf.saved_model.load('model/')

あとは「loaded_model(x_t)」みたいに入力データを渡してあげると実行できます。

最後に

Tensorflow1.x系とはモデルやパラメータの保存の仕方が変わっていた & subclassing APIでの書き方が違ったので、とりあえずまとめておきました。

何か間違えている箇所があればコメントしていただければ嬉しいです。