ぽよメモ

ファッション情報学徒の備忘録.

Optunaによる枝刈りとAsynchronous Successive Halving Algorithm

はじめに

PFNから発表されたハイパーパラメータ最適化ツールOptunaの記事が多数見受けられるようになってきました.Optunaは特に探索中に試行の枝刈りを行うことで,効率の良い探索を行うことができることが目玉の一つです.ここでは特に,Chainerと組み合わせる際の枝刈りの方法と,Optunaの採用する枝刈りのアルゴリズムについてまとめておきます.また,僕自身そこまで詳しいわけでは無いため,厳密な枝刈りによる効率化の度合い,その是非等についてはここでは議論しません*1. Optunaがパラメータの選択に採用しているアルゴリズムに関する情報は以下の記事が大変詳しく書かれています.

qiita.com

環境

  • Chainer v5.3.0
  • CuPy v5.3.0
  • Optuna v0.9.0

実行環境はUbuntu 18.04 LTSで,Docker,Docker-compose,nvidia-docker2を使用しました.また,NVIDIA Driverのバージョンは415.27です.

$ docker --version
Docker version 18.09.2, build 6247962
$ docker-compose --version
docker-compose version 1.23.2, build 1110ad01

ハードウェアは以下の通りです.

  • Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz 10コア20スレッド × 2
  • RAM 64GB
  • GPU GTX 1080 ti

使用するコード

特にどんなコードでも大差ないため,MNISTを使って例を示します.まず単純に最適化する場合のコードを以下に示します.

import argparse
import functools
import chainer
import numpy as np
import optuna
from chainer import links as L
from chainer import functions as F
from chainer import training
from chainer.training import extensions

# From: https://github.com/chainer/chainer/blob/v5/examples/mnist/train_mnist.py
# Copyright (c) 2015 Preferred Infrastructure, Inc.
# Copyright (c) 2015 Preferred Networks, Inc.
# Network definition
class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def forward(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

# 目的関数を設定する
def objective(trial, device, train_data, test_data):
    # trialからパラメータを取得
    n_unit = trial.suggest_int("n_unit", 8, 128)
    batch_size = trial.suggest_int("batch_size", 2, 128)
    n_out = 10
    epoch = 20

    # モデルを定義
    model = L.Classifier(MLP(n_unit, n_out))

    if device >= 0:
        chainer.backends.cuda.get_device_from_id(device).use()
        model.to_gpu()

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    train_iter = chainer.iterators.SerialIterator(train_data, batch_size)
    test_iter = chainer.iterators.SerialIterator(test_data, batch_size,
                                                 repeat=False, shuffle=False)
    updater = training.updaters.StandardUpdater(
                    train_iter, optimizer, device=device)
    trainer = training.Trainer(updater, (epoch, 'epoch'), out='output')

    # validationをするextensionを追加
    trainer.extend(extensions.Evaluator(test_iter, model, device=device))

    # 学習を実行
    trainer.run()

    # accuracyを評価指標として用いる
    return 1 - trainer.observation['validation/main/accuracy']

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('trials', type=int, help='Number of trials')
    parser.add_argument('-g', '--gpu', type=int, default=-1, help='GPU ID')
    args = parser.parse_args()

    np.random.seed(0)

    # MNISTデータを読み込む
    train, test = chainer.datasets.get_mnist()
    # 目的関数にパラメータを渡す
    obj = functools.partial(objective, device=args.gpu, train_data=train, test_data=test)
    # Studyを作成
    study = optuna.study.create_study(
        storage='sqlite:///optimize.db', study_name='prune_test', load_if_exists=True
    )
    # 最適化を実行
    study.optimize(obj, n_trials=args.trials)

    # Summaryを出力
    print("[Trial summary]")
    df = study.trials_dataframe()
    state = optuna.structs.TrialState
    print("Copmleted:", len(df[df['state'] == state.COMPLETE]))
    print("Pruned:", len(df[df['state'] == state.PRUNED]))
    print("Failed:", len(df[df['state'] == state.FAIL]))

    # 最良のケース
    print("[Best Params]")
    best = study.best_trial
    print("Accuracy:", 1 - best.value)
    print("Batch size:", best.params['batch_size'])
    print("N unit:", best.params['n_unit'])

試しに100回ほど最適化してみた結果,約4時間程度かかりました*2

$ python main.py 100 -g 0
[I 2019-03-24 09:44:51,751] A new study created with name: prune_test
[I 2019-03-24 09:48:36,239] Finished trial#0 resulted in value: 0.02411597967147827. Current best value is 0.02411597967147827 with parameters: {'n_unit': 57, 'batch_size': 23}.
# 省略
[I 2019-03-24 15:01:56,820] Finished trial#99 resulted in value: 0.023074626922607422. Current best value is 0.019391894340515137 with parameters: {'n_unit': 93, 'batch_size': 41}.
[Trial summary]
Copmleted: 100
Pruned: 0
Failed: 0
[Best Params]
Accuracy: 0.9806081056594849
Batch size: 41
N unit: 93

また,実際に使用したDockerfileと以下の実験を行ったスクリプトファイルはそれぞれこちらこちらにあります.

枝刈りとは

Optunaによる試行の枝刈りとは,

深層学習や勾配ブースティングなど、反復アルゴリズムが学習に用いられる場合、学習曲線から、最終的な結果がどのぐらいうまくいきそうかを大まかに予測することができます。この予測を用いて、良い結果を残すことが見込まれない試行は、最後まで行うことなく早期に終了させてしまうことができます。これが、Optuna のもつ枝刈りの機能になります。

となっています*3.Optunaではv0.9.0現在,2種類の枝刈り方法が存在します.以降,試行をtrial,その試行の中での学習のステップの単位をepochとします.

MedianPruner

アルゴリズム

最小化したい値(validationのlossやaccuracyなど)を定期的(1 epochごとなど)に報告し,その値を過去のtrialにおける同じepochにおける値と比較して,それらの中央値より悪ければ試行を止めるPrunerです.

例えば,以下のような試行が行われているとします(これらの報告されている値はvalidation lossだと考えてください).

. 1 epoch 2 epoch 3 epoch 4 epoch
1 trial 100 80 60 40
2 trial 120 100 90 80
3 trial 110 75 65 10

4 trialでは,1 epoch目で95,2 epoch目で90を取ったとすると,このとき,過去の3回の試行における各タイムステップの中央値は以下のようになり,

. 1 epoch 2 epoch 3 epoch 4 epoch
median 110 80 60 40

4 trialの2 epoch目の値90は過去の試行における2 epoch目の中央値80よりも大きいため,4 trialは破棄されます.

使い方

MedianPrunerをインスタンス化し,studyに渡します.

    # 途中省略
    # Prunerを作成
    pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=0)
    # Studyを作成してPrunerを指定
    study = optuna.study.create_study(
        storage='sqlite:///optimize.db', study_name='prune_test', pruner=pruner, load_if_exists=True
    )

Chainer用のインテグレーションを追加します.これはTrainerのExtensionで提供されています.これは僕がハマったことなのですが,実はpruner=NoneとしてStudyに渡すと,デフォルトでMedianPrunerが使われるため,以下のExtensionを追加するだけでPruneされてしまいます.Pruneされたくない人は注意してください((単に中間の値を全て保存しておきたいだけならtrial.report(値, step=epoch数)で出来ます.)).

    # 省略
    trainer = training.Trainer(updater, (epoch, 'epoch'), out='output')

    # validationをするextensionを追加
    trainer.extend(extensions.Evaluator(test_iter, model, device=device))

    # Optunaとのインテグレーションのためのextensionを追加
    # trialオブジェクト,監視するメトリクス,監視する頻度を指定
    integrator = optuna.integration.ChainerPruningExtension(
        trial, 'validation/main/accuracy', (1, 'epoch')
    )
    trainer.extend(integrator)

実行すると時折枝刈りされる試行が出てきます.

$ python main.py 100 -g 0
[I 2019-03-24 12:30:48,730] A new study created with name: prune_test
[I 2019-03-24 12:32:23,866] Finished trial#0 resulted in value: 0.033683180809020996. Current best value is 0.033683180809020996 with parameters: {'n_unit': 26, 'batch_size': 60}.
# 省略
 [I 2019-03-24 13:16:01,994] Setting status of trial#99 as TrialState.PRUNED. Trial was pruned at epoch 1.
[Trial summary]
Copmleted: 10
Pruned: 90
Failed: 0
[Best Params]
Accuracy: 0.978394627571106
Batch size: 108
N unit: 72

45分程度かかったようです.Accuracyは枝刈り無しよりも悪いですが,100回しか行っていないためもっと試行回数を増やせば良くなりそうです.

n_startup_trials

枝刈りを開始するまでの必要trial数を指定します.n_startup_trials=0なら試行の数にかかわらず,過去の試行の中央値よりも悪ければ破棄されます.例えば,最初に上げた例の2 trialは本来1 epoch目で中断されてしまいます.n_startup_trials=5とすると,5 trialまでは必ず実行され,6 trialから枝刈りが行われるようになります. デフォルトはn_startup_trials=5です.

n_warmup_steps

学習開始直後の学習曲線の傾きの角度が大きいものと,そうでないものがあるとき,小さいものの方が最終的な性能は良くなる場合もあるにもかかわらず,枝刈りが行われてしまうという可能性があります.これを回避するため,ある試行の中で,必ず実行するステップ数を決めることができます.例えばn_warmup_steps=5とすると,5 epoch目までは必ず全ての試行において実行され,6 epoch目から枝刈りが行われるようになります. この値を大きく取ればそれだけ様々なパラメータの可能性を見ることができますが,逆に枝刈りの効率は下がってしまいます.

SucccessiveHalvingPruner

v0.6.0から追加された,異なるアルゴリズムを採用したPrunerです.

アルゴリズム

単純なSuccessive Halvingアルゴリズム(SHA)ではなく,Asynchronous Successive Halvingアルゴリズム(ASHA)という改善された(?)アルゴリズムを採用していることがドキュメントに述べられています*4.ASHAについては以下の論文で提案されています.

arxiv.org

SHAそのものの解説はこちらの記事が詳しかったです.

adtech.cyberagent.io

SHAは学習の総時間を決め,途中まで学習をすすめてその中からよりよい組  1/\eta 個を抽出,学習を続けてまた上位  1/\eta 個を抽出……というのを繰り返すことで,良いパラメータでの学習に時間をかけるということのようです.ASHA自体の論文について詳しく読み込めているわけではないこと,こういったアルゴリズムに対して詳しいわけでもないことなどから,以下で述べることは全く正確性に欠ける可能性があることをご了承ください.

用語の定義

  •  n : ハイパーパラメータの組み合わせの数.
  •  R : 一つの試行におけるmaximum resource.
  •  r : 一つの試行におけるminimum resource.
  •  \eta : reduction factor.2以上の数値.
  •  s : minimum early stopping rate.
  • rung
    • 上位の組み合わせを抽出するための区切り?
    • ある  i について  n_i 個の試行を行うことを一つのrungとしている?
  • brackets
    • ある  n 個のハイパーパラメータの組についての最適化
  • 昇格
    • あるrungの上位の組み合わせを次のrungへと移行する(学習を継続する)こと

まず,論文中にSHAのアルゴリズムは以下の様に示されています.

f:id:pudding_info:20190324134012p:plain

また,  n = 9 R = 9 r = 1 s = 0 のとき以下の左図のようになり,異なる  s の組み合わせについて示したものが以下の右表になるようです.

f:id:pudding_info:20190324134743p:plain
Figure 1: Promotion scheme for SHA

 s = 0 のとき,最初のrungでは9個の試行が行われ,上位1/3だけが次のrungへ,そして最終的に1つが選出されていることが分かります.rungが進むごとに一つの試行に割り当てられるresource  r_i が増え,学習が進んでいることがわかります.

まず,SHAを単純に並列化(これを論文中では"synchronous" SHA,同期的SHAと呼んでいる)する上での問題は論文中に,

  1. あるrungは次のrungに進むために, n_i 個の試行が全て完了しないといけないため,stragglerやdropped job*5に弱い
  2. 試行するジョブが全て無くなったときには新しいbracketを追加するが,上位  1/\eta 個を選ぶのは各bracketについて独立であるため,bracketを並列しても,上位  1/\eta 個を選ぶパフォーマンスは向上しない
    • 自信無いです

というようなことが述べられているように見えます.これを解決するために,筆者らはASHAを提唱しており,アルゴリズムの概略は以下の様になっています.

f:id:pudding_info:20190324145915p:plain
Asynchronous Successive Halving Algorithm

同期と非同期の違いは以下の図のように表されるようです.

f:id:pudding_info:20190324150153p:plain
Figure 2: Comparison of promotion schemes for SHA and ASHA.

同期SHAでは,rung 1に進むためにrung 0の試行が全て完了するまで待つのに対して,ASHAでは,全ての完了を待たずに先にrung 1のジョブが実行されます.つまり,これまでに実行された  m 個の試行について,常に  1/\eta という比率を保つように次のrungの学習を行う,というようなことのようです(これも自信無い).もし昇格するジョブが無かったとき,単にbaseのrung(rung 0)に新しいjobを追加します.これはつまりパラメータの組の全体数  n を決めないということです.ASHAで必要なパラメータは,SHAのパラメータから  n を除いた全てです.

使い方

単にMedianPrunerをSuccessiveHalvingPrunerに置き換えれば良いです.記述は省略しますがChainerPruningExtensionも必要です.

    # 途中省略
    # Prunerを作成
    pruner = optuna.pruners.SuccessiveHalvingPruner(
        min_resource=1,
        reduction_factor=4,
        min_early_stopping_rate=0
    )
    # Studyを作成してPrunerを指定
    study = optuna.study.create_study(
        storage='sqlite:///optimize.db', study_name='prune_test', pruner=pruner, load_if_exists=True
    )

引数は3種類あり,

  • min_resource:論文中の  r
  • reduction_factor:論文中の \eta
  • min_early_stopping_rate:論文中の s

にそれぞれ相当します.なお,最大リソース数  R はパラメータとして渡しません.これは各trialの中で決まる値(Trainerに渡すEpoch数など)に当たるためだそうです.また,これらの値から

  • 最低限実行されるepoch数: e_{min}
  • Pruneされるタイミング: e_{prune}

などが以下の式で計算出来ます*6


\begin{aligned}
e_{min}  &= r \times \eta^{\left( s \right)} \\
e_{prune} &= r \times \eta^{\left( s + rung\right)}  \\
\end{aligned}


例えば,デフォルトの値(  r=1 \eta=4 s=0 )のとき,


\begin{aligned}
e_{min}  &= 1 \times 4^{\left( 0 \right)} \\
              &= 1 \\
e_{prune} &= r \times \eta^{\left(s + rung\right)}  \\
                 &= 1 \times 4^{\left(0 + rung \right)}
\end{aligned}


ここで, e_{prune} は例えばrung 0で他よりも良い成績であった場合,次にPruneされるのは rung=1として  e_{prune} = 4 と計算することが出来ます.つまり,1,4,16,64…epoch目でそれぞれPruneされるかどうかが決まります.実際に,100回実行してみた結果が以下です.

$ docker-compose up asha
[I 2019-03-24 12:30:48,446] A new study created with name: prune_test
[I 2019-03-24 12:32:16,421] Finished trial#0 resulted in value: 0.02267676591873169. Current best value is 0.02267676591873169 with parameters: {'batch_size': 65, 'n_unit': 119}.
# 省略
[I 2019-03-24 12:58:49,325] Setting status of trial#99 as TrialState.PRUNED. Trial was pruned at epoch 1.
[Trial summary]
Copmleted: 3
Pruned: 97
Failed: 0
[Best Params]
Accuracy: 0.9773232340812683
Batch size: 65
N unit: 119

かかった時間は30分程度でした.MedianPrunerの時と同じく,枝刈り無しで行ったときより結果が悪いのは気になりますが,こちらも同じく時間が大きく短縮されているのでもっと試行回数を増やして良さそうです.

まとめ

MedianPrunerの挙動はドキュメントを読んでだいたい理解したのですが,SuccessiveHalvingPrunerは「Successive Halving Algorithmの非同期版」という書かれ方がされており,全く分からなかったので論文を流し読みしてまとめてみました.やはり枝刈りを行うと圧倒的に時間が短縮されることも実際に確かめることが出来ました.積極的に活用していきたいですね.
内容に間違っている箇所があれば,コメントで優しく指摘して頂けると助かります.

*1:実際,各セクションで実際に最適化を実行していますが,それぞれ同時に動かしたりしています.1080tiなら大丈夫だろwと気軽にやっているので実行時間の正確さは期待しないでください.

*2:遅くないか…?

*3:research.preferred.jp

*4:https://optuna.readthedocs.io/en/stable/reference/pruners.html#optuna.pruners.SuccessiveHalvingPruner

*5:厳密な意味はわからないのですが,stragglerは他と比べて長い試行,dropped jobは実行中の失敗(メモリ不足とか,ノードが落ちたとか?)でしょうか

*6:これらの記号は僕が勝手に決めたもので特に意味はありません