MMACE Paper: Graph Neural Network for HIV Inhibition
Show code cell source
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
import numpy as np
import tensorflow as tf
import selfies as sf
import exmol
import skunk
import warnings
from rdkit import Chem
from rdkit.Chem.Draw import rdDepictor
rdDepictor.SetPreferCoordGen(True)
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_useSVG = True
sns.set_context("notebook")
sns.set_style(
"dark",
{
"xtick.bottom": True,
"ytick.left": True,
"xtick.color": "#666666",
"ytick.color": "#666666",
"axes.edgecolor": "#666666",
"axes.linewidth": 0.8,
"figure.dpi": 300,
},
)
color_cycle = ["#1BBC9B", "#F06060", "#F3B562", "#6e5687", "#5C4B51"]
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=color_cycle)
np.random.seed(0)
2023-12-04 18:06:41.531313: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-12-04 18:06:41.568496: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-12-04 18:06:41.569696: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-12-04 18:06:42.305321: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[1], line 10
8 import tensorflow as tf
9 import selfies as sf
---> 10 import exmol
11 import skunk
12 import warnings
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/exmol/__init__.py:3
1 from .version import __version__
2 from . import stoned
----> 3 from .exmol import *
4 from .data import *
5 from .stoned import sanitize_smiles
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/exmol/exmol.py:30
28 from rdkit.Chem import rdchem # type: ignore
29 from rdkit.DataStructs.cDataStructs import BulkTanimotoSimilarity, TanimotoSimilarity # type: ignore
---> 30 import langchain.llms as llms
31 import langchain.prompts as prompts
33 from . import stoned
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain/llms/__init__.py:22
1 """
2 **LLM** classes provide
3 access to the large language model (**LLM**) APIs and services.
(...)
18 AIMessage, BaseMessage
19 """ # noqa: E501
20 from typing import Any, Callable, Dict, Type
---> 22 from langchain.llms.base import BaseLLM
25 def _import_ai21() -> Any:
26 from langchain.llms.ai21 import AI21
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain/llms/base.py:2
1 # Backwards compatibility.
----> 2 from langchain_core.language_models import BaseLanguageModel
3 from langchain_core.language_models.llms import (
4 LLM,
5 BaseLLM,
(...)
9 update_cache,
10 )
12 __all__ = [
13 "create_base_retry_decorator",
14 "get_prompts",
(...)
19 "LLM",
20 ]
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/language_models/__init__.py:7
1 from langchain_core.language_models.base import (
2 BaseLanguageModel,
3 LanguageModelInput,
4 LanguageModelOutput,
5 get_tokenizer,
6 )
----> 7 from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
8 from langchain_core.language_models.llms import LLM, BaseLLM
10 __all__ = [
11 "BaseLanguageModel",
12 "BaseChatModel",
(...)
18 "LanguageModelOutput",
19 ]
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/language_models/chat_models.py:20
7 from functools import partial
8 from typing import (
9 TYPE_CHECKING,
10 Any,
(...)
17 cast,
18 )
---> 20 from langchain_core.callbacks import (
21 AsyncCallbackManager,
22 AsyncCallbackManagerForLLMRun,
23 BaseCallbackManager,
24 CallbackManager,
25 CallbackManagerForLLMRun,
26 Callbacks,
27 )
28 from langchain_core.globals import get_llm_cache
29 from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/callbacks/__init__.py:13
1 from langchain_core.callbacks.base import (
2 AsyncCallbackHandler,
3 BaseCallbackHandler,
(...)
11 ToolManagerMixin,
12 )
---> 13 from langchain_core.callbacks.manager import (
14 AsyncCallbackManager,
15 AsyncCallbackManagerForChainGroup,
16 AsyncCallbackManagerForChainRun,
17 AsyncCallbackManagerForLLMRun,
18 AsyncCallbackManagerForRetrieverRun,
19 AsyncCallbackManagerForToolRun,
20 AsyncParentRunManager,
21 AsyncRunManager,
22 BaseRunManager,
23 CallbackManager,
24 CallbackManagerForChainGroup,
25 CallbackManagerForChainRun,
26 CallbackManagerForLLMRun,
27 CallbackManagerForRetrieverRun,
28 CallbackManagerForToolRun,
29 ParentRunManager,
30 RunManager,
31 )
32 from langchain_core.callbacks.stdout import StdOutCallbackHandler
33 from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/callbacks/manager.py:26
9 from typing import (
10 TYPE_CHECKING,
11 Any,
(...)
22 cast,
23 )
24 from uuid import UUID
---> 26 from langsmith.run_helpers import get_run_tree_context
27 from tenacity import RetryCallState
29 from langchain_core.callbacks.base import (
30 BaseCallbackHandler,
31 BaseCallbackManager,
(...)
37 ToolManagerMixin,
38 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/__init__.py:10
6 except metadata.PackageNotFoundError:
7 # Case where package metadata is not available.
8 __version__ = ""
---> 10 from langsmith.client import Client
11 from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
12 from langsmith.run_helpers import trace, traceable
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/client.py:43
41 from langsmith import schemas as ls_schemas
42 from langsmith import utils as ls_utils
---> 43 from langsmith.evaluation import evaluator as ls_evaluator
45 if TYPE_CHECKING:
46 import pandas as pd
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/evaluation/__init__.py:4
1 """Evaluation Helpers."""
3 from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
----> 4 from langsmith.evaluation.string_evaluator import StringEvaluator
6 __all__ = ["EvaluationResult", "RunEvaluator", "StringEvaluator"]
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/evaluation/string_evaluator.py:3
1 from typing import Callable, Dict, Optional
----> 3 from pydantic import BaseModel
5 from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
6 from langsmith.schemas import Example, Run
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/__init__.py:372, in __getattr__(attr_name)
370 return import_module(f'.{attr_name}', package=package)
371 else:
--> 372 module = import_module(module_name, package=package)
373 return getattr(module, attr_name)
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/importlib/__init__.py:127, in import_module(name, package)
125 break
126 level += 1
--> 127 return _bootstrap._gcd_import(name[level:], package, level)
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/main.py:15
12 import typing_extensions
13 from pydantic_core import PydanticUndefined
---> 15 from ._internal import (
16 _config,
17 _decorators,
18 _fields,
19 _forward_ref,
20 _generics,
21 _mock_val_ser,
22 _model_construction,
23 _repr,
24 _typing_extra,
25 _utils,
26 )
27 from ._migration import getattr_migration
28 from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/_internal/_decorators.py:15
12 from typing_extensions import Literal, TypeAlias, is_typeddict
14 from ..errors import PydanticUserError
---> 15 from ._core_utils import get_type_ref
16 from ._internal_dataclass import slots_true
17 from ._typing_extra import get_function_type_hints
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/_internal/_core_utils.py:16
14 from pydantic_core import CoreSchema, core_schema
15 from pydantic_core import validate_core_schema as _validate_core_schema
---> 16 from typing_extensions import TypeAliasType, TypeGuard, get_args, get_origin
18 from . import _repr
19 from ._typing_extra import is_generic_alias
ImportError: cannot import name 'TypeAliasType' from 'typing_extensions' (/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/typing_extensions.py)
# shuffle rows and sample fom HIV dataset
hivdata = pd.read_csv("HIV.csv")
# REDUCED Data FOR CI
hivdata = hivdata.sample(frac=0.1).reset_index(drop=True)
hivdata.head()
def gen_smiles2graph(sml):
"""Argument for the RD2NX function should be a valid SMILES sequence
returns: the graph
"""
m, smi_canon, status = exmol.stoned.sanitize_smiles(sml)
# m = Chem.MolFromSmiles(smi_canon)
m = Chem.AddHs(m)
order_string = {
Chem.rdchem.BondType.SINGLE: 1,
Chem.rdchem.BondType.DOUBLE: 2,
Chem.rdchem.BondType.TRIPLE: 3,
Chem.rdchem.BondType.AROMATIC: 4,
}
N = len(list(m.GetAtoms()))
# nodes = np.zeros((N,100))
nodes = np.zeros((440, 100))
for i in m.GetAtoms():
nodes[i.GetIdx(), i.GetAtomicNum()] = 1
# adj = np.zeros((N,N))
adj = np.zeros((440, 440))
for j in m.GetBonds():
u = min(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
v = max(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
order = j.GetBondType()
if order in order_string:
order = order_string[order]
else:
raise Warning("Ignoring bond order" + order)
adj[u, v] = 1
adj[v, u] = 1
adj += np.eye(440)
return nodes, adj
class GCNLayer(tf.keras.layers.Layer):
"""Implementation of GCN as layer"""
def __init__(self, activation=None, **kwargs):
# constructor, which just calls super constructor
# and turns requested activation into a callable function
super(GCNLayer, self).__init__(**kwargs)
self.activation = tf.keras.activations.get(activation)
def build(self, input_shape):
# create trainable weights
node_shape, adj_shape = input_shape
self.w = self.add_weight(shape=(node_shape[2], node_shape[2]), name="w")
def call(self, inputs):
# split input into nodes, adj
nodes, adj = inputs
# compute degree
degree = tf.reduce_sum(adj, axis=-1)
# GCN equation
new_nodes = tf.einsum("bi,bij,bjk,kl->bil", 1 / degree, adj, nodes, self.w)
out = self.activation(new_nodes)
return out, adj
class GRLayer(tf.keras.layers.Layer):
"""Reduction layer: A GNN layer that computes average over all node features"""
def __init__(self, name="GRLayer", **kwargs):
super(GRLayer, self).__init__(name=name, **kwargs)
def call(self, inputs):
nodes, adj = inputs
reduction = tf.reduce_mean(nodes, axis=1)
return reduction
ninput = tf.keras.Input(
(
None,
100,
)
)
ainput = tf.keras.Input(
(
None,
None,
)
)
# GCN block
x = GCNLayer("relu")([ninput, ainput])
x = GCNLayer("relu")(x)
x = GCNLayer("relu")(x)
x = GCNLayer("relu")(x)
# reduce to graph features
x = GRLayer()(x)
# standard layers
x = tf.keras.layers.Dense(256)(x)
x = tf.keras.layers.Dense(1, activation="sigmoid")(x)
gcnmodel = tf.keras.Model(inputs=(ninput, ainput), outputs=x)
gcnmodel.compile(
"adam",
loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
metrics=["accuracy"],
)
gcnmodel.summary()
def gen_data():
for i in range(len(hivdata)):
graph = gen_smiles2graph(hivdata.smiles[i])
activity = hivdata.HIV_active[i]
yield graph, activity
data = tf.data.Dataset.from_generator(
gen_data,
output_types=((tf.float32, tf.float32), tf.float32),
output_shapes=(
(tf.TensorShape([None, 100]), tf.TensorShape([None, None])),
tf.TensorShape([]),
),
)
N = len(hivdata)
split = int(0.1 * N)
test_data = data.take(split)
nontest = data.skip(split)
val_data, train_data = nontest.take(split), nontest.skip(split).shuffle(1000)
class_weight = {0: 1.0, 1: 30.0} # to account for class imbalance
result = gcnmodel.fit(
train_data.batch(128),
validation_data=val_data.batch(128),
epochs=30,
verbose=0,
class_weight=class_weight,
)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.plot(result.history["loss"], label="training")
ax1.plot(result.history["val_loss"], label="validation")
ax1.legend()
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax2.plot(result.history["accuracy"], label="training")
ax2.plot(result.history["val_accuracy"], label="validation")
ax2.legend()
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
fig.tight_layout()
fig.savefig("gnn-loss-acc.png", dpi=180)
fig.show()
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
prediction = []
test_y = []
for x, y in test_data.as_numpy_iterator():
yhat = gcnmodel((x[0][np.newaxis, ...], x[1][np.newaxis, ...]))
prediction.append(yhat.numpy())
test_y.append(y)
prediction = np.array(prediction).flatten()
test_y = np.array(test_y)
fpr_keras, tpr_keras, thresholds_keras = roc_curve(test_y, prediction)
auc_keras = auc(fpr_keras, tpr_keras)
plt.figure(figsize=(6, 4), dpi=100)
plt.plot(fpr_keras, tpr_keras, label="AUC = {:.3f}".format(auc_keras))
plt.plot([0, 1], [0, 1], linestyle="--")
plt.xlabel("True Positive Rate")
plt.ylabel("False Positive Rate")
plt.legend()
plt.savefig("gnn-roc.png", dpi=300)
plt.show()
CF explanation
The following example find CFs for a given molecule where the HIV activity is zero.
def predictor_function(smiles, selfies):
# print('inut:',smiles)
labels = []
for sml in smiles:
nodes, adj_mat = gen_smiles2graph(sml)
pred = gcnmodel((nodes[np.newaxis, ...], adj_mat[np.newaxis, ...])).numpy()
labels.append(pred)
labels = np.array(labels).flatten()
bin_labels = np.where(labels > 0.5, np.ones(len(labels)), np.zeros(len(labels)))
target_act = np.zeros(len(labels))
return abs(bin_labels - target_act).astype(bool)
basic = exmol.get_basic_alphabet()
stoned_kwargs = {"num_samples": 1500, "alphabet": basic, "max_mutations": 2}
example_base = "C=CCN(CC=C)C(=O)Nc1ccc(C(=O)NN=Cc2cccc(OC)c2OC)cc1"
space = exmol.sample_space(
example_base,
predictor_function,
stoned_kwargs={"num_samples": 1500, "alphabet": basic, "max_mutations": 2},
quiet=True,
)
exps = exmol.cf_explain(space)
fkw = {"figsize": (8, 6)}
mpl.rc("axes", titlesize=12)
exmol.plot_cf(exps, figure_kwargs=fkw, mol_size=(450, 400), nrows=1)
plt.savefig("gnn-simple.png", dpi=180)
svg = exmol.insert_svg(exps, mol_fontsize=16)
with open("gnn-simple.svg", "w") as f:
f.write(svg)
font = {"family": "normal", "weight": "normal", "size": 22}
exmol.plot_space(
space,
exps,
figure_kwargs=fkw,
mol_size=(300, 200),
offset=0,
cartoon=True,
rasterized=True,
)
plt.scatter([], [], label="Counterfactual", s=150, color=plt.get_cmap("viridis")(1.0))
plt.scatter([], [], label="Same Class", s=150, color=plt.get_cmap("viridis")(0.0))
plt.legend(fontsize=22)
plt.tight_layout()
svg = exmol.insert_svg(exps, mol_fontsize=16)
with open("gnn-space.svg", "w") as f:
f.write(svg)
exps = exmol.cf_explain(space, nmols=19)
fkw = {"figsize": (12, 10)}
mpl.rc("axes", titlesize=10)
exmol.plot_cf(
exps, figure_kwargs=fkw, mol_size=(450, 400), mol_fontsize=26, nrows=4, ncols=5
)
plt.savefig("gnn-simple-20.png", bbox_inches="tight", dpi=300)
svg = exmol.insert_svg(exps, mol_fontsize=14)
with open("gnn-simple-20.svg", "w") as f:
f.write(svg)
fkw = {"figsize": (8, 6)}
font = {"family": "normal", "weight": "normal", "size": 22}
exmol.plot_space(space, exps, figure_kwargs=fkw, mol_size=(350, 300), mol_fontsize=22)
plt.scatter([], [], label="Same Label", s=150, color=plt.get_cmap("viridis")(1.0))
plt.scatter([], [], label="Counterfactual", s=150, color=plt.get_cmap("viridis")(0.0))
plt.legend(fontsize=22)
plt.savefig("gnn-space.png", bbox_inches="tight", dpi=180)