グラフ機械学習と強化学習について

主にグラフ機械学習や強化学習手法を記載します。

Fitted Q-iteration

久しぶりの記事です。

オフライン強化学習を真面目に使いこなしていきたい。

ということでオフライン強化学習の中では基本的な手法であるFitted Q-iterationについてみていきます。D. Ernstらによって2005年に提案されています。

手法理解を優先とするため厳密さに欠けるところがあると思いますが、ご容赦ください。 Neural fitted Q-iterationやDeep Q-networkの基礎となっている手法です。

Value Iteration

強化学習reinforcement learningは、エージェントが環境とのやり取りを行いながら、累積報酬が最大になるような方法を求める機械学習手法です。

はじめに強化学習の設定を定義します。時刻$t$における状態$s_t$, 行動$a_t$とすると、エージェントが行動$a_t$を行なったとき、環境は、次の状態$s_{t+1}$, 報酬$r_{t}=R(s_t, a_t)$及び終端条件を返し、かつマルコフ決定過程に従います。 エージェントの方策(policy)を$\pi(a_t | s_t)$と表現します。この方策は、機械学習モデルで表現されることが多いですが、どのようなモデルでも利用できます。

次に、現在の状態や行動に対する期待報酬の予測値を表す「価値」を導入します。ある状態での状態価値$V^{\pi}(s)$は次のように定義されます。

$$ V^{\pi}(s)=\mathbb{E}\left[\sum_{k=0}^{\infty} \gamma^{k} r_{t+k+1} | s_t = s \right] $$

ここで、$\gamma$は割引率で、将来の報酬ほど価値を減衰させるパラメータです。 方策$\pi$に従って行動した時の状態$s$における期待報酬を表しています。これは状態価値(state valueと呼ばれます。一方、状態と行動に対する価値は

$$ Q^{\pi}(s, a)=\mathbb{E}\left[\sum_{k=0}^{\infty} \gamma^{k} r_{t+k+1} | s_t = s, a_t=a \right] $$

と表すことができ、状態行動価値(state-action valueもしくはQ関数と呼ばれます。 最適な価値というのはベルマン方程式を再帰的に解くことで得られることが知られています。 最適な方策を$\pi^{\ast}$としたとき、以下の式で表現されます。

$$ V^{\pi^{\ast}}(s) = \max_a \left\{r + \gamma \sum_{s'}P(s'|s, \pi(s))V^{\pi^{\ast}}(s') \right\} $$

ここで、$P$は状態遷移確率で、添え字を省略し、次の状態を$s'$として書いています。 状態遷移確率は一般的に未知であり、全ての状態に対して行動をして求めるのは困難であることが多いため、ベルマン方程式を近似する方法が提案されています。 その方法の1つに、時間的差分学習(temporal difference learning, TD learning)が知られています。これは以下の式によって価値を更新します。

$$ V(s) \leftarrow V(s) + \alpha (r + \gamma V(s') - V(s)), $$

ここで$\alpha$は学習率であり、$r + \gamma V(s')$はTD targetと呼ばれます。 この場合は1次の時間発展に対するTD targetを考慮しています。

一方、Q-learningでは、次のようになります。

$$ Q(s, a) \leftarrow Q(s) + \alpha (r + \gamma \max_a Q(s', a) - Q(s, a)), $$

ある状態$s'$の時に、全ての行動に対するQを参照し、最大のQの時の行動$a$を選択することで価値を更新していく方法です。 方策オフ型(off-policy)の更新方法です。実際には、この価値は通常、ニューラルネットワークなど何らかの形でモデル化されるため、価値関数と呼びます。

最適なQ関数を求めるために、Q-learningを行っていけば良いのですが、Q関数を近似するモデルを考える場合、誤差関数を定義する必要があります。 多くの場合、Q targetとTD targetとの二乗誤差が最小になるように価値関数を構築しますが、これはmean square Bellman error (MSBE)と呼ばれます。 すなわち、以下の式で表されます。

$$ L = \mathbb{E} \left[ (Q_{target} - (r + \gamma \max_{a'} Q)^{2} \right] $$

この時、求めるべき回帰モデルが$Q=f(s, a)$となります。 実際に探索と活用を行いながら、この回帰モデルを学習することで、Q関数を求めることができるようになります。

例えば、DQNでは、Q関数はニューラルネットワークで表現されますが、学習を安定させるため、以下のような工夫を行なっています。

  • MSEではなくhuber lossの活用
  • 経験再生バッファ(Experience replay buffer)からのランダムサンプリング
  • Fixed Q-targets (target Qを算出するニューラルネットの重みを数ステップ前のものを利用)
  • 報酬のクリッピング

現在では、DQNをより安定に、さらに性能を良くするための工夫を施した論文が数多く提案されています。

Fitted Q-iteration

時代が遡ることになりますが、ここからがfitted Q-iteration (FQI)の解説です。実際には、上記のもので解説はほとんど終わっており、DQNでモデルを決定木やRandom forestのようなモデルにしたものがFQIです。ニューラルネットワークにしたものはNeural fitted Q-iteration (NFP)と呼ばれています。オフライン強化学習(バッチ強化学習)では、得られたデータのみから強化学習モデルを構築します。すなわち、経験再生バッファが更新されない状況です。

アルゴリズムはFigure 1のようになっています。xが状態、uが行動となっている点に注意です。

FQI

ここでは実装面からFQIを見ていきます。

オフラインデータの収集

本来は実験などによって得られたものを使いますが、ここでは、FQIの実装確認のため、Q-learningを使ってデータを収集します。 通常、Q-learningでは状態を離散化し、Q table形式で学習をしていきます。FQIではQ関数がモデル化されるため連続値で利用できます。

import gym
rom collections import Counter, defaultdict
from queue import Queue
from typing import Tuple
import numpy as np


class QTable:
    def __init__(self, env, num_digit=6, init_qtable="random"):
        """
        Observation bounds:
            cart_x = (-2.4, 2.4)
            cart_v = (-3.0, 3.0)
            pole_angle = (-0.2094, 0.2094) # in radians, which is approx. (-12°, 12°)
            pole_v = (-2.0, 2.0)
        """
        self.env = env
        self.num_digit = num_digit

        if init_qtable == "random":
            self.q_table = np.random.uniform(
                low=0,
                high=1,
                size=(num_digit ** env.observation_space.shape[0], env.action_space.n),
            )
        else:
            self.q_table = np.zeros(
                shape=(num_digit ** env.observation_space.shape[0], env.action_space.n)
            )

        self.bound = np.array([[-2.4, 2.4], [-3.0, 3.0], [-0.2, 0.2], [-2.0, 2.0]])
        self.bins_list = [self.create_bins(x, y) for x, y in self.bound]
        self.shape = self.q_table.shape

    def create_bins(self, low, high):
        """Utility function to create bins."""
        return np.linspace(low, high, self.num_digit + 1)[1:-1]

    def digitize(self, observation: np.ndarray) -> int:
        """Returns discrete state from observation.

        This method digitizes the continuous state into discrete state.
        """
        digit = [np.digitize(obs, lst) for obs, lst in zip(observation, self.bins_list)]
        # convert n-digit to 10-digit
        ids = sum([dig * (self.num_digit**i) for i, dig in enumerate(digit)])
        return ids

    def __getitem__(self, idx):
        return self.q_table[idx]

    def __setitem__(self, key, value):
        self.q_table[key] = value

    def __repr__(self):
        return f"{self.__class__.__name__}(env={self.env}, num_digits={self.num_digit}, q_table={self.q_table})"


class Action:
    def __init__(self, env):
        self.num_action = env.action_space.n

    def greedy(self, q_table: QTable, state: int) -> int:
        return int(np.argmax(q_table[state]))

    def epsilon_greedy(self, q_table: QTable, state: int, episode: int) -> int:
        epsilon = 0.5 * (1 / (episode + 1))
        if np.random.uniform(0, 1) > epsilon:
            action = self.greedy(q_table, state)
        else:
            action = np.random.choice(self.num_action)
        return action


class QLearning:
    def __init__(self, env, num_digit, alpha=0.5, gamma=0.99, init_qtable="random"):
        self.action = Action(env)
        self.q_table = QTable(env, num_digit=num_digit, init_qtable=init_qtable)
        self.alpha = alpha
        self.gamma = gamma

    def update(self, state, action, reward, next_state) -> None:
        """Off-policy update

        Temporal difference target:
            TD = reward_{t} + \gamma * max_a Q(s_{t+1}, a)

        Update rule:
            Q(s_t, a_t) := Q(s_t, a_t) + \alpha * ( TD - Q(s_t, a_t) )
        """
        state = self.q_table.digitize(state)
        next_state = self.q_table.digitize(next_state)
        cur_q = self.q_table[state, action]
        td_target = reward + self.gamma * max(self.q_table[next_state, :])
        self.q_table[state, action] = cur_q + self.alpha * (td_target - cur_q)

    def compute_action(self, observation, episode: int) -> int:
        state = self.q_table.digitize(observation)
        return self.action.epsilon_greedy(
            q_table=self.q_table, state=state, episode=episode
        )
  • QTableでは、Q tableを作成するために各状態をbinningし、そのインデックスからn進数から10進数に変換するようにしています。
  • Actionでは、探索と活用部分はepsilon greedyを利用しています。
  • QLearningでは、TD learningに基づきQ更新するクラスを構築しました。

次にCartPoleをプレイするためのクラスを構築します。

class Agent:
    def __init__(self, env, episode=500, horizon=200):
        self.env = env
        self.episode = episode
        self.horizon = horizon
        self.episode_buffer = defaultdict(list)

    def compute_reward(
        self, done: bool, step: int, complete_episodes: int
    ) -> Tuple[float, int]:
        """Custom reward function"""
        if done:
            if step < 195:
                reward = -1.0
                complete_episodes = 0
            else:
                reward = 1.0
                complete_episodes += 1
        else:
            reward = 0.0

        return reward, complete_episodes

    def play_episodes(self, algo):
        complete_episodes = 0
        for episode in range(self.episode):
            state = self.env.reset()
            total_reward = 0.0
            for step in range(self.horizon):
                action = algo.compute_action(state, episode=episode)
                next_state, reward, done, _ = self.env.step(action)
                total_reward += reward
                my_reward, complete_episodes = self.compute_reward(
                    done, step, complete_episodes=complete_episodes
                )
                if episode >= 200:
                    self.episode_buffer[episode].append(
                        (state, action, next_state, my_reward)
                    )
                algo.update(state, action, my_reward, next_state)
                if done:
                    if episode % 100 == 0:
                        print(f"episode={episode}, total_reward={total_reward}")
                    break
                state = next_state

            if complete_episodes >= 10:
                print("10 times successes")
                break

env = gym.make("CartPole-v0")
q_learning = QLearning(env, num_digit=9, alpha=0.6, gamma=0.99)
agent = Agent(env, episode=1000)
agent.play_episodes(q_learning)

self.replay_buffer内に(状態、行動、次の状態、報酬)の軌跡(trajectory)を貯められるようにしています。 ステップ数が195回以下なら報酬が-1、それ以外は0、成功したら報酬が1となるようなスパースな報酬設定です。 ある程度、パフォーマンスの良い軌跡をバッファーに格納したいため、エピソードが200以上の軌跡を保存しています。

モデルの初期化

ここからは一つ一つ実行しながら見ていきます。 計算が高速なLightGBMを利用します。

初期のモデルには(状態、行動)から報酬を予測するモデルを作ります。

import pandas as pd
from lightgbm import LGBMRegressor

cols = ["state", "action", "next_state", "reward"]
episode_list = []
for i in range(len(agent.episode_buffer)):
    tmp = pd.DataFrame(agent.episode_buffer[i], columns=cols)
    episode_list.append(tmp)

data = pd.DataFrame(pd.concat(episode_list), columns=cols))
states = np.vstack(data["state"].to_numpy())
actions = data["action"].to_numpy()
X = np.c_[states, actions.reshape(-1, 1)]
y = data["reward"].to_numpy()
model = LGBMRegressor()
model.fit(X, y)

このモデルの出力結果は以下のようになります。(状態、行動)からは正しく報酬を予測できるモデルができているとは言い難い見た目になっています。

import matplotlib.pyplot as plt
plt.scatter(reward, model.predict(X))

reward=f(s, a)

FQIの実行

先ほど収集したデータを使ってQ targetを算出します。 得られたものを使ってQ関数(LightGBM)を学習します。

def compute_q_target(args):
    row, model, gamma = args
    next_state = np.array(row["next_state"])
    q_values_next_state = [
        model.predict(np.append(next_state, a).reshape(1, -1))[0] for a in [0, 1]
    ]
    max_q_value_next_state = max(q_values_next_state)
    return row["reward"] + gamma * max_q_value_next_state

def get_q_target(data, model, gamma=0.99, n_jobs=1):
    with Pool(n_jobs) as p:
        args_list = [(row, model, gamma) for _, row in data.iterrows()]
        q_targets = list(p.map(compute_q_target, args_list))
    return q_targets

# Fitted Q-iteration
q_target_list = []
rmse_list = []
for i in tqdm(range(10)):
    q_targets =get_q_target(data, model, gamma=0.99, n_jobs=8)  
    model.fit(X, q_targets)
    q_target_list.append(q_targets)
    if i > 0:
        rmse = np.sqrt(mean_squared_error(q_targets, q_target_list[i-1]))
        rmse_list.append(rmse)
        print("RMSE of Q:", round(rmse, 4))
plt.scatter(q_targets, model.predict(X))

結果を見るとQ関数を予測できていそうです。誤差(MSBE)も下がっています。

mode output

Loss of mean squared Bellman error

最後は得られたモデルを使って実際にCartPoleをプレイします。

total_reward_list = []
for i in tqdm(range(100)):
    env = gym.make("CartPole-v0")
    #env.seed(i)
    obs = env.reset()
    total_reward = 0.0
    q_values_list = []
    while True:
        q_values = [model.predict(np.append(obs, a).reshape(1, -1))[0] for a in [0, 1]]
        q_values_list.append(q_values)
        action = np.argmax(q_values)
        obs, reward, done, info = env.step(action)
        total_reward += reward
        if done:
            break
    total_reward_list.append(total_reward)    
    
print(np.mean(total_reward_list),"+/-", round(np.std(total_reward_list), 4))
> 95.23 +/- 3.187

実際に得られたiterationを20回程度行ったモデルを使うと、ある程度連続して成功できるようなQ関数(LightGBM)が学習できました。 初期の収集したデータに依存して結果はばらつきそうです。200回連続して得られるような場合もありましたが、大体は上記のような結果です。

課題

FQIは非常に単純ながらQ関数を近似することができ、環境とのやり取りが行えないよう状況では強力なオフライン強化学習手法の1つです。しかしながら、オフライン強化学習全般の欠点として、Q値の過大評価があります。Conservative Q-learing (CQL)のような方法だと正則化などを 導入し、この問題を軽減することができています。FQIではモデルが木構造のものを用いるので、見たことのない状況(O.O.D)では、ワークしないと思われます。ニューラルネットワークを利用したFQIの方が良いと思いますが、学習の不安定性が残ります。とりあえずオフライン強化学習をするならCQLを行うのが良いと思います。