Hyperparameter Optimization#
We’ll be using oce’s hyperparameter optimization engine, powered by hyperopt, to create a stronger model of Caco-2 permeability
# We'll first be downloading the data as described in 1A.
import olorenchemengine as oce
import requests
from sklearn import metrics
r = requests.get("https://ndownloader.figstatic.com/files/4917022")
open("caco2_data.xlsx" , 'wb').write(r.content)
import pandas as pd
import numpy as np
df = pd.read_excel("caco2_data.xlsx")
df["split"] = df["Dataset"].replace({"Tr": "train", "Te": "test"})
df = df[["smi", "split", "logPapp"]].dropna()
import random
p = 0.8
train_size = int(len(df[df["split"]=="train"]) * p)
val_size = len(df[df["split"]=="train"]) - train_size
l = ["valid"]*val_size + ["train"]*train_size
random.shuffle(l)
df.loc[df["split"] == "train", "split"] = l
def isfloat(num):
try:
float(num)
return True
except ValueError:
return False
df = df[df["logPapp"].apply(isfloat)]
df["logPapp"] = df["logPapp"].astype('float')
dataset = oce.BaseDataset(data = df.to_csv(), structure_col="smi", property_col="logPapp")
manager = oce.ModelManager(dataset, metrics="Root Mean Squared Error", file_path="1C_manager.oce")
# Here we will be using a short-list of good model architectures and hyperparameter
# values, which were found using the Therapeutic Data Commons ADMET dataset collection.
manager.run(oce.TOP_MODELS_ADMET())
model = oce.BaseBoosting(
[oce.RandomForestModel(oce.OptChoice("descriptor1", [oce.GobbiPharma2D(), oce.Mol2Vec(), oce.DescriptastorusDescriptor("rdkit2dnormalized"), oce.DescriptastorusDescriptor("morgan3counts"), oce.OlorenCheckpoint("default")]),
n_estimators=oce.OptChoice("n_estimators1", [10, 500, 1000, 2000]),
max_features = oce.OptChoice("max_features1", ["log2", "auto"]),),
oce.RandomForestModel(oce.OptChoice("descriptor2", [oce.GobbiPharma2D(), oce.Mol2Vec(), oce.DescriptastorusDescriptor("rdkit2dnormalized"), oce.DescriptastorusDescriptor("morgan3counts"), oce.OlorenCheckpoint("default")]),
n_estimators=oce.OptChoice("n_estimators2", [10, 500, 1000, 2000]),
max_features = oce.OptChoice("max_features2", ["log2", "auto"]),),
oce.RandomForestModel(oce.OptChoice("descriptor3", [oce.GobbiPharma2D(), oce.Mol2Vec(), oce.DescriptastorusDescriptor("rdkit2dnormalized"), oce.DescriptastorusDescriptor("morgan3counts"), oce.OlorenCheckpoint("default")]),
n_estimators=oce.OptChoice("n_estimators3", [10, 500, 1000, 2000]),
max_features = oce.OptChoice("max_features3", ["log2", "auto"]),)]
)
best = oce.optimize(model, manager, max_evals=100)
manager.get_model_database().sort_values(by="Root Mean Squared Error", ascending=True)
manager.get_model_database().sort_values(by="Root Mean Squared Error", ascending=True)
Model Name | Model Parameters | Fitting Time | Root Mean Squared Error | |
---|---|---|---|---|
94 | BaseBoosting 8K36OOLO | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 91.027555 | 0.335641 |
95 | BaseBoosting k8-gHshR | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 677.343201 | 0.339610 |
10 | BaseBoosting cZHDmmMV | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 262.673683 | 0.341740 |
93 | BaseBoosting K4L62AtM | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 59.878670 | 0.343320 |
97 | BaseBoosting 6v_E8N2p | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 66.435989 | 0.349018 |
... | ... | ... | ... | ... |
8 | KNN 20fz7vhA | {'BC_class_name': 'KNN', 'args': [{'BC_class_n... | 1.294361 | 0.525758 |
21 | BaseBoosting m_RXV5dC | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 703.818117 | 0.533155 |
92 | ChemPropModel TuX4IIuT | {'BC_class_name': 'ChemPropModel', 'args': [],... | 13.275340 | 0.538600 |
91 | ChemPropModel TuX4IIuT | {'BC_class_name': 'ChemPropModel', 'args': [],... | 17.817980 | 5.149165 |
89 | ChemPropModel 7l6982xF | {'BC_class_name': 'ChemPropModel', 'args': [],... | 4.267098 | 5.150783 |
98 rows × 4 columns
manager.best_model.test(*dataset.test_dataset)
255it [00:00, 439.32it/s]
100%|██████████| 6/6 [00:00<00:00, 72.42it/s]
{'r2': 0.7555024600133741,
'Explained Variance': 0.7555024603158613,
'Max Error': 1.3791490625619556,
'Mean Absolute Error': 0.28752496358544677,
'Mean Squared Error': 0.14721778747961853,
'Root Mean Squared Error': 0.383689702076585}