Putting it all together#

We now demonstrate all of Oloren ChemEngine’s uncertainty features by training a production-level model and error model on the BACE dataset from DeepChem datasets.

import olorenchemengine as oce
import pandas as pd
import numpy as np

Creating the dataset#

We will train the model on 90% of the data and leave 10% for testing.

bace_dataset = oce.DatasetFromCSV("bace.csv", structure_col = "mol", property_col = "pIC50")
splitter = oce.RandomSplit(split_proportions=[0.9,0,0.1])
bace_dataset = splitter.transform(bace_dataset)
oce.save(bace_dataset, "bace_dataset.oce")

Training the production-level model#

Production-level models can be produced by running fit_cv, which in addition to fitting the whole model, fits the error model via cross validation.

bace_dataset = oce.load("bace_dataset.oce")

model = oce.ZWK_XGBoostModel(oce.OlorenCheckpoint("default"))
model.fit_cv(bace_dataset.train_dataset[0], bace_dataset.train_dataset[1], error_model = oce.SDCwRMSD1())
oce.save(model, "bace_model.oce")

Visualizing results#

We visualize the probable output range for each test molecule (80% confidence interval) and the true output for each test molecule. For the ones plotted, each of the predicted values are within the error margin.

model = oce.load("bace_model.oce")
bace_dataset = oce.load("bace_dataset.oce")

results_df = model.predict(bace_dataset.test_dataset[0], return_ci=True, return_vis=True)
results_df
152it [00:00, 349.85it/s]
predicted ci vis
0 6.004463 1.052809 <olorenchemengine.visualizations.visualization.Vis...
1 7.084375 0.869544 <olorenchemengine.visualizations.visualization.Vis...
2 6.217337 1.473213 <olorenchemengine.visualizations.visualization.Vis...
3 5.523591 1.170892 <olorenchemengine.visualizations.visualization.Vis...
4 6.502657 0.511894 <olorenchemengine.visualizations.visualization.Vis...
... ... ... ...
147 5.657271 1.333157 <olorenchemengine.visualizations.visualization.Vis...
148 4.210466 0.880119 <olorenchemengine.visualizations.visualization.Vis...
149 7.369242 0.833378 <olorenchemengine.visualizations.visualization.Vis...
150 7.733750 0.714240 <olorenchemengine.visualizations.visualization.Vis...
151 5.793466 0.860124 <olorenchemengine.visualizations.visualization.Vis...

152 rows × 3 columns

../_images/2E_All_Together_7_2.png

Ground truth output#

list(bace_dataset.test_dataset[1])[150]
7.1487417

Predicted output and error margin#

results_df["vis"][150].render_ipynb()