Putting it all together#
We now demonstrate all of Oloren ChemEngine’s uncertainty features by training a production-level model and error model on the BACE dataset from DeepChem datasets.
import olorenchemengine as oce
import pandas as pd
import numpy as np
Creating the dataset#
We will train the model on 90% of the data and leave 10% for testing.
bace_dataset = oce.DatasetFromCSV("bace.csv", structure_col = "mol", property_col = "pIC50")
splitter = oce.RandomSplit(split_proportions=[0.9,0,0.1])
bace_dataset = splitter.transform(bace_dataset)
oce.save(bace_dataset, "bace_dataset.oce")
Training the production-level model#
Production-level models can be produced by running fit_cv
, which in
addition to fitting the whole model, fits the error model via cross
validation.
bace_dataset = oce.load("bace_dataset.oce")
model = oce.ZWK_XGBoostModel(oce.OlorenCheckpoint("default"))
model.fit_cv(bace_dataset.train_dataset[0], bace_dataset.train_dataset[1], error_model = oce.SDCwRMSD1())
oce.save(model, "bace_model.oce")
Visualizing results#
We visualize the probable output range for each test molecule (80% confidence interval) and the true output for each test molecule. For the ones plotted, each of the predicted values are within the error margin.
model = oce.load("bace_model.oce")
bace_dataset = oce.load("bace_dataset.oce")
results_df = model.predict(bace_dataset.test_dataset[0], return_ci=True, return_vis=True)
results_df
152it [00:00, 349.85it/s]
predicted | ci | vis | |
---|---|---|---|
0 | 6.004463 | 1.052809 | <olorenchemengine.visualizations.visualization.Vis... |
1 | 7.084375 | 0.869544 | <olorenchemengine.visualizations.visualization.Vis... |
2 | 6.217337 | 1.473213 | <olorenchemengine.visualizations.visualization.Vis... |
3 | 5.523591 | 1.170892 | <olorenchemengine.visualizations.visualization.Vis... |
4 | 6.502657 | 0.511894 | <olorenchemengine.visualizations.visualization.Vis... |
... | ... | ... | ... |
147 | 5.657271 | 1.333157 | <olorenchemengine.visualizations.visualization.Vis... |
148 | 4.210466 | 0.880119 | <olorenchemengine.visualizations.visualization.Vis... |
149 | 7.369242 | 0.833378 | <olorenchemengine.visualizations.visualization.Vis... |
150 | 7.733750 | 0.714240 | <olorenchemengine.visualizations.visualization.Vis... |
151 | 5.793466 | 0.860124 | <olorenchemengine.visualizations.visualization.Vis... |
152 rows × 3 columns

Ground truth output#
list(bace_dataset.test_dataset[1])[150]
7.1487417
Predicted output and error margin#
results_df["vis"][150].render_ipynb()