LIME paper: Random Forest for Solubility Prediciton

Import packages

import pandas as pd
import matplotlib.pyplot as plt

# import seaborn as sns
import matplotlib as mpl
import rdkit, rdkit.Chem, rdkit.Chem.Draw
from rdkit.Chem.Draw import IPythonConsole
import numpy as np
import mordred, mordred.descriptors
from mordred import HydrogenBond, Polarizability
from mordred import SLogP, AcidBase, BertzCT, Aromatic, BondCount, AtomCount
from mordred import Calculator

import exmol as exmol
from rdkit.Chem.Draw import rdDepictor
import os
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, plot_roc_curve

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
rdDepictor.SetPreferCoordGen(True)

IPythonConsole.ipython_useSVG = True
color_cycle = ["#F06060", "#1BBC9B", "#F3B562", "#6e5687", "#5C4B51"]
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=color_cycle)
np.random.seed(0)
soldata = pd.read_csv(
    "https://github.com/whitead/dmol-book/raw/main/data/curated-solubility-dataset.csv"
)
features_start_at = list(soldata.columns).index("MolWt")
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[1], line 14
     11 from mordred import SLogP, AcidBase, BertzCT, Aromatic, BondCount, AtomCount
     12 from mordred import Calculator
---> 14 import exmol as exmol
     15 from rdkit.Chem.Draw import rdDepictor
     16 import os

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/exmol/__init__.py:3
      1 from .version import __version__
      2 from . import stoned
----> 3 from .exmol import *
      4 from .data import *
      5 from .stoned import sanitize_smiles

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/exmol/exmol.py:30
     28 from rdkit.Chem import rdchem  # type: ignore
     29 from rdkit.DataStructs.cDataStructs import BulkTanimotoSimilarity, TanimotoSimilarity  # type: ignore
---> 30 import langchain.llms as llms
     31 import langchain.prompts as prompts
     33 from . import stoned

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain/llms/__init__.py:22
      1 """
      2 **LLM** classes provide
      3 access to the large language model (**LLM**) APIs and services.
   (...)
     18     AIMessage, BaseMessage
     19 """  # noqa: E501
     20 from typing import Any, Callable, Dict, Type
---> 22 from langchain.llms.base import BaseLLM
     25 def _import_ai21() -> Any:
     26     from langchain.llms.ai21 import AI21

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain/llms/base.py:2
      1 # Backwards compatibility.
----> 2 from langchain_core.language_models import BaseLanguageModel
      3 from langchain_core.language_models.llms import (
      4     LLM,
      5     BaseLLM,
   (...)
      9     update_cache,
     10 )
     12 __all__ = [
     13     "create_base_retry_decorator",
     14     "get_prompts",
   (...)
     19     "LLM",
     20 ]

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/language_models/__init__.py:7
      1 from langchain_core.language_models.base import (
      2     BaseLanguageModel,
      3     LanguageModelInput,
      4     LanguageModelOutput,
      5     get_tokenizer,
      6 )
----> 7 from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
      8 from langchain_core.language_models.llms import LLM, BaseLLM
     10 __all__ = [
     11     "BaseLanguageModel",
     12     "BaseChatModel",
   (...)
     18     "LanguageModelOutput",
     19 ]

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/language_models/chat_models.py:20
      7 from functools import partial
      8 from typing import (
      9     TYPE_CHECKING,
     10     Any,
   (...)
     17     cast,
     18 )
---> 20 from langchain_core.callbacks import (
     21     AsyncCallbackManager,
     22     AsyncCallbackManagerForLLMRun,
     23     BaseCallbackManager,
     24     CallbackManager,
     25     CallbackManagerForLLMRun,
     26     Callbacks,
     27 )
     28 from langchain_core.globals import get_llm_cache
     29 from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/callbacks/__init__.py:13
      1 from langchain_core.callbacks.base import (
      2     AsyncCallbackHandler,
      3     BaseCallbackHandler,
   (...)
     11     ToolManagerMixin,
     12 )
---> 13 from langchain_core.callbacks.manager import (
     14     AsyncCallbackManager,
     15     AsyncCallbackManagerForChainGroup,
     16     AsyncCallbackManagerForChainRun,
     17     AsyncCallbackManagerForLLMRun,
     18     AsyncCallbackManagerForRetrieverRun,
     19     AsyncCallbackManagerForToolRun,
     20     AsyncParentRunManager,
     21     AsyncRunManager,
     22     BaseRunManager,
     23     CallbackManager,
     24     CallbackManagerForChainGroup,
     25     CallbackManagerForChainRun,
     26     CallbackManagerForLLMRun,
     27     CallbackManagerForRetrieverRun,
     28     CallbackManagerForToolRun,
     29     ParentRunManager,
     30     RunManager,
     31 )
     32 from langchain_core.callbacks.stdout import StdOutCallbackHandler
     33 from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langchain_core/callbacks/manager.py:26
      9 from typing import (
     10     TYPE_CHECKING,
     11     Any,
   (...)
     22     cast,
     23 )
     24 from uuid import UUID
---> 26 from langsmith.run_helpers import get_run_tree_context
     27 from tenacity import RetryCallState
     29 from langchain_core.callbacks.base import (
     30     BaseCallbackHandler,
     31     BaseCallbackManager,
   (...)
     37     ToolManagerMixin,
     38 )

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/__init__.py:10
      6 except metadata.PackageNotFoundError:
      7     # Case where package metadata is not available.
      8     __version__ = ""
---> 10 from langsmith.client import Client
     11 from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
     12 from langsmith.run_helpers import trace, traceable

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/client.py:43
     41 from langsmith import schemas as ls_schemas
     42 from langsmith import utils as ls_utils
---> 43 from langsmith.evaluation import evaluator as ls_evaluator
     45 if TYPE_CHECKING:
     46     import pandas as pd

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/evaluation/__init__.py:4
      1 """Evaluation Helpers."""
      3 from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
----> 4 from langsmith.evaluation.string_evaluator import StringEvaluator
      6 __all__ = ["EvaluationResult", "RunEvaluator", "StringEvaluator"]

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/langsmith/evaluation/string_evaluator.py:3
      1 from typing import Callable, Dict, Optional
----> 3 from pydantic import BaseModel
      5 from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
      6 from langsmith.schemas import Example, Run

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/__init__.py:372, in __getattr__(attr_name)
    370     return import_module(f'.{attr_name}', package=package)
    371 else:
--> 372     module = import_module(module_name, package=package)
    373     return getattr(module, attr_name)

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/importlib/__init__.py:127, in import_module(name, package)
    125             break
    126         level += 1
--> 127 return _bootstrap._gcd_import(name[level:], package, level)

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/main.py:15
     12 import typing_extensions
     13 from pydantic_core import PydanticUndefined
---> 15 from ._internal import (
     16     _config,
     17     _decorators,
     18     _fields,
     19     _forward_ref,
     20     _generics,
     21     _mock_val_ser,
     22     _model_construction,
     23     _repr,
     24     _typing_extra,
     25     _utils,
     26 )
     27 from ._migration import getattr_migration
     28 from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/_internal/_decorators.py:15
     12 from typing_extensions import Literal, TypeAlias, is_typeddict
     14 from ..errors import PydanticUserError
---> 15 from ._core_utils import get_type_ref
     16 from ._internal_dataclass import slots_true
     17 from ._typing_extra import get_function_type_hints

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pydantic/_internal/_core_utils.py:16
     14 from pydantic_core import CoreSchema, core_schema
     15 from pydantic_core import validate_core_schema as _validate_core_schema
---> 16 from typing_extensions import TypeAliasType, TypeGuard, get_args, get_origin
     18 from . import _repr
     19 from ._typing_extra import is_generic_alias

ImportError: cannot import name 'TypeAliasType' from 'typing_extensions' (/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/typing_extensions.py)

Build and train a Random Forest model

# make object that can compute descriptors
calc = Calculator()
calc.register([HydrogenBond.HBondDonor, HydrogenBond.HBondAcceptor])
calc.register(
    [AcidBase.AcidicGroupCount, AcidBase.BasicGroupCount, Aromatic.AromaticBondsCount]
)
calc.register([SLogP.SLogP, Polarizability.APol])
calc.register(
    [
        BondCount.BondCount(type="double"),
        BondCount.BondCount(type="aromatic"),
        AtomCount.AtomCount("Hetero"),
    ]
)

# make subsample from pandas df
molecules = [rdkit.Chem.MolFromSmiles(smi) for smi in soldata.SMILES]

raw_features = []
for e, c in zip(molecules, calc.map(molecules, quiet=True)):
    raw_features.append([v for v in c.values()])
feature_names = np.array([d.description() for d in calc.descriptors])
raw_features = np.array(raw_features)
labels = soldata["Solubility"]
def pick_features(raw_features):
    fm = raw_features.mean()
    fs = raw_features.std()

    def feature_convert(f):
        f -= fm
        f /= fs
        return f

    features = feature_convert(raw_features)

    # we have some nans in features, likely because std was 0
    features = features.astype(float)
    features_select = np.random.randint(
        0, len(raw_features[0]), size=3
    )  # np.all(np.isfinite(features), axis=0)
    features = features[:, features_select]
    names = feature_names[features_select]
    return features, names
features, names = pick_features(raw_features)
print(features.shape, names)

X_train, X_test, y_train, y_test = train_test_split(
    features, labels, test_size=0.1, shuffle=True
)

clf = RandomForestRegressor(max_depth=10, random_state=0)
clf.fit(X_train, y_train)
predicted = clf.predict(X_test)
plt.figure(figsize=(5, 4))
plt.plot(y_test, predicted, ".")
plt.plot(y_test, y_test, linestyle=":")
plt.text(
    max(y_test) - 6,
    min(y_test) + 1,
    f"correlation = {np.corrcoef(y_test, predicted)[0,1]:.3f}",
    fontsize=12,
)
plt.text(
    max(y_test) - 6,
    min(y_test),
    f"loss = {np.sqrt(np.mean((y_test - predicted)**2)):.3f}",
    fontsize=12,
)
plt.savefig("RF-ROC.png")

Compute descriptor attributions

def model_eval(smiles):
    molecules = [rdkit.Chem.MolFromSmiles(smi) for smi in smiles]
    labels = clf.predict(np.nan_to_num(features))
    return labels


labels = model_eval(soldata.SMILES)
smi = soldata.SMILES[1500]
stoned_kwargs = {
    "num_samples": 2000,
    "alphabet": exmol.get_basic_alphabet(),
    "max_mutations": 2,
}
space = exmol.sample_space(smi, model_eval, stoned_kwargs=stoned_kwargs, quiet=True)
def calc_feature_importance(descriptors, tstats):
    from collections import OrderedDict

    feature_importance = {a: b for a, b in zip(descriptors, tstats) if not np.isnan(b)}
    feature_importance = dict(
        sorted(feature_importance.items(), key=lambda item: abs(item[1]), reverse=True)
    )
    # Fitted space important features
    return feature_importance
descriptor_type = "Classic"
exmol.lime_explain(space, descriptor_type=descriptor_type)
wls_attr = calc_feature_importance(
    list(space[0].descriptors.descriptor_names), list(space[0].descriptors.tstats)
)
wls_attr

Do we recover training features?

x = wls_attr.keys()
xaxis = np.arange(len(x))
x_colors = ["purple" if t in names else "black" for t in x]

rf_imp = {a: b for a, b in zip(names, clf.feature_importances_)}
rf_x = np.zeros(len(x))
rf_y = np.zeros(len(x))
for i, j in enumerate(x):
    if j in rf_imp:
        rf_x[i] = i
        rf_y[i] = rf_imp[j]

width = [wls_attr[i] for i in x]
colors = ["#F06060" if i < 0 else "#1BBC9B" for i in width]

fig, ax = plt.subplots(figsize=(6, 5))
ax.barh(xaxis + 0.2, width, 0.75, label="WLS", color=colors)

plt.xticks(fontsize=12)
plt.xlabel("Feature t-statistics", fontsize=12)
plt.yticks(xaxis, x, fontsize=12)
[t.set_color(i) for (i, t) in zip(x_colors, ax.yaxis.get_ticklabels())]
plt.gca().invert_yaxis()
plt.title("Random Forest Regression", fontsize=12)