ベイジアンネットワークの推論では良く用いられるアルゴリズムです。分子構造もグラフ構造なので本アルゴリズムは適用することができます。グラフを木分解(tree decomposition)することで得られる構造がjunction treeです。有向グラフの場合、無向グラフにする操作(モラル化など)が必要となりますが、分子構造の場合は無向グラフなので、この変換は必要ありません。
参考にしてる文献は下記です。
JT-VAEを使ってみた結果など、こちらも参考になります。
グラフ理論
改めてグラフ理論の定義をおさらいしてみます。
グラフは頂点もしくはノードの有限非空集合 (finite non-empty set) の非順序ペア (unordered pairs of vertices)からなるエッジの集合からなるデータ構造です。
\begin{align} G=(V, E) \end{align}
分子構造だと、ノードは原子(C, N, O, F)などでエッジが結合(単結合、二重結合、・・・)などです。データ同士に関係性を持つようなデータセットはグラフ構造で表すことができます。
グラフのデータ構造は
隣接リスト
隣接行列
形式の2通りで表現できます。隣接行列の方がわかりやすいのですが、スパースになるため大きい分子構造には対応が難しいかもしれません。ただ、GNNと生成モデルを組み合わせればグラフ生成をワンショットで行えるので面白いです。隣接リストの方がプログラミングする上では簡単です。どちらが良いかは状況によるので使い分けるの吉です。
サブグラフとは
グラフ
\begin{align} H=(V_H, E_H) \end{align}
を考えます。 ならば のサブグラフといいます。またはのスーパーグラフともいえます。頂点のサブセットが与えられれば、他のサブグラフは中の頂点間中に存在する全てのグラフから完全に構成されるので、より厳密に全ての頂点は
\begin{align} v_i, v_j \in V', \\ (v_i, v_j) \in E ⇔ (v_j, v_i) \in E \end{align}
です。言い換えれば、2ノードは に隣接し、Gに隣接してさえすればよいです。ノードの全ペア間のエッジが存在するならば、(サブ)グラフは完全(もしくはクリーク)と呼ばれます。このクリークが他のクリークに含まれないならばそれを極大クリーク (maximal clique )と呼びます。
Junction Tree(JT)
木分解によりグラフをjunction treeに写像できます。この操作によりサイクルフリーな構造(木構造)になります。 JTをはノード集合がでエッジ集合がと表せます。はサブグラフ構造のラベルです。各ノード、すなわちクラスターとなります。正確な定義は次の通りになります。
(1) 全クラスターのユニオンはGと等しい。すなわち
\begin{align} \cup_i V_i = V \text{かつ} \cup_i E_i = E \end{align}
(2) インターセクションがある。全クラスター間で以下となる。
\begin{align} V_i \cap V_j \subseteq V_k \end{align}
ジャンクションツリーにすることで、分子構造を生成する上で問題であった「環構造の発生」が問題なく行えるようになります。たとえばサブグラフのベンゼンが1つのクラスター(ノード)としてあらわすことができるようになります。
簡易的なコード
この化合物を木分解してみます。
クリークの作成
まずはRDkitの関数によりSSSR (smallest set of smallest rings)によりリングの個数を計算します。とりあえず、上記のpd1_inihibitor_datasetからsmilesを取得しておきます。
import sys import rdkit from rdkit import Chem def get_clique_mol(mol, atoms): smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True) new_mol = Chem.MolFromSmiles(smiles, sanitize=False) new_mol = copy_edit_mol(new_mol).GetMol() new_mol = sanitize(new_mol) return new_mol def get_cliques(mol): cliques = [] for bond in mol.GetBonds(): a1 = bond.GetBeginAtom().GetIdx() a2 = bond.GetEndAtom().GetIdx() if not bond.IsInRing(): cliques.append([a1,a2]) return cliques path = "data/pd1_inhibitor_dataset/" df = pd.read_csv(path+"PD1_inhibitor_dataset.csv") mol = Chem.MolFromSmiles(df.smiles[0]) n_atoms = mol.GetNumAtoms() cliques = get_cliques(mol) ssr = [list(x) for x in Chem.GetSymmSSSR(mol)] cliques.extend(ssr) print(cliques) for i in range(len(ssr)): print('{}番目の環のサイズ:{}'.format(i, len(ssr[i]))) # 隣接リスト nei_list = [[] for i in range(n_atoms)] for i in range(len(cliques)): for atom in cliques[i]: nei_list[atom].append(i) print(nei_list)
上記のコードを実行すると、クリークとなる番号が取得できます。
0番目の環のサイズ:45 1番目の環のサイズ:5 2番目の環のサイズ:5 3番目の環のサイズ:6 4番目の環のサイズ:5 5番目の環のサイズ:6 6番目の環のサイズ:5 7番目の環のサイズ:6 [[0, 1], [1, 2], [2, 3], [3, 4], [5, 6], [12, 13], [15, 16], [16, 17], [17, 18], [17, 19], [20, 21], [23, 24], [24, 25], [24, 26], [27, 28], [29, 30], [31, 32], [32, 33], [39, 40], [42, 43], [43, 44], [44, 45], [44, 46], [47, 48], [55, 56], [58, 59], [59, 60], [69, 70], [72, 73], [73, 74], [77, 78], [81, 82], [84, 85], [85, 86], [86, 87], [86, 88], [89, 90], [92, 93], [93, 94], [93, 95], [95, 96], [96, 97], [97, 98], [97, 99], [103, 104], [106, 107], [107, 108], [110, 111], [112, 113], [114, 115], [117, 118], [119, 120], [121, 122], [122, 123], [129, 130], [131, 132]]
スパニングツリー
スパニングツリーは、あるグラフの全ての頂点とそのグラフを構成する辺の一部分のみで構成される木です。クリークが得られれば、ジャンクションツリーにするのは簡単です。スパニングツリーにするためのアルゴリズムを用います。
もともとは最小経路を求めるアルゴリズムです。
有名どころだと
があります。プリム法はダイクストラ法とほぼ同じ方法なので実装は簡単です。ポイントはプライオリティキューを用いることです。最初の経路を選択していき、通った経路をヒープキューにプッシュしていきます。scipy
にはクラスカル法が実装されています。クラスカル法は適当なエッジを選択していきますが、1つノードに3つ以上のエッジが接続したとき、最小となる重みを選択していきます。英語版のwikiをみるとわかりやすいです。
プリム法の例
import heapq # minimum spanning tree # graph[vertex] = (weight, end, start) def prim(graph): n = len(graph) used = [True] * n edgelist = [] for edge in graph[0]: heapq.heappush(edgelist, edge) used[0] = False res_edge = [] res = 0 while edgelist: minedge = heapq.heappop(edgelist) if not used[minedge[1]]: continue v = minedge[1] used[v] = False for edge in graph[v]: if used[edge[1]]: heapq.heappush(edgelist, edge) res += minedge[0] res_edge.append((minedge[2], minedge[1]) return res, res_edge
※ 実際はscipy.sparse.csgraph
のminimum_spanning_tree()
を使えばよいです。
JTを作成する場合の注意点はMinimum Spanning TreeではなくMaximal Spanning Treeの点です。最大の重みを定義しておき、引いておく(逆の操作)で最大のスパニングツリーを取得します。
木分解
from collections import defaultdict from scipy.sparse import csr_matrix from scipy.sparse.csgraph import minimum_spanning_tree from collections import defaultdict from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions def tree_decomp(mol): n_atoms = mol.GetNumAtoms() cliques = get_cliques(mol) ssr = [list(x) for x in Chem.GetSymmSSSR(mol)] cliques.extend(ssr) nei_list = [[] for i in range(n_atoms)] for i in range(len(cliques)): for atom in cliques[i]: nei_list[atom].append(i) #Merge Rings with intersection > 2 atoms for i in range(len(cliques)): if len(cliques[i]) <= 2: continue for atom in cliques[i]: for j in nei_list[atom]: if i >= j or len(cliques[j]) <= 2: continue inter = set(cliques[i]) & set(cliques[j]) if len(inter) > 2: cliques[i].extend(cliques[j]) cliques[i] = list(set(cliques[i])) cliques[j] = [] cliques = [c for c in cliques if len(c) > 0] nei_list = [[] for i in range(n_atoms)] for i in range(len(cliques)): for atom in cliques[i]: nei_list[atom].append(i) #Build edges and add singleton cliques edges = defaultdict(int) for atom in range(n_atoms): if len(nei_list[atom]) <= 1: continue cnei = nei_list[atom] bonds = [c for c in cnei if len(cliques[c]) == 2] rings = [c for c in cnei if len(cliques[c]) > 4] if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): cliques.append([atom]) c2 = len(cliques) - 1 for c1 in cnei: edges[(c1,c2)] = 1 elif len(rings) > 2: #Multiple (n>2) complex rings cliques.append([atom]) c2 = len(cliques) - 1 for c1 in cnei: edges[(c1,c2)] = MST_MAX_WEIGHT - 1 else: for i in range(len(cnei)): for j in range(i + 1, len(cnei)): c1,c2 = cnei[i],cnei[j] inter = set(cliques[c1]) & set(cliques[c2]) if edges[(c1,c2)] < len(inter): edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction edges = [u + (MST_MAX_WEIGHT-v,) for u,v in edges.items()] if len(edges) == 0: return cliques, edges #Compute Maximum Spanning Tree row, col, data = zip(*edges) n_clique = len(cliques) clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) ) junc_tree = minimum_spanning_tree(clique_graph) row,col = junc_tree.nonzero() edges = [(row[i],col[i]) for i in range(len(row))] return (cliques, edges)
結果の可視化
import networkx as nx import matplotlib.pyplot as plt plt.figure(figsize=(8, 6), dpi= 80, facecolor='w', edgecolor='k') graph = nx.Graph() graph.add_edges_from(edges) import matplotlib.pyplot as plt nx.draw_networkx(graph) plt.show()
networkxで可視化すると次のような図が得られました。各ノードにはそれぞれサブグラフであるので、分子を生成するためには、木構造を拡大していけばよいです。
Next Action?
木構造に対してニューラルネットワークを行っていく ⇒ JT-VAEのアプローチです。
複数のアプローチがまだまだ考えられますね。
JT-VAEを試してみると、分子生成部分にはグラフニューラルネットワークにあまり依存しない感じなのと、復元率が論文程よくないため、まだまだ改善点がありそうですね。最適化の工夫もできそうです。時間があれば改良していきたい。他の論文もウォッチしなければ。DGLも試したい・・・。