グラフ機械学習と強化学習について

主にグラフ機械学習や強化学習手法を記載します。

Junction Treeアルゴリズム

ベイジアンネットワークの推論では良く用いられるアルゴリズムです。分子構造もグラフ構造なので本アルゴリズムは適用することができます。グラフを木分解(tree decomposition)することで得られる構造がjunction treeです。有向グラフの場合、無向グラフにする操作(モラル化など)が必要となりますが、分子構造の場合は無向グラフなので、この変換は必要ありません。

参考にしてる文献は下記です。

arxiv.org

JT-VAEを使ってみた結果など、こちらも参考になります。

qiita.com

グラフ理論

改めてグラフ理論の定義をおさらいしてみます。

グラフ Gは頂点もしくはノードの有限非空集合 (finite non-empty set)  Vと頂点Vの非順序ペア (unordered pairs of vertices)からなるエッジの集合 E \subseteq  V \times V からなるデータ構造です。

\begin{align} G=(V, E) \end{align}

分子構造だと、ノードは原子(C, N, O, F)などでエッジが結合(単結合、二重結合、・・・)などです。データ同士に関係性を持つようなデータセットはグラフ構造で表すことができます。

グラフのデータ構造は

  • 隣接リスト

  • 隣接行列

形式の2通りで表現できます。隣接行列の方がわかりやすいのですが、スパースになるため大きい分子構造には対応が難しいかもしれません。ただ、GNNと生成モデルを組み合わせればグラフ生成をワンショットで行えるので面白いです。隣接リストの方がプログラミングする上では簡単です。どちらが良いかは状況によるので使い分けるの吉です。

サブグラフとは

グラフ

\begin{align} H=(V_H, E_H) \end{align}

を考えます。 V_H \in V, E_H \in E ならば  G のサブグラフといいます。またG Hのスーパーグラフともいえます。頂点のサブセット V' \subseteq Vが与えられれば、他のサブグラフ G' = (V',E') V中の頂点間 G中に存在する全てのグラフから完全に構成されるので、より厳密に全ての頂点は

\begin{align} v_i, v_j \in V', \\ (v_i, v_j) \in E ⇔ (v_j, v_i) \in E \end{align}

です。言い換えれば、2ノードは G' に隣接し、Gに隣接してさえすればよいです。ノードの全ペア間のエッジが存在するならば、(サブ)グラフは完全(もしくはクリーク)と呼ばれます。このクリークが他のクリークに含まれないならばそれを極大クリーク (maximal clique )と呼びます。

Junction Tree(JT

木分解によりグラフ Gをjunction treeに写像できます。この操作によりサイクルフリーな構造(木構造)になります。 JT \mathcal{T}_G = (\mathcal{V, E, X})はノード集合が \mathcal{V} = \{C_1, \ldots, C_n\} でエッジ集合が \mathcal{E}と表せます。 \mathcal{X}はサブグラフ構造のラベルです。各ノード、すなわちクラスタ C_i = (V_i, E_i)となります。正確な定義は次の通りになります。

(1) 全クラスターのユニオンはGと等しい。すなわち

\begin{align} \cup_i V_i = V \text{かつ} \cup_i E_i = E \end{align}

(2) インターセクションがある。全クラスター間で以下となる。

\begin{align} V_i \cap V_j \subseteq V_k \end{align}

ジャンクションツリーにすることで、分子構造を生成する上で問題であった「環構造の発生」が問題なく行えるようになります。たとえばサブグラフのベンゼンが1つのクラスター(ノード)としてあらわすことができるようになります。

簡易的なコード

f:id:udnp:20191014120931p:plain

この化合物を木分解してみます。

https://github.com/masahiro-mochizuki/pd1_inhibitor_dataset

クリークの作成

まずはRDkitの関数によりSSSR (smallest set of smallest rings)によりリングの個数を計算します。とりあえず、上記のpd1_inihibitor_datasetからsmilesを取得しておきます。

import sys
import rdkit
from rdkit import Chem


def get_clique_mol(mol, atoms):
    smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
    new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
    new_mol = copy_edit_mol(new_mol).GetMol()
    new_mol = sanitize(new_mol)
    return new_mol


def get_cliques(mol):
    cliques = []
    for bond in mol.GetBonds():
        a1 = bond.GetBeginAtom().GetIdx()
        a2 = bond.GetEndAtom().GetIdx()
        if not bond.IsInRing():
            cliques.append([a1,a2])
    return cliques

path = "data/pd1_inhibitor_dataset/"
df = pd.read_csv(path+"PD1_inhibitor_dataset.csv")

mol = Chem.MolFromSmiles(df.smiles[0])
n_atoms = mol.GetNumAtoms()
cliques = get_cliques(mol)
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)
print(cliques)
for i in range(len(ssr)):
    print('{}番目の環のサイズ:{}'.format(i, len(ssr[i])))

# 隣接リスト
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
    for atom in cliques[i]:
        nei_list[atom].append(i)

print(nei_list)

上記のコードを実行すると、クリークとなる番号が取得できます。

0番目の環のサイズ:45
1番目の環のサイズ:5
2番目の環のサイズ:5
3番目の環のサイズ:6
4番目の環のサイズ:5
5番目の環のサイズ:6
6番目の環のサイズ:5
7番目の環のサイズ:6

[[0, 1], [1, 2], [2, 3], [3, 4], [5, 6], [12, 13], [15, 16], [16, 17], [17, 18], [17, 19], [20, 21], [23, 24], [24, 25], [24, 26], [27, 28], [29, 30], [31, 32], [32, 33], [39, 40], [42, 43], [43, 44], [44, 45], [44, 46], [47, 48], [55, 56], [58, 59], [59, 60], [69, 70], [72, 73], [73, 74], [77, 78], [81, 82], [84, 85], [85, 86], [86, 87], [86, 88], [89, 90], [92, 93], [93, 94], [93, 95], [95, 96], [96, 97], [97, 98], [97, 99], [103, 104], [106, 107], [107, 108], [110, 111], [112, 113], [114, 115], [117, 118], [119, 120], [121, 122], [122, 123], [129, 130], [131, 132]]

スパニングツリー

スパニングツリーは、あるグラフの全ての頂点とそのグラフを構成する辺の一部分のみで構成される木です。クリークが得られれば、ジャンクションツリーにするのは簡単です。スパニングツリーにするためのアルゴリズムを用います。

もともとは最小経路を求めるアルゴリズムです。

有名どころだと

があります。プリム法はダイクストラ法とほぼ同じ方法なので実装は簡単です。ポイントはプライオリティキューを用いることです。最初の経路を選択していき、通った経路をヒープキューにプッシュしていきます。scipyにはクラスカル法が実装されています。クラスカル法は適当なエッジを選択していきますが、1つノードに3つ以上のエッジが接続したとき、最小となる重みを選択していきます。英語版のwikiをみるとわかりやすいです。

プリム法の例

import heapq
 
 
# minimum spanning tree
# graph[vertex] = (weight, end, start)
def prim(graph):
    n = len(graph)
    used = [True] * n
    edgelist = []
    for edge in graph[0]:
        heapq.heappush(edgelist, edge)
 
    used[0] = False
    res_edge = []
    res = 0
    while edgelist:
        minedge = heapq.heappop(edgelist)
        if not used[minedge[1]]:
            continue
 
        v = minedge[1]
        used[v] = False
        for edge in graph[v]:
            if used[edge[1]]:
                heapq.heappush(edgelist, edge)
        res += minedge[0]
        res_edge.append((minedge[2], minedge[1])
 
    return res, res_edge

※ 実際はscipy.sparse.csgraphminimum_spanning_tree()を使えばよいです。

JTを作成する場合の注意点はMinimum Spanning TreeではなくMaximal Spanning Treeの点です。最大の重みを定義しておき、引いておく(逆の操作)で最大のスパニングツリーを取得します。

木分解

from collections import defaultdict
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions


def tree_decomp(mol):
    n_atoms = mol.GetNumAtoms()
    cliques = get_cliques(mol)
    ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
    cliques.extend(ssr)

    nei_list = [[] for i in range(n_atoms)]
    for i in range(len(cliques)):
        for atom in cliques[i]:
            nei_list[atom].append(i)
    
    #Merge Rings with intersection > 2 atoms
    for i in range(len(cliques)):
        if len(cliques[i]) <= 2: 
            continue
        for atom in cliques[i]:
            for j in nei_list[atom]:
                if i >= j or len(cliques[j]) <= 2: 
                    continue
                inter = set(cliques[i]) & set(cliques[j])
                if len(inter) > 2:
                    cliques[i].extend(cliques[j])
                    cliques[i] = list(set(cliques[i]))
                    cliques[j] = []
    
    cliques = [c for c in cliques if len(c) > 0]
    nei_list = [[] for i in range(n_atoms)]
    for i in range(len(cliques)):
        for atom in cliques[i]:
            nei_list[atom].append(i)
    
    #Build edges and add singleton cliques
    edges = defaultdict(int)
    for atom in range(n_atoms):
        if len(nei_list[atom]) <= 1: 
            continue
        cnei = nei_list[atom]
        bonds = [c for c in cnei if len(cliques[c]) == 2]
        rings = [c for c in cnei if len(cliques[c]) > 4]
        if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): 
      
            cliques.append([atom])
            c2 = len(cliques) - 1
            for c1 in cnei:
                edges[(c1,c2)] = 1
        elif len(rings) > 2: #Multiple (n>2) complex rings
            cliques.append([atom])
            c2 = len(cliques) - 1
            for c1 in cnei:
                edges[(c1,c2)] = MST_MAX_WEIGHT - 1
        else:
            for i in range(len(cnei)):
                for j in range(i + 1, len(cnei)):
                    c1,c2 = cnei[i],cnei[j]
                    inter = set(cliques[c1]) & set(cliques[c2])
                    if edges[(c1,c2)] < len(inter):
                        edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction

    edges = [u + (MST_MAX_WEIGHT-v,) for u,v in edges.items()]
    
    if len(edges) == 0:
        return cliques, edges

    #Compute Maximum Spanning Tree
    row, col, data = zip(*edges)
    n_clique = len(cliques)
    clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) )
    junc_tree = minimum_spanning_tree(clique_graph)
    row,col = junc_tree.nonzero()
    edges = [(row[i],col[i]) for i in range(len(row))]
    return (cliques, edges)

結果の可視化

import networkx as nx
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 6), dpi= 80, facecolor='w', edgecolor='k')

graph = nx.Graph()
graph.add_edges_from(edges)
import matplotlib.pyplot as plt
nx.draw_networkx(graph)
plt.show()

networkxで可視化すると次のような図が得られました。各ノードにはそれぞれサブグラフであるので、分子を生成するためには、木構造を拡大していけばよいです。

f:id:udnp:20191014123920p:plain

Next Action?

木構造に対してニューラルネットワークを行っていく ⇒ JT-VAEのアプローチです。

  • リーフを拡張するのにRNN + MCTS
  • そもそもリーフをでなくて強化学習でノードを拡張していく。

複数のアプローチがまだまだ考えられますね。

JT-VAEを試してみると、分子生成部分にはグラフニューラルネットワークにあまり依存しない感じなのと、復元率が論文程よくないため、まだまだ改善点がありそうですね。最適化の工夫もできそうです。時間があれば改良していきたい。他の論文もウォッチしなければ。DGLも試したい・・・。