グラフ機械学習と強化学習について

主にグラフ機械学習や強化学習手法を記載します。

ECFPとNeural-fingerprintの比較

はじめに

ケモインフォマティクスやマテリアルズインフォマティクスでは良く使用されるExtended Connectivity Circular Fingerprints (ECFP)及び Neural Graph Fingerprint (NFP) を実装してみました。ZincデータセットのlogP-SAを予測することで、両手法を比較しています。ECFPとNFPは下論文で提案された方法です。

ECFP自体は類似する方法がRDKitやCDKに実装されおり、正確かつ高速に使うことができます。NFPは実装上の都合上、CPUオンリーでバッチに対応していません。私の実装したものは、あくまで理解のためですので、アルゴリズムは多少間違っているかもしれません。ECFPについてはQiitaにも載せています。

フィンガープリント

化学の分野で用いられるフィンガープリントは、分子構造を表現する方法の1つで2次元構造に基づいています。主に、分子の「類似性検索」や定量的構造活性・物性相関(QSAR, Quantitative Structure-Activity Relationship or QSPR, Quantitative Structure-Property Relationship)モデル構築に使用されます。分子構造はSMILESと呼ばれる文字列情報で表現されます。また、分子はグラフ構造でも表せ、原子情報をノード、結合情報をエッジとして表現することもできます。原子=ノード、結合=エッジとして考えて構いません。私がよく用いるフィンガープリントはECFPですが、グラフニューラルネットワークによるフィンガープリントは多数ありますが、本記事では、ECFPとNFPについて記載します。

f:id:udnp:20200408183011p:plain

図1はECFPとNFPの概念図です。左図はECFPの概略図で分子グラフのメッセージパッシングによりノード(原子の)情報が伝達され、最終的に固定長の0, 1のビットに変換されます。右図はNFPで情報伝達はニューラルネットワークで重み付けされ情報が伝達されます。

アルゴリズム

実際にアルゴリズムを見るのがわかりやすいと思います。

f:id:udnp:20200408182955p:plain

Circular fingerprints (ECFP)

上図のアルゴリズムを詳しく見ていきます。まず、分子の半径Rとfingerprintの長さSを決めます。この半径は隣接ノードへ伝達する回数で、この操作はメッセージパッシングと呼ばれます。半径2なら2つ隣りのノードまで、R=3なら3分子離れたノード情報を伝達していく操作になります。fingerprintの配列fを0で初期化しておきます。固定長Sは1024や2048 bitsとなることが多いです。分子中の原子ごとにAtom Propertiesであるノード特徴量をリストに追加していきます。例えば、

  • 原子番号
  • 隣接重原子(水素を除いた原子)の数
  • 原子価
  • 環構造であるかどうか

などです。他にも追加していくことができます。隣接ノード情報を注目しているノードへ集め、concatenationしていきます。これをハッシュに通し、fingerprintの長さSで余りを求めます。この余りのインデックス部分に1を書き込みます。これを全原子に対して行い(for each atom in molecule)、何度も繰り返す(for 1 to R)ことで情報をメッセージパッシングさせることができます。この操作はMorganアルゴリズムとも呼ばれています。グラフ同型問題で非常に高い識別精度をもつWeisfeiler-Lehman graph isomorphism testと操作は非常に似ています。ECFPやNFPでは、この特徴量をどう決めるかによって結果が大きく変わるので、正しく決めることが大事です。

Neural graph fingerprints (NFP)

NFPとECFPとの違いは、concatenate以降のところがニューラルネットワークに決定される所になります。ECFPでは単純に結合し、ハッシュを通して、その余りのインデックスを1としている決定的な方法でした。NFPでは学習によってノード情報をランク付けし、ソフトマックスで確率的にインデックスを決定し、フィンガープリントの値としています。そのため、0, 1のベクトルではなく連続値となります。論文名のConvolutional Networks on Graphともあるように、 これはグラフ畳み込みニューラルネットワーク(GCN)の考え方です。現在はより一般的なフレームワーク(DGL, DeepChem, GraphNetsやPyTorch geometricなど)があるので、そちらで実装した方が良いと思います。NFPはGCNとほとんど同じアルゴリズムですが、NFPでは最終層でsoftmaxを取り、sum poolingを行っている点が異なっています。

ECFPの実装

分子を扱う場合はやはり、RDKitがとても便利です。RDKitでノード特徴量を作っていきます。

RDKit版 ECFP

from __future__ import print_function
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
import numpy as np

smiles = 'c1ccccn1'

mol = Chem.MolFromSmiles(smiles)
bit_morgan1 = {}
fp1 = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=12, bitInfo=bit_morgan1)
bit1 = list(fp1)
print(bit1)
>> [0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0]

RDKitを用いれば、このように簡単にSMILESからECFPを行うことができます。このフィンガープリントを説明変数として目的変数を回帰したものが、QSAR / QSPRと呼ばれるものになります。

参考までに分子を描画します。

def mols2grid_image(mols, molsPerRow=1):
    mols = [e if e is not None else Chem.RWMol() for e in mols]
    for mol in mols:
        AllChem.Compute2DCoords(mol)

    return Draw.MolsToGridImage(mols, molsPerRow=molsPerRow, subImgSize=(150,150))

mols2grid_image([mol], molsPerRow=2)

f:id:udnp:20200408183032p:plain

自作版 ECFP

自作するためには、SMILESから原子の情報とエッジの情報を表す特徴を抽出しなければなりません。複数分子に対応できるようにmolオブジェクトをリストで渡すようにしていますが、その後の処理は行っていません。

  • 隣接行列
def GetAdjacencyMatrix(smols: list, connected=True):
    max_length = max(mol.GetNumAtoms() for mol in mols)
    bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType() for mol in mols for bond in mol.GetBonds())))
    bond_encoder_m = {l: i for i, l in enumerate(bond_labels)}

    A = np.zeros(shape=(max_length, max_length), dtype=np.int32)
    begin, end = [b.GetBeginAtomIdx() for b in mol.GetBonds()], [b.GetEndAtomIdx() for b in mol.GetBonds()]
    bond_type = [bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
    A[begin, end] = bond_type
    A[end, begin] = bond_type
    return A 

GetAdjacencyMatrixで隣接行列を取得します。smilesからChem.MolFromSmilesによりmolオブジェクトにした後、結合種類によってラベル付けをし、分子のインデックス、結合情報が得られるので書き込んでいきます。なお、rdkit.Chem.rdmolops.GetAdjacencyMatrix()だと上記の操作が行えます。

m = Chem.MolFromSmiles('c1ccccc1')
Chem.rdmolops.GetAdjacencyMatrix(m)
  • 特徴量抽出

RDKitの便利な関数により、分子中の各原子ごとの情報を行列にしていきます。RDKitのECFPがどのような情報を用いているかわからなかったので、適当に選択しています。Module Hierarchyが参考になります。

def createNodeFeatures(mol) -> np.ndarray:
    """ノード特徴量を生成する

    :param mol: rdkit mol object
    :return: np.array
    """
    features = np.array([[
        *[a.GetDegree() == i for i in range(5)],
        *[a.GetExplicitValence() == i for i in range(9)],
        *[int(a.GetHybridization()) == i for i in range(1, 7)],
        *[a.GetImplicitValence() == i for i in range(9)],
        a.GetIsAromatic(),
        a.GetNoImplicit(),
        *[a.GetNumExplicitHs() == i for i in range(5)],
        *[a.GetNumImplicitHs() == i for i in range(5)],
        *[a.GetNumRadicalElectrons() == i for i in range(5)],
        a.IsInRing(),
        a.GetAtomicNum(),
        a.GetDegree(),
        a.GetExplicitValence(),
        a.GetImplicitValence(),
        a.GetFormalCharge(),
        a.GetTotalNumHs()] for a in mol.GetAtoms()], dtype=np.int32)

    return features

特徴抽出の結果は、次のようになります。rowがノード数、columnが特徴数です。

>> print(createNodeFeatures(mol))
[[0. 1. 0. ... 0. 3. 1.]
 [0. 0. 1. ... 0. 2. 2.]
 [0. 0. 0. ... 0. 0. 3.]
 ...
 [0. 0. 1. ... 0. 1. 2.]
 [0. 0. 0. ... 0. 0. 3.]
 [0. 1. 0. ... 0. 0. 1.]]
  • MyECFPクラス 上記の隣接行列と原子の特徴行列からECFPを実装します。
class ECFP:
    """Extended Connectivity Fingerpritnsを計算する

    :param mol: rdkit mol object
    :param radius: 隣接する原子情報をどこまで見るか. メッセージパッシングの回数に相当する
    :param nbits: 得られた部分構造の格納数するためのビット数
    :param n_feat: 特徴量行列 (ノード数 × 特徴量数 )
    """

    def __init__(self, mol, radius: int = 2, nbits: int = 2048, n_feat: np.array = None):
        self.mol = mol
        self.radius = radius
        self.nbits = nbits
        self.fps = np.zeros(shape=(self.nbits,), dtype=np.int32)

        if n_feat is None:
            n_feat = createNodeFeatures(mol)

        n_feat = np.array(n_feat, dtype=np.int32)
        n_atoms = n_feat.shape[0]
        self.adj = Chem.GetAdjacencyMatrix(mol)
        deg = np.array(np.sum(self.adj, axis=1), dtype=str)

        # concatenate node features.
        n_feat = np.array([''.join([str(f) for f in n_feat[atom]]) for atom in range(n_atoms)], dtype=str)
        self.n_feat = np.array([n_feat[i] + deg[i] for i in range(n_atoms)])

    def _concat_neighbor(self, atom: int, n_feat: list) -> list:
        """隣接情報を付与する。メッセージパッシングとaggregationに相当する。concatenationで次元を減らす。

        :param atom: 注目している原子のindex
        :param n_feat: 更新前の特徴量情報
        :return: ノード特徴量を返す
        """
        nei_id = np.nonzero(self.adj[atom])[0]
        new = [str(nid) + str(feat) + str(self.adj[atom][nei_id][nid]) for nid, feat in enumerate(n_feat[nei_id])]
        vec = ''.join([str(ind) for ind in new])
        return vec

    def calculate(self, nei_info: bool = False) -> np.ndarray:
        """ECFPを計算する

        :param nei_info: メッセージパッシング前の情報も加えるかどうか。無いほうが精度上がる。
        :return: result of fingerprint.
        """
        n_atoms = self.n_feat.shape[0]
        identifier = copy.deepcopy(self.n_feat)
        for _ in range(0, self.radius):
            for atom in range(n_atoms):
                if _ == 0:
                    v = self.n_feat[atom]
                else:
                    v = self.n_feat[atom] if not nei_info else self._concat_neighbor(atom, identifier)
                identifier[atom] = hash(v)
                index = int(identifier[atom]) % self.nbits
                self.fps[index] = 1

        self.identifier = identifier
        return self.fps

concat_neighbor()では、隣接する原子の情報を集めます。隣接のidと情報、結合次数をconcatenateしています。calculate()でECFPのアルゴリズム通り実装していきます。各原子ごとに隣接する特徴量を取得し、hash()を通し、これをその原子の識別子とします。これを固定長で割った余りが部分構造のindexになります。これを全ての原子に対して行います。メッセージパッシング操作を繰り返すほど、遠くの原子の情報を伝えることに相当します。

NFP実装

NFPに関してはDGL+PyTorchを用いて実装していきます。このライブラリは化学構造に特化していおり、JTNNやMPNNの学習モデルもあり、非常に使いやすいです。DGLはグラフニューラルネットワークを統一的なフレームワークで実装できるようになっています。

前処理

SMILESをグラフデータにする必要があります。簡単なのはmol_tobigraph()を用いることです。ノードやエッジの特徴量に関してはCanonicalAtomFeaturizer()`CanonicalBondFeaturizer()やを用いると自動的に特徴量を作ることができます。

mols_graph = [mol_to_bigraph(
    mol, 
    node_featurizer=CanonicalAtomFeaturizer(), 
    edge_featurizer=CanonicalBondFeaturizer()) for mol in mols]

ECFPと比較するために、上記createNodeFeatures()での特徴量を初期値とします。

モデル

import dgl.function as fn

gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')


class NFP(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, depth=2, nbits=16):
        super(NFP, self).__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, nbits)
        self.softmax = nn.Softmax(dim=1)
        
        self.depth = depth
        self.nbits = nbits
        
        self.linear3 = nn.Linear(nbits, hidden_dim)        
        self.linear4 = nn.Linear(hidden_dim, 1)      
         
    def forward(self, g, n_feat):
        with g.local_scope():
            fps = torch.zeros([1, self.nbits])
            for _ in range(self.depth):
                g.ndata['h'] = n_feat
                g.update_all(gcn_message, gcn_reduce)
                h = g.ndata['h']
                
                r = F.relu(self.linear1(h))
                i = self.softmax(self.linear2(r))
                fps += torch.sum(i, dim=0)

            out = F.relu(self.linear3(fps))
            out = self.linear4(out).squeeze(0)
        return fps, out

モデルは非常にシンプルで、GCNを行った後に線形関数に通し、ソフトマックを通します。この状態だと(バッチ数×ノード数×特徴量)という次元なのでノードに対してsum()を取ってやることで次元を落としてやります。得られたフィンガープリントは、そのままMLPを用いて回帰に用いることができます。学習は、通常のニューラルネットの学習方法と同じです。なお、この実装ではミニバッチに対応していません。DGLではミニバッチを作るときは、グラフ同士の隣接行列を対角行列にし、大きなグラフとすることで対応しています。ミニバッチに対応させるのは少々面倒なので、実装はしていません。また、softmax後にsumを取る必要があるためcudaに対応できていません。

学習に関しては、通常のPyTorchの実装と同じですが、ここではearly stoppingは実装していません。

loss_func = nn.MSELoss(reduction="none")
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
epoch_losses = []
for epoch in range(10):
    epoch_loss = 0
    for ite, batch in enumerate(train_loader):
        _, bg, label, masks = batch
        n_feat = bg.ndata['h']
        fps, prediction = model(bg, n_feat)
        loss = (loss_func(prediction, label) * (masks != 0).float()).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (ite + 1)
    print(f'Epoch {epoch}, loss {epoch_loss:.4f}'
    epoch_losses.append(epoch_loss)

実験条件

zincデータ22万のうち教師データ1800, テストデータ200としています。logP(溶解性の指標)に対して

  1. NFP + MLP
  2. NFP + Random Forest
  3. 特徴量初期値のノードに対する和 + Random Forest
  4. ECFP + Random Forest
  5. RDkit Morgan Fingerprint + Random Forest

で比較を行っています。NFPの条件は、radius=2, nbits = 16です。MLPは2層の隠れ層のノード数は64で、エポック数は10です。Random Forestはscikit-learnのデフォルトです。ECFP, RDKit Morgan fingerprint はともにradius=2, nbits=2048です。

結果

テストデータに対する予測精度のスコアを載せます。

model r2 RMSE
1 0.8945 0.3577
2 0.8902 0.3725
3 0.7827 0.7371
4 0.7505 0.8463
5 0.7346 0.9002

NFPは学習させている分、予測精度が上がっています。分子グラフの特徴量をうまく埋め込むことができているようです。途中で取り出した16次元のベクトルをRandom Forestで回帰しても精度は同じくらい出ていますが、データ数を増やしてみるとMLPの方が精度が高めに出ることの方が多いです。ECFPの精度は自作の方が良かったのですが、データセットに依存します。また、radiusを増加させていくと精度が落ちていきます。

実際には、エッジ情報を用いた方が精度がもう少し向上します。MPNNやSchnetはDGLでAlchemyの学習済みモデルもあるので、転移学習させても面白いかもしれません。

最後に

GNNを用いることは多いのですが、実際のタスクに適用すると大抵の場合は求めたい精度に達しません。これはグラフのノード・エッジ特徴量の初期値に依存し、そもそもこれらの特徴量がある程度、目的変数に対して相関がないと、回帰・識別タスクに適用しても悲惨なことになることが多いです。やはり楽な方法はなく、特徴量エンジニアリング、第一原理計算やMD計算など原理原則に基づいたパラメータの探索と色々試行錯誤が必要となってきます。低分子に対するGNNの研究はもうやりつくされた感があります。ポリマー、タンパク質など巨大な分子はDFT計算でも計算が難しいので、GNNをどう適用するかが大事になってきそうです。とりあえずGNN+NASや分散学習を試していきたい。

ソースコードはこちら。

github.com

参考文献

  • ChemAxon ECFP
  • RDKit Fingerprints Document
  • 化合物でもDeep Learningがしたい