Counterfactuals are datapoints that try to answer the question “what is the smallest change to the features that would alter the prediction?” In the context of chemical ML, it involves determining the smallest change to an input molecule that would cause a model to change its prediction.
Counterfactual selection is based on the Python package exmol. Detailed methods and analysis can be accessed in their paper, but here we will briefly overview how counterfactuals are identified and plotted with oce.
Selecting counterfactuals can be broken down into several steps: 1. Sample the chemical space around the target molecule. This is done with our PerturbationEngine class of methods. 2. Run the model and compute the similarity from the target molecule. We use Tanimoto similarity of the molecules’ Morgan fingerprints. 3. Cluster and select counterfactuals. Density-based clustering is performed in the PCA-reduced space of the affinity matrix. The most similar conuterfactual(s) from each cluster are saved. 4. Plot counterfactuals. This can either be done in the PCA-reduced space, or in similiarity-output space.
Below, we demonstrate how to run this pipeline with an arbitrary model.
In this example, we will train a model on the lipophilicity dataset from MoleculeNet.
import olorenchemengine as oce from rdkit import Chem import pandas as pd
dataset = oce.load('lipophilicity_dataset.oce') model = oce.load('lipophilicity_model_rf.oce')
Using backend: pytorch
We will select a molecule from the testing dataset to use as our reference molecule.
smiles = dataset.test_dataset.loc[3780, 'smiles']
We can now build our visualization object using the
VisualizeCounterfactual method. Here, we set
delta = 0.5,
indicating that every compound with a predicted value that differs from
the predicted value of the reference by more than 0.5 is a potential
vis = oce.VisualizeCounterfactual(smiles, model, delta=0.5)
100%|██████████| 28/28 [00:17<00:00, 1.58it/s] 100%|██████████| 32263/32263 [00:10<00:00, 3065.50it/s] 100%|██████████| 5303/5303 [00:42<00:00, 124.51it/s]
Finally, we can plot the selected counterfactuals using the
The x-axis is the Tanimoto similarity between each sampled molecule with the reference (base) molecule. The y-axis is the predicted value of the molecule from the trained model. For regression models such as this one, the y-axis will be continuous, whereas for classification models, it will be discrete. High and low counterfactuals can also be clearly seen as having either significantly higher or lower predicted outputs from the reference. For classification models, everything outside of the reference class would be simply labeled as a counterfactual.
The number of points shown (default 40) can be varied by setting
vis = oce.VisualizeCounterfactual(smiles, model, delta=0.5, n=10)
100%|██████████| 28/28 [00:17<00:00, 1.58it/s] 100%|██████████| 32263/32263 [00:10<00:00, 3017.44it/s] 100%|██████████| 5303/5303 [00:42<00:00, 124.42it/s]
Counterfactuals can also be plotted in their original PCA space by
vis = oce.VisualizeCounterfactual(smiles, model, delta=0.5, pca=True)
In this space, molecules closer to each other have a greater similarity. Thus, different clusters of molecules correspond to different high-level modifications of the reference molecule. Subclusters correspond to more specific modifications within these groups.