グラフ機械学習と強化学習について

主にグラフ機械学習や強化学習手法を記載します。

RAY RLlib

Rayは分散処理を計算するためのAPIです。その中でも特にRLlibは強化学習に特化したライブラリになっています。

シミュレーション環境さえ用意できれば、強化学習はいかに並列計算を行うかが大事になってきます。 Open MPIが有名かと思いますが、Rayを使えばノード間分散処理といった面倒な実装のところも簡単に実装することができます。

docs.ray.io

使い方

非常に簡単です。例えばCartPoleのタスクをPPOを32並列、1GPUで実行するには以下のようにconfigを設定します。

from ray.rllib.agents.ppo import PPOTrainer

# Configure the algorithm.
config = {
    "env": "CartPole-v1",
    "num_workers": 32,
    "num_gpus": 1,
    "framework": "torch", #  or "tf"
    "model": {
        "fcnet_hiddens": [64, 64],
        "fcnet_activation": "relu",
    },
    "evaluation_num_workers": 1,
    "evaluation_config": {
        "render_env": False,
    }
}

# Create our RLlib Trainer.
trainer = PPOTrainer(config=config)

for _ in range(100):
    print(trainer.train())

# Evaluate the trained Trainer (and render each timestep to the shell's output).
trainer.evaluate()

また、DQNの優先度付経験再生のサンプリングを分散に行えるようにしたAPEX(Distributed Prioritized Experience Replay)を用いる場合も

from ray.rllib.agents.dqn import ApexTrainer
trainer = ApexTrainer(config=config)

と変更するだけです。アルゴリズムは代表的なものが多数実装されています。

あとは計算資源を用意してあげるだけです。ノード間並列もray.init()にてIPを設定してあげればできます。 AWS上のEKSなどでAutoScalingもできそうですが、こちらもいずれ試していきたいと思います。

計算結果はTensorboardに記録されます。だいたい2分くらいで最大の報酬500に到達しています。

f:id:udnp:20220205185454p:plain

独自の環境、カスタムモデル、損失関数など拡張性ができるようになっています。まずはRayに実装されている強化学習手法を試してダメならモデルや損失関数を変更するというのがよさそうです。

Trainer config

強化学習の細かい設定などはTrainer configから参照できます。

https://docs.ray.io/en/latest/rllib-training.html#common-parameters