はじめに
前回こんな記事を書きました.
本当は今回の記事もまとめて1つで公開する予定だったのですが長くなりすぎたので分割しました.
環境
環境は全て前回の記事と同様です.
- Chainer v5.3.0
- CuPy v5.3.0
- Optuna v0.9.0
枝刈りと過学習
当初,Optunaのプレスリリースにあった「学習曲線から、最終的な結果がどのぐらいうまくいきそうかを大まかに予測する」という一文から,「過学習を起こしそうだったら早めに切る」という意味だと誤解していました.実際にはその試行内ではなく,過去の試行との比較を行うため,これは全く意味が異なってしまいます.これに関連したIssueは以下です.
必要な部分だけ抽出すると,
- Optunaの枝刈りは過学習を検知するものではない
- 過学習を気にするなら,例えばChainerなら
chainer.training.triggers.EarlyStoppingTrigger
を使うように
ということです.
EarlyStopping
これは,モデルの学習の収束を判定するための方法です.何らかの指標,例えばvalidation lossを監視し,train lossは減少し続けるのに対して,validation lossが改善されなくなった場合,学習を打ち切ります.ChainerではTriggerとして実装されています.
# 1 epochごとにvalidationのaccuracyを監視し,3回以上改善しなければstopする early_trigger = training.triggers.EarlyStoppingTrigger( check_trigger=(1, "epoch"), monitor="validation/main/accuracy", patients=3, mode="max", max_trigger=(epoch, "epoch") ) # `(epoch, "epoch")`の代わりに上記のTriggerを渡す trainer = training.Trainer(updater, early_trigger, out='output')
これはあくまで終了するだけなので,学習終了後にそのパラメータを読み出したりはしてくれません.そういうことがしたければ以下の記事が参考になります.
タイミング
適当な値を出しますが,こんな感じのlossの推移があり,
epoch | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
loss | 100 | 80 | 60 | 40 | 20 | 30 | 25 | 28 |
EarlyStoppingTriggerのパラメータとして以下の物を渡したとします.
patients=3
mode="min"
max_trigger=(10, "epoch")
patients
は「最小値からいくつ連続して値が改善しなかった場合に学習を止めるか」というパラメータです.例の場合,最小値は5 epoch目の20です.以降,30 → 25 → 28と値が上下していますが,一貫して最小値20を上回っているため,8 epochで中断されます.
mode
には,最小,最大どちらの方向で値を監視するかを設定します.lossならば最小の方向になり,accuracyなら最大の方向に監視することになると思います.デフォルトは"auto"ですが,明示した方がトラブルはないんじゃないかなと思います.
max_trigger
は値が改善され続けたときにいくつまで学習するかを設定します.
OptunaとEarlyStopping
Optunaは,あくまでも目的関数が返す最後の値を最適なパラメータの選出に使用します.枝刈りを行っても行わなくても,学習過程で記録した最小値が利用されるわけではないため注意が必要です.これを簡単なサンプルで示してみます.
import optuna def objective(trial): sample_losses = [ [200, 90, 52, 31, 15, 7, 17, 28, 45, 56], # A [143, 82, 56, 40, 26, 18, 24, 23, 26, 28] # B ] losses = sample_losses[trial.number] # 途中経過を報告する for i, loss in enumerate(losses): trial.report(loss, step=i) # 最後の値を返す return losses[-1] if __name__ == "__main__": study = optuna.study.create_study("sqlite:///test.db") study.optimize(objective, 2) # 全ての試行のvalueをprint print("[Trials]") for t in study.trials: # Trialの番号,その時の値,値の推移 print(t.number, t.value, t.intermediate_values) # Optunaが選んだbestなtrial best = study.best_trial print("[Best]") print("Number:", best.number) print("Value:", best.value)
実行されるTrialの順に応じて異なる数列をOptunaへ報告する目的関数を設定し,これを最適化させてみます.この sample_losses
をプロットすると以下の様になります.
見て分かるように,実際には試行Aの方が6 epochで(epochではないですが便宜上の単位として使います)最も低い値を記録しますが,Optunaは試行Bをbest trialとして選出します.
[I 2019-03-24 23:45:29,729] A new study created with name: no-name-22ecd572-e23d-4ce4-8370-26a12267b372 [I 2019-03-24 23:45:29,830] Finished trial#0 resulted in value: 56.0. Current best value is 56.0 with parameters: {}. [I 2019-03-24 23:45:29,918] Finished trial#1 resulted in value: 28.0. Current best value is 28.0 with parameters: {}. [Trials] 0 56.0 {0: 200.0, 1: 90.0, 2: 52.0, 3: 31.0, 4: 15.0, 5: 7.0, 6: 17.0, 7: 28.0, 8: 45.0, 9: 56.0} 1 28.0 {0: 143.0, 1: 82.0, 2: 56.0, 3: 40.0, 4: 26.0, 5: 18.0, 6: 24.0, 7: 23.0, 8: 26.0, 9: 28.0} [Best] Number: 1 Value: 28.0
これは嬉しくありません.正しく最も良い値で判断して欲しいところです.そこで,最終的に返す値を変えます.途中経過の値は枝刈りに使われるのみ*1なので,無視できます.
import optuna def objective(trial): sample_losses = [ [200, 90, 52, 31, 15, 7, 17, 28, 45, 56], # A [143, 82, 56, 40, 26, 18, 24, 23, 26, 28] # B ] losses = sample_losses[trial.number] # 途中経過を報告する for i, loss in enumerate(losses): trial.report(loss, step=i) # 最小値を返す losses.sort() return losses[0] if __name__ == "__main__": # 省略
実行してみます.
[I 2019-03-25 00:02:33,012] A new study created with name: no-name-36544734-db8e-4478-83d5-314d3d999c7b [I 2019-03-25 00:02:33,122] Finished trial#0 resulted in value: 7.0. Current best value is 7.0 with parameters: {}. [I 2019-03-25 00:02:33,223] Finished trial#1 resulted in value: 18.0. Current best value is 7.0 with parameters: {}. [Trials] 0 7.0 {0: 200.0, 1: 90.0, 2: 52.0, 3: 31.0, 4: 15.0, 5: 7.0, 6: 17.0, 7: 28.0, 8: 45.0, 9: 56.0} 1 18.0 {0: 143.0, 1: 82.0, 2: 56.0, 3: 40.0, 4: 26.0, 5: 18.0, 6: 24.0, 7: 23.0, 8: 26.0, 9: 28.0} [Best] Number: 0 Value: 7.0
このように,結局は目的関数の返す値によって決定されることが分かりました.よってOptunaで最適化する際には,その試行の中の最良の値を返す必要があるように思われます*2.
枝刈りを行う際の更なる注意点として,Optunaは枝刈りした試行についてPRUNED
というステータスで記録しますが,best trialの選出には PRUNED
のものは含まれません*3.これは,過去の同じステップと比べて値が悪化しているのだから当然と考えられます.しかし,前述のように,学習の過程でベストな値を取っても,その後改善せずむしろ過学習により劣化した場合に枝刈りされる可能性は依然としてあり,その場合は本来最適であったパラメータが見逃されることになります.
これはEarlyStoppingを用いても抑制はできるでしょうが完全に防ぐことは出来ないと考えています.その仕組み上,最良の値からいくつかぶん学習を進める必要があるため,その長さ(patients
)の分だけ枝刈りの可能性が残されてしまうためです.あまり長い patients
を設定することは避けるべきかと思います.
実際にやってみた
前回の記事で使用した実験コードに更に手を加える形で実装しました.
全体は以下にあります.
optuna-sample/main.py at 1b7cfccea08b4a2255ff685d931f746ce0de2007 · pddg/optuna-sample · GitHub
# 省略 early_trigger = training.triggers.EarlyStoppingTrigger( check_trigger=(1, "epoch"), monitor="validation/main/accuracy", patients=3, mode="max", max_trigger=(epoch, "epoch") ) trainer = training.Trainer(updater, early_trigger, out='output') # 実行中のログを取る log_reporter = extensions.LogReport() trainer.extend(log_reporter) # 省略 # 学習を実行 trainer.run() # Accuracyが最大のものを探す observed_log = log_reporter.log observed_log.sort(key=lambda x: x['validation/main/accuracy']) best_epoch = observed_log[-1] # 何epoch目がベストだったかを記録しておく trial.set_user_attr('epoch', best_epoch['epoch']) # accuracyを評価指標として用いる return 1 - best_epoch['validation/main/accuracy']
上記のコードの途中でTrialオブジェクトに対してuser_attr
として最良であった場合のEpoch数を記録していますが,これは後から以下の様にして取り出すことが出来ます.
print("[Best Params]") best = study.best_trial print("Epoch:", best.user_attrs.get('epoch'))
これを用いて,枝刈り無し,MedianPrunerによる枝刈り有り,SuccessiveHalvingPrunerによる枝刈り有りの3種類でそれぞれ100回の最適化を行いました.EarlyStopping無しの結果については前回の記事をご覧ください.また,前回から繰り返し書いていますがこれは厳密な時間測定ではなく,なんとなく感覚を掴んでいるだけですので,悪しからず.
枝刈り無し
[I 2019-03-24 15:57:00,465] A new study created with name: prune_test [I 2019-03-24 15:57:42,481] Finished trial#0 resulted in value: 0.03238105773925781. Current best value is 0.03238105773925781 with parameters: {'n_unit': 36, 'batch_size': 105}. # 省略 [I 2019-03-24 20:16:07,683] Finished trial#99 resulted in value: 0.022976338863372803. Current best value is 0.018449485301971436 with parameters: {'n_unit': 95, 'batch_size': 37}. [Trial summary] Copmleted: 100 Pruned: 0 Failed: 0 [Best Params] Epoch: 9 Accuracy: 0.9815505146980286 Batch size: 37 N unit: 95
4時間強程度の時間がかかりました.思ったより全然時間を短縮できませんでしたね.今回は最大で20 epoch学習をおこなっているのですが,これが思ったより多すぎなかったということなのでしょうか. とはいえ,Best Paramsを見て頂ければ分かるように,9 epoch目でベストの値をたたき出していることが分かります.
MedianPrunerによる枝刈り
[I 2019-03-24 15:56:47,192] A new study created with name: prune_test [I 2019-03-24 15:57:31,076] Finished trial#0 resulted in value: 0.02388054132461548. Current best value is 0.02388054132461548 with parameters: {'batch_size': 67, 'n_unit': 116}. # 省略 [I 2019-03-24 16:29:43,264] Setting status of trial#99 as TrialState.PRUNED. Trial was pruned at epoch 1. [Trial summary] Copmleted: 14 Pruned: 86 Failed: 0 [Best Params] Epoch: 8 Accuracy: 0.9790022373199463 Batch size: 69 N unit: 125
約30分程度で済んでいます.EarlyStopping無しで行った時よりも多少時間が短縮できているようですが,たまたまかも知れません.こちらもBest Paramsは8 Epoch目と比較的早い段階で収束していることがわかります.
SuccessiveHalvingPrunerによる枝刈り
[I 2019-03-24 15:56:58,310] A new study created with name: prune_test [I 2019-03-24 15:58:00,723] Finished trial#0 resulted in value: 0.023097515106201172. Current best value is 0.023097515106201172 with parameters: {'batch_size': 61, 'n_unit': 70}. # 省略 [I 2019-03-24 16:26:50,098] Setting status of trial#99 as TrialState.PRUNED. Trial was pruned at epoch 1. [Trial summary] Copmleted: 8 Pruned: 92 Failed: 0 [Best Params] Epoch: 9 Accuracy: 0.9769024848937988 Batch size: 61 N unit: 70
30分弱程度で完了しました.やはり枝刈りは強力ですね.こちらも9 Epoch目で学習が収束していることから,今回のサンプルネットワークを用いたMNISTの学習では8,9 epochあたりで十分収束するということでしょうか(もちろん触るパラメータ次第だとは思いますが).
まとめ
ChainerのEarlyStoppingTriggerは簡単に使えて強力ですので,無意味に長い学習を行って計算リソースや時間を無駄に消費したくない方は是非導入してみてはいかがでしょうか.
また枝刈り有り・無しの場合で,かかる時間と得られる最適化の妥当さのバランスがどうなっていくのか,上記の結果を見る限り同じ100回の最適化でもそれぞれ異なるパラメータに行き着いており,最終的にどこに収束していくのか,気になります.
あとこれはどなたかご存じの方がいらっしゃれば教えて頂きたいのですが,こういった最適化のようなタスク,および実際の学習において numpy.random.seed(0)
のようにseed値を固定すべきなのでしょうか.再現性を中途半端に考慮するより最初からランダムにし,複数回行って平均等を見るべきなのでしょうか.
深層学習は難しいですね.
*1:と解釈しているのですが,パラメータのサンプリングに使われたりするのでしょうか
*2:むしろOptunaはなぜ途中で値を報告させる機能を有しているにも関わらず,それらを考慮しないのでしょう.
*3:少なくともv0.9.0ではそうなっています. optuna/base.py at v0.9.0 · pfnet/optuna · GitHub