LIME paper: Recurrent Neural Network for Solubility Prediciton
Import packages and set up RNN
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, FancyBboxPatch
from matplotlib.offsetbox import AnnotationBbox
import seaborn as sns
import skunk
import matplotlib as mpl
import numpy as np
import tensorflow as tf
import selfies as sf
import exmol
from dataclasses import dataclass
from rdkit.Chem.Draw import rdDepictor, MolsToGridImage
from rdkit.Chem import MolFromSmiles, MACCSkeys
import random
rdDepictor.SetPreferCoordGen(True)
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import urllib.request
urllib.request.urlretrieve(
"https://github.com/google/fonts/raw/main/ofl/ibmplexmono/IBMPlexMono-Regular.ttf",
"IBMPlexMono-Regular.ttf",
)
fe = font_manager.FontEntry(fname="IBMPlexMono-Regular.ttf", name="plexmono")
font_manager.fontManager.ttflist.append(fe)
plt.rcParams.update(
{
"axes.facecolor": "#f5f4e9",
"grid.color": "#AAAAAA",
"axes.edgecolor": "#333333",
"figure.facecolor": "#FFFFFF",
"axes.grid": False,
"axes.prop_cycle": plt.cycler("color", plt.cm.Dark2.colors),
"font.family": fe.name,
"figure.figsize": (3.5, 3.5 / 1.2),
"ytick.left": True,
"xtick.bottom": True,
}
)
mpl.rcParams["font.size"] = 12
soldata = pd.read_csv(
"https://github.com/whitead/dmol-book/raw/main/data/curated-solubility-dataset.csv"
)
features_start_at = list(soldata.columns).index("MolWt")
np.random.seed(0)
random.seed(0)
2023-12-04 18:07:06.565393: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-12-04 18:07:06.603003: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-12-04 18:07:06.603764: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-12-04 18:07:07.302411: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[1], line 11
9 import tensorflow as tf
10 import selfies as sf
---> 11 import exmol
12 from dataclasses import dataclass
13 from rdkit.Chem.Draw import rdDepictor, MolsToGridImage
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/exmol/__init__.py:3
1 from .version import __version__
2 from . import stoned
----> 3 from .exmol import *
4 from .data import *
5 from .stoned import sanitize_smiles
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/exmol/exmol.py:30
28 from rdkit.Chem import rdchem # type: ignore
29 from rdkit.DataStructs.cDataStructs import BulkTanimotoSimilarity, TanimotoSimilarity # type: ignore
---> 30 import langchain.llms as llms
31 import langchain.prompts as prompts
33 from . import stoned
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain/llms/__init__.py:22
1 """
2 **LLM** classes provide
3 access to the large language model (**LLM**) APIs and services.
(...)
18 AIMessage, BaseMessage
19 """ # noqa: E501
20 from typing import Any, Callable, Dict, Type
---> 22 from langchain.llms.base import BaseLLM
25 def _import_ai21() -> Any:
26 from langchain.llms.ai21 import AI21
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain/llms/base.py:2
1 # Backwards compatibility.
----> 2 from langchain_core.language_models import BaseLanguageModel
3 from langchain_core.language_models.llms import (
4 LLM,
5 BaseLLM,
(...)
9 update_cache,
10 )
12 __all__ = [
13 "create_base_retry_decorator",
14 "get_prompts",
(...)
19 "LLM",
20 ]
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/language_models/__init__.py:7
1 from langchain_core.language_models.base import (
2 BaseLanguageModel,
3 LanguageModelInput,
4 LanguageModelOutput,
5 get_tokenizer,
6 )
----> 7 from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
8 from langchain_core.language_models.llms import LLM, BaseLLM
10 __all__ = [
11 "BaseLanguageModel",
12 "BaseChatModel",
(...)
18 "LanguageModelOutput",
19 ]
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/language_models/chat_models.py:20
7 from functools import partial
8 from typing import (
9 TYPE_CHECKING,
10 Any,
(...)
17 cast,
18 )
---> 20 from langchain_core.callbacks import (
21 AsyncCallbackManager,
22 AsyncCallbackManagerForLLMRun,
23 BaseCallbackManager,
24 CallbackManager,
25 CallbackManagerForLLMRun,
26 Callbacks,
27 )
28 from langchain_core.globals import get_llm_cache
29 from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/callbacks/__init__.py:13
1 from langchain_core.callbacks.base import (
2 AsyncCallbackHandler,
3 BaseCallbackHandler,
(...)
11 ToolManagerMixin,
12 )
---> 13 from langchain_core.callbacks.manager import (
14 AsyncCallbackManager,
15 AsyncCallbackManagerForChainGroup,
16 AsyncCallbackManagerForChainRun,
17 AsyncCallbackManagerForLLMRun,
18 AsyncCallbackManagerForRetrieverRun,
19 AsyncCallbackManagerForToolRun,
20 AsyncParentRunManager,
21 AsyncRunManager,
22 BaseRunManager,
23 CallbackManager,
24 CallbackManagerForChainGroup,
25 CallbackManagerForChainRun,
26 CallbackManagerForLLMRun,
27 CallbackManagerForRetrieverRun,
28 CallbackManagerForToolRun,
29 ParentRunManager,
30 RunManager,
31 )
32 from langchain_core.callbacks.stdout import StdOutCallbackHandler
33 from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/callbacks/manager.py:26
9 from typing import (
10 TYPE_CHECKING,
11 Any,
(...)
22 cast,
23 )
24 from uuid import UUID
---> 26 from langsmith.run_helpers import get_run_tree_context
27 from tenacity import RetryCallState
29 from langchain_core.callbacks.base import (
30 BaseCallbackHandler,
31 BaseCallbackManager,
(...)
37 ToolManagerMixin,
38 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/__init__.py:10
6 except metadata.PackageNotFoundError:
7 # Case where package metadata is not available.
8 __version__ = ""
---> 10 from langsmith.client import Client
11 from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
12 from langsmith.run_helpers import trace, traceable
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/client.py:43
41 from langsmith import schemas as ls_schemas
42 from langsmith import utils as ls_utils
---> 43 from langsmith.evaluation import evaluator as ls_evaluator
45 if TYPE_CHECKING:
46 import pandas as pd
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/evaluation/__init__.py:4
1 """Evaluation Helpers."""
3 from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
----> 4 from langsmith.evaluation.string_evaluator import StringEvaluator
6 __all__ = ["EvaluationResult", "RunEvaluator", "StringEvaluator"]
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/evaluation/string_evaluator.py:3
1 from typing import Callable, Dict, Optional
----> 3 from pydantic import BaseModel
5 from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
6 from langsmith.schemas import Example, Run
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/__init__.py:372, in __getattr__(attr_name)
370 return import_module(f'.{attr_name}', package=package)
371 else:
--> 372 module = import_module(module_name, package=package)
373 return getattr(module, attr_name)
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/importlib/__init__.py:127, in import_module(name, package)
125 break
126 level += 1
--> 127 return _bootstrap._gcd_import(name[level:], package, level)
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/main.py:15
12 import typing_extensions
13 from pydantic_core import PydanticUndefined
---> 15 from ._internal import (
16 _config,
17 _decorators,
18 _fields,
19 _forward_ref,
20 _generics,
21 _mock_val_ser,
22 _model_construction,
23 _repr,
24 _typing_extra,
25 _utils,
26 )
27 from ._migration import getattr_migration
28 from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/_internal/_decorators.py:15
12 from typing_extensions import Literal, TypeAlias, is_typeddict
14 from ..errors import PydanticUserError
---> 15 from ._core_utils import get_type_ref
16 from ._internal_dataclass import slots_true
17 from ._typing_extra import get_function_type_hints
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/_internal/_core_utils.py:16
14 from pydantic_core import CoreSchema, core_schema
15 from pydantic_core import validate_core_schema as _validate_core_schema
---> 16 from typing_extensions import TypeAliasType, TypeGuard, get_args, get_origin
18 from . import _repr
19 from ._typing_extra import is_generic_alias
ImportError: cannot import name 'TypeAliasType' from 'typing_extensions' (/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/typing_extensions.py)
# scramble them
soldata = soldata.sample(frac=0.01, random_state=0).reset_index(drop=True)
soldata.head()
from rdkit.Chem import MolToSmiles
def _randomize_smiles(mol, isomericSmiles=True):
return MolToSmiles(
mol,
canonical=False,
doRandom=True,
isomericSmiles=isomericSmiles,
kekuleSmiles=random.random() < 0.5,
)
smiles = list(soldata["SMILES"])
solubilities = list(soldata["Solubility"])
aug_data = 10
def largest_mol(smiles):
ss = smiles.split(".")
ss.sort(key=lambda a: len(a))
return ss[-1]
aug_smiles = []
aug_solubilities = []
for sml, sol in zip(smiles, solubilities):
sml = largest_mol(sml)
if len(sml) <= 4:
continue # ion or metal
new_smls = []
new_smls.append(sml)
aug_solubilities.append(sol)
for _ in range(aug_data):
try:
new_sml = _randomize_smiles(MolFromSmiles(sml))
if new_sml not in new_smls:
new_smls.append(new_sml)
aug_solubilities.append(sol)
except:
continue
aug_smiles.extend(new_smls)
aug_df_AqSolDB = pd.DataFrame(
data={"SMILES": aug_smiles, "Solubility": aug_solubilities}
)
print(f"The dataset was augmented from {len(soldata)} to {len(aug_df_AqSolDB)}.")
selfies_list = []
for s in aug_df_AqSolDB.SMILES:
try:
selfies_list.append(sf.encoder(exmol.sanitize_smiles(s)[1]))
except sf.EncoderError:
selfies_list.append(None)
len(selfies_list)
basic = set(exmol.get_basic_alphabet())
data_vocab = set(
sf.get_alphabet_from_selfies([s for s in selfies_list if s is not None])
)
vocab = ["[nop]"]
vocab.extend(list(data_vocab.union(basic)))
vocab_stoi = {o: i for o, i in zip(vocab, range(len(vocab)))}
def selfies2ints(s):
result = []
for token in sf.split_selfies(s):
if token in vocab_stoi:
result.append(vocab_stoi[token])
else:
result.append(np.nan)
# print('Warning')
return result
def ints2selfies(v):
return "".join([vocab[i] for i in v])
# test them out
s = selfies_list[0]
print("selfies:", s)
v = selfies2ints(s)
print("selfies2ints:", v)
so = ints2selfies(v)
# creating an object
@dataclass
class Config:
vocab_size: int
example_number: int
batch_size: int
buffer_size: int
embedding_dim: int
rnn_units: int
hidden_dim: int
drop_rate: float
config = Config(
vocab_size=len(vocab),
example_number=len(selfies_list),
batch_size=128,
buffer_size=10000,
embedding_dim=64,
hidden_dim=32,
rnn_units=64,
drop_rate=0.20,
)
# now get sequences
encoded = [selfies2ints(s) for s in selfies_list if s is not None]
# check for non-Nones
dsolubilities = aug_df_AqSolDB.Solubility.values[[s is not None for s in selfies_list]]
padded_seqs = tf.keras.preprocessing.sequence.pad_sequences(encoded, padding="post")
# Should be shuffled from the beginning, so no worries
N = len(padded_seqs)
split = int(0.1 * N)
# Now build dataset
test_data = tf.data.Dataset.from_tensor_slices(
(padded_seqs[:split], dsolubilities[:split])
).batch(config.batch_size)
nontest = tf.data.Dataset.from_tensor_slices(
(
padded_seqs[split:],
dsolubilities[split:],
)
)
val_data, train_data = nontest.take(split).batch(config.batch_size), nontest.skip(
split
).shuffle(config.buffer_size).batch(config.batch_size).prefetch(
tf.data.experimental.AUTOTUNE
)
model = tf.keras.Sequential()
# make embedding and indicate that 0 should be treated as padding mask
model.add(
tf.keras.layers.Embedding(
input_dim=config.vocab_size, output_dim=config.embedding_dim, mask_zero=True
)
)
model.add(tf.keras.layers.Dropout(config.drop_rate))
# RNN layer
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(config.rnn_units)))
model.add(tf.keras.layers.Dropout(config.drop_rate))
# a dense hidden layer
model.add(tf.keras.layers.Dense(config.hidden_dim, activation="relu"))
model.add(tf.keras.layers.Dropout(config.drop_rate))
# regression, so no activation
model.add(tf.keras.layers.Dense(1))
model.summary()
model.compile(tf.optimizers.Adam(1e-3), loss="mean_squared_error")
# verbose=0 silences output, to get progress bar set verbose=1
result = model.fit(train_data, validation_data=val_data, epochs=50)
model.save("solubility-rnn-accurate")
# model = tf.keras.models.load_model('solubility-rnn-accurate/')
plt.figure(figsize=(5, 3.5))
plt.plot(result.history["loss"], label="training")
plt.plot(result.history["val_loss"], label="validation")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.savefig("rnn-loss.png", bbox_inches="tight", dpi=300)
plt.show()
yhat = []
test_y = []
for x, y in test_data:
yhat.extend(model(x).numpy().flatten())
test_y.extend(y.numpy().flatten())
yhat = np.array(yhat)
test_y = np.array(test_y)
# plot test data
plt.figure(figsize=(5, 3.5))
plt.plot(test_y, test_y, ":")
plt.plot(test_y, yhat, ".")
plt.text(
max(test_y) - 6,
min(test_y) + 1,
f"correlation = {np.corrcoef(test_y, yhat)[0,1]:.3f}",
)
plt.text(
max(test_y) - 6, min(test_y), f"loss = {np.sqrt(np.mean((test_y - yhat)**2)):.3f}"
)
plt.xlabel(r"$y$")
plt.ylabel(r"$\hat{y}$")
plt.title("Testing Data")
plt.savefig("rnn-fit.png", dpi=300, bbox_inches="tight")
plt.show()
LIME explanations
In the following example, we find out what descriptors influence solubility of a molecules. For example, let’s say we have a molecule with LogS=1.5. We create a perturbed chemical space around that molecule using stoned
method and then use lime
to find out which descriptors affect solubility predictions for that molecule.
Wrapper function for RNN, to use in STONED
# Predictor function is used as input to sample_space function
def predictor_function(smile_list, selfies):
encoded = [selfies2ints(s) for s in selfies]
# check for nans
valid = [1.0 if sum(e) > 0 else np.nan for e in encoded]
encoded = [np.nan_to_num(e, nan=0) for e in encoded]
padded_seqs = tf.keras.preprocessing.sequence.pad_sequences(encoded, padding="post")
labels = np.reshape(model(padded_seqs, training=False), (-1))
return labels * valid
Descriptor explanations
# Make sure SMILES doesn't contain multiple fragments
smi = "CCCCC(=O)N(CC1=CC=C(C=C1)C2=C(C=CC=C2)C3=NN=N[NH]3)C(C(C)C)C(O)=O" # mol1 - not soluble
# smi = "CC(CC(=O)NC1=CC=CC=C1)C(=O)O" #mol2 - highly soluble
af = exmol.get_basic_alphabet()
stoned_kwargs = {
"num_samples": 5000,
"alphabet": af,
"max_mutations": 2,
}
space = exmol.sample_space(
smi, predictor_function, stoned_kwargs=stoned_kwargs, quiet=True
)
print(len(space))
from IPython.display import display, SVG
desc_type = ["Classic", "ecfp", "maccs"]
for d in desc_type:
beta = exmol.lime_explain(space, descriptor_type=d)
if d == "ecfp":
display(
SVG(
exmol.plot_descriptors(
space, output_file=f"{d}_mol2.svg", return_svg=True
)
)
)
plt.close()
else:
exmol.plot_descriptors(space, output_file=f"{d}_mol2.svg")
Text explanations
exmol.lime_explain(space, "ecfp")
s1_ecfp = exmol.text_explain(space, "ecfp")
explanation = exmol.text_explain_generate(s1_ecfp, "aqueous solubility")
print(explanation)
Similarity map
beta = exmol.lime_explain(space, "ecfp")
svg = exmol.plot_utils.similarity_map_using_tstats(space[0], return_svg=True)
display(SVG(svg))
# Write figure to file
with open("ecfp_similarity_map_mol2.svg", "w") as f:
f.write(svg)
# Inspect space
MolsToGridImage(
[MolFromSmiles(m.smiles) for m in space],
legends=[f"yhat = {m.yhat:.3}" for m in space],
molsPerRow=10,
maxMols=100,
)
How’s the fit?
fkw = {"figsize": (6, 4)}
font = {"family": "normal", "weight": "normal", "size": 16}
fig = plt.figure(figsize=(10, 5))
mpl.rc("axes", titlesize=12)
mpl.rc("font", size=16)
ax_dict = fig.subplot_mosaic("AABBB")
# Plot space by fit
svg = exmol.plot_utils.plot_space_by_fit(
space,
[space[0]],
figure_kwargs=fkw,
mol_size=(200, 200),
offset=1,
ax=ax_dict["B"],
beta=beta,
)
# Compute y_wls
w = np.array([1 / (1 + (1 / (e.similarity + 0.000001) - 1) ** 5) for e in space])
non_zero = w > 10 ** (-6)
w = w[non_zero]
N = w.shape[0]
ys = np.array([e.yhat for e in space])[non_zero].reshape(N).astype(float)
x_mat = np.array([list(e.descriptors.descriptors) for e in space])[non_zero].reshape(
N, -1
)
y_wls = x_mat @ beta
y_wls += np.mean(ys)
lower = np.min(ys)
higher = np.max(ys)
# set transparency using w
norm = plt.Normalize(min(w), max(w))
cmap = plt.cm.Oranges(w)
cmap[:, -1] = w
def weighted_mean(x, w):
return np.sum(x * w) / np.sum(w)
def weighted_cov(x, y, w):
return np.sum(w * (x - weighted_mean(x, w)) * (y - weighted_mean(y, w))) / np.sum(w)
def weighted_correlation(x, y, w):
return weighted_cov(x, y, w) / np.sqrt(
weighted_cov(x, x, w) * weighted_cov(y, y, w)
)
corr = weighted_correlation(ys, y_wls, w)
ax_dict["A"].plot(
np.linspace(lower, higher, 100), np.linspace(lower, higher, 100), "--", linewidth=2
)
sc = ax_dict["A"].scatter(ys, y_wls, s=50, marker=".", c=cmap, cmap=cmap)
ax_dict["A"].text(max(ys) - 3, min(ys) + 1, f"weighted \ncorrelation = {corr:.3f}")
ax_dict["A"].set_xlabel(r"$\hat{y}$")
ax_dict["A"].set_ylabel(r"$g$")
ax_dict["A"].set_title("Weighted Least Squares Fit")
ax_dict["A"].set_xlim(lower, higher)
ax_dict["A"].set_ylim(lower, higher)
ax_dict["A"].set_aspect(1.0 / ax_dict["A"].get_data_ratio(), adjustable="box")
sm = plt.cm.ScalarMappable(cmap=plt.cm.Oranges, norm=norm)
cbar = plt.colorbar(sm, orientation="horizontal", pad=0.15, ax=ax_dict["A"])
cbar.set_label("Chemical similarity")
plt.tight_layout()
plt.savefig("weighted_fit.svg", dpi=300, bbox_inches="tight", transparent=False)
Robustness to incomplete sampling
We first sample a reference chemical space, and then subsample smaller chemical spaces from this reference. Rank correlation is computed between important descriptors for the smaller subspaces and the reference space.
# Sample a big space
stoned_kwargs = {
"num_samples": 5000,
"alphabet": exmol.get_basic_alphabet(),
"max_mutations": 2,
}
space = exmol.sample_space(
smi, predictor_function, stoned_kwargs=stoned_kwargs, quiet=True
)
len(space)
# get descriptor attributions
exmol.lime_explain(space, "MACCS", return_beta=False)
# Assign feature ids for rank comparison
features = features = {
a: b
for a, b in zip(
space[0].descriptors.descriptor_names,
np.arange(len(space[0].descriptors.descriptors)),
)
}
# Get set of ranks for the reference space
baseline_imp = {
a: b
for a, b in zip(space[0].descriptors.descriptor_names, space[0].descriptors.tstats)
if not np.isnan(b)
}
baseline_imp = dict(
sorted(baseline_imp.items(), key=lambda item: abs(item[1]), reverse=True)
)
baseline_set = [features[x] for x in baseline_imp.keys()]
# Get subsets and calculate lime importances - subsample - get rank correlation
from scipy.stats import spearmanr
plt.figure(figsize=(4, 3))
N = len(space)
size = np.arange(500, N, 1000)
rank_corr = {N: 1}
for i, f in enumerate(size):
# subsample space
rank_corr[f] = []
for _ in range(10):
# subsample space of size f
idx = np.random.choice(np.arange(N), size=f, replace=False)
subspace = [space[i] for i in idx]
# get desc attributions
ss_beta = exmol.lime_explain(subspace, descriptor_type="MACCS")
ss_imp = {
a: b
for a, b in zip(
subspace[0].descriptors.descriptor_names, subspace[0].descriptors.tstats
)
if not np.isnan(b)
}
ss_imp = dict(
sorted(ss_imp.items(), key=lambda item: abs(item[1]), reverse=True)
)
ss_set = [features[x] for x in ss_imp.keys()]
# Get ranks for subsampled space and compare with reference
ranks = {a: [b] for a, b in zip(baseline_set[:5], np.arange(1, 6))}
for j, s in enumerate(ss_set):
if s in ranks:
ranks[s].append(j + 1)
# compute rank correlation
r = spearmanr(np.arange(1, 6), [ranks[x][1] for x in ranks])
rank_corr[f].append(r.correlation)
plt.scatter(f, np.mean(rank_corr[f]), color="#13254a", marker="o")
plt.scatter(N, 1.0, color="red", marker="o")
plt.axvline(x=N, linestyle=":", color="red")
plt.xlabel("Size of chemical space")
plt.ylabel("Rank correlation")
plt.tight_layout()
plt.savefig("rank correlation.svg", dpi=300, bbox_inches="tight")
Effect of mutation number, alphabet and size of chemical space
# Mutation
desc_type = ["Classic"]
muts = [1, 2, 3]
for i in muts:
stoned_kwargs = {
"num_samples": 2500,
"alphabet": exmol.get_basic_alphabet(),
"min_mutations": i,
"max_mutations": i,
}
space = exmol.sample_space(
smi, predictor_function, stoned_kwargs=stoned_kwargs, quiet=True
)
for d in desc_type:
exmol.lime_explain(space, descriptor_type=d)
exmol.plot_descriptors(space, title=f"Mutations={i}")
# Alphabet
basic = exmol.get_basic_alphabet()
train = sf.get_alphabet_from_selfies([s for s in selfies_list if s is not None])
wide = sf.get_semantic_robust_alphabet()
desc_type = ["MACCS"]
alphs = {"Basic": basic, "Training Data": train, "SELFIES": wide}
for a in alphs:
stoned_kwargs = {"num_samples": 2500, "alphabet": alphs[a], "max_mutations": 2}
space = exmol.sample_space(
smi, predictor_function, stoned_kwargs=stoned_kwargs, quiet=True
)
for d in desc_type:
exmol.lime_explain(space, descriptor_type=d)
exmol.plot_descriptors(space, title=f"Alphabet: {a}")
# Size of space
desc_type = ["MACCS"]
space_size = [1500, 2000, 2500]
for s in space_size:
stoned_kwargs = {
"num_samples": s,
"alphabet": exmol.get_basic_alphabet(),
"max_mutations": 2,
}
space = exmol.sample_space(
smi, predictor_function, stoned_kwargs=stoned_kwargs, quiet=True
)
for d in desc_type:
exmol.lime_explain(space, descriptor_type=d)
exmol.plot_descriptors(
space,
title=f"Chemical space size={s}",
)