Hyperparameter Optimization#

We’ll be using oce’s hyperparameter optimization engine, powered by hyperopt, to create a stronger model of Caco-2 permeability

# We'll first be downloading the data as described in 1A.

import olorenchemengine as oce

import requests
from sklearn import metrics
r = requests.get("https://ndownloader.figstatic.com/files/4917022")
open("caco2_data.xlsx" , 'wb').write(r.content)

import pandas as pd
import numpy as np

df = pd.read_excel("caco2_data.xlsx")

df["split"] = df["Dataset"].replace({"Tr": "train", "Te": "test"})
df = df[["smi", "split", "logPapp"]].dropna()

import random
p = 0.8
train_size = int(len(df[df["split"]=="train"]) * p)
val_size = len(df[df["split"]=="train"]) - train_size
l = ["valid"]*val_size + ["train"]*train_size
random.shuffle(l)
df.loc[df["split"] == "train", "split"] = l

def isfloat(num):
    try:
        float(num)
        return True
    except ValueError:
        return False
df = df[df["logPapp"].apply(isfloat)]

df["logPapp"] = df["logPapp"].astype('float')

dataset = oce.BaseDataset(data = df.to_csv(), structure_col="smi", property_col="logPapp")
manager = oce.ModelManager(dataset, metrics="Root Mean Squared Error", file_path="1C_manager.oce")
# Here we will be using a short-list of good model architectures and hyperparameter
# values, which were found using the Therapeutic Data Commons ADMET dataset collection.

manager.run(oce.TOP_MODELS_ADMET())
model = oce.BaseBoosting(
    [oce.RandomForestModel(oce.OptChoice("descriptor1", [oce.GobbiPharma2D(), oce.Mol2Vec(), oce.DescriptastorusDescriptor("rdkit2dnormalized"), oce.DescriptastorusDescriptor("morgan3counts"), oce.OlorenCheckpoint("default")]),
        n_estimators=oce.OptChoice("n_estimators1", [10, 500, 1000, 2000]),
        max_features = oce.OptChoice("max_features1", ["log2", "auto"]),),
    oce.RandomForestModel(oce.OptChoice("descriptor2", [oce.GobbiPharma2D(), oce.Mol2Vec(), oce.DescriptastorusDescriptor("rdkit2dnormalized"), oce.DescriptastorusDescriptor("morgan3counts"), oce.OlorenCheckpoint("default")]),
        n_estimators=oce.OptChoice("n_estimators2", [10, 500, 1000, 2000]),
        max_features = oce.OptChoice("max_features2", ["log2", "auto"]),),
    oce.RandomForestModel(oce.OptChoice("descriptor3", [oce.GobbiPharma2D(), oce.Mol2Vec(), oce.DescriptastorusDescriptor("rdkit2dnormalized"), oce.DescriptastorusDescriptor("morgan3counts"), oce.OlorenCheckpoint("default")]),
        n_estimators=oce.OptChoice("n_estimators3", [10, 500, 1000, 2000]),
        max_features = oce.OptChoice("max_features3", ["log2", "auto"]),)]
)
best = oce.optimize(model, manager, max_evals=100)
manager.get_model_database().sort_values(by="Root Mean Squared Error", ascending=True)
manager.get_model_database().sort_values(by="Root Mean Squared Error", ascending=True)
Model Name Model Parameters Fitting Time Root Mean Squared Error
94 BaseBoosting 8K36OOLO {'BC_class_name': 'BaseBoosting', 'args': [[{'... 91.027555 0.335641
95 BaseBoosting k8-gHshR {'BC_class_name': 'BaseBoosting', 'args': [[{'... 677.343201 0.339610
10 BaseBoosting cZHDmmMV {'BC_class_name': 'BaseBoosting', 'args': [[{'... 262.673683 0.341740
93 BaseBoosting K4L62AtM {'BC_class_name': 'BaseBoosting', 'args': [[{'... 59.878670 0.343320
97 BaseBoosting 6v_E8N2p {'BC_class_name': 'BaseBoosting', 'args': [[{'... 66.435989 0.349018
... ... ... ... ...
8 KNN 20fz7vhA {'BC_class_name': 'KNN', 'args': [{'BC_class_n... 1.294361 0.525758
21 BaseBoosting m_RXV5dC {'BC_class_name': 'BaseBoosting', 'args': [[{'... 703.818117 0.533155
92 ChemPropModel TuX4IIuT {'BC_class_name': 'ChemPropModel', 'args': [],... 13.275340 0.538600
91 ChemPropModel TuX4IIuT {'BC_class_name': 'ChemPropModel', 'args': [],... 17.817980 5.149165
89 ChemPropModel 7l6982xF {'BC_class_name': 'ChemPropModel', 'args': [],... 4.267098 5.150783

98 rows × 4 columns

manager.best_model.test(*dataset.test_dataset)
255it [00:00, 439.32it/s]
100%|██████████| 6/6 [00:00<00:00, 72.42it/s]
{'r2': 0.7555024600133741,
 'Explained Variance': 0.7555024603158613,
 'Max Error': 1.3791490625619556,
 'Mean Absolute Error': 0.28752496358544677,
 'Mean Squared Error': 0.14721778747961853,
 'Root Mean Squared Error': 0.383689702076585}