Model Searching#
In this example, we will use the permeability data in 1A, “ADME Properties Evaluation in Drug Discovery: Prediction of Caco-2 Cell Permeability Using a Combination of NSGA-II and Boosting” https://pubs.acs.org/doi/10.1021/acs.jcim.5b00642, and instead of defining a model ourselves, we will search over a list of top model architectures and find the best one.
# We'll first load in the data, same as in 1A
import requests
import os
if not os.path.exists("caco2_data.xlsx"):
r = requests.get("https://ndownloader.figstatic.com/files/4917022")
open("caco2_data.xlsx" , 'wb').write(r.content)
# Reading the data into a dataframe
# Subsetting the data into molecule, split, and property
# Converting property values to floats
# Creating splits
import pandas as pd
import numpy as np
df = pd.read_excel("caco2_data.xlsx")
df["split"] = df["Dataset"].replace({"Tr": "train", "Te": "test"})
df = df[["smi", "split", "logPapp"]].dropna()
def isfloat(num):
try:
float(num)
return True
except ValueError:
return False
df = df[df["logPapp"].apply(isfloat)]
df["logPapp"] = df["logPapp"].astype('float')
# Now we use the dataframe to create a BaseDataset object.
# We will generate it from the pd.DataFrame object.
# We have defined our own split column, which will be used by the dataset object.
import olorenchemengine as oce
dataset = oce.BaseDataset(data = df.to_csv(), structure_col="smi", property_col="logPapp")
# We'll now get our list of top model architectures.
# Each of these models has certain situations where it outperforms the others,
# so we test all of them to see which model is best for this specific situation.
models = oce.TOP_MODELS_ADMET()
# We'll also create a ModelManager object to keep track of our experiments
mm = oce.ModelManager(dataset, metrics = ["Root Mean Squared Error"], file_path="mm_1B_results.oce")
# This will now use our model manager to test our top models
# and will take around 1-4 hours to run in total for this dataset,
# though it will take more or less time depending on the machine.
mm.run(models)
# Get the list of models and sort by their RMSE performance metrics
# We see that the best model now outperforms the published models.
mm.get_model_database().sort_values(by="Root Mean Squared Error")
Model Name | Model Parameters | Fitting Time | Root Mean Squared Error | |
---|---|---|---|---|
0 | ZWK_XGBoostModel 8t4Lbm1C | {'BC_class_name': 'ZWK_XGBoostModel', 'args': ... | 583.755889 | 0.306506 |
4 | RFStacker ZObB1n2V | {'BC_class_name': 'RFStacker', 'args': [[{'BC_... | 1370.884091 | 0.326439 |
2 | BaseBoosting sSOI0-2O | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 80.049373 | 0.332818 |
6 | RFStacker Dg3XrFow | {'BC_class_name': 'RFStacker', 'args': [[{'BC_... | 1077.793208 | 0.344461 |
9 | ZWK_XGBoostModel u3zq9AAV | {'BC_class_name': 'ZWK_XGBoostModel', 'args': ... | 583.068085 | 0.350327 |
1 | BaseBoosting GDDXgNxr | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 98.140028 | 0.354108 |
13 | RFStacker J-KhwR5S | {'BC_class_name': 'RFStacker', 'args': [[{'BC_... | 2523.654553 | 0.378332 |
10 | BaseBoosting 1zpI0dIb | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 31.493627 | 0.390516 |
11 | BaseBoosting ADkCCrwJ | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 4.215762 | 0.402029 |
15 | RFStacker kHyqmLCI | {'BC_class_name': 'RFStacker', 'args': [[{'BC_... | 4003.144378 | 0.450928 |
7 | ResampleAdaboost rw2YnX2a | {'BC_class_name': 'ResampleAdaboost', 'args': ... | 150.905021 | 0.452858 |
14 | BaseBoosting Q-ko4Uuj | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 70.730023 | 0.460666 |
16 | ResampleAdaboost rw2YnX2a | {'BC_class_name': 'ResampleAdaboost', 'args': ... | 432.079354 | 0.474882 |
12 | SPGNN TWy3l_kb | {'BC_class_name': 'SPGNN', 'args': [], 'kwargs... | 9.815435 | 0.487068 |
5 | BaseBoosting Px-cadEt | {'BC_class_name': 'BaseBoosting', 'args': [[{'... | 54.437475 | 0.492785 |
8 | KNN 20fz7vhA | {'BC_class_name': 'KNN', 'args': [{'BC_class_n... | 1.745354 | 0.526132 |
17 | KNN 20fz7vhA | {'BC_class_name': 'KNN', 'args': [{'BC_class_n... | 0.016418 | 0.926369 |
3 | SPGNN 8PvbRqPX | {'BC_class_name': 'SPGNN', 'args': [], 'kwargs... | 10.181282 | 2.647019 |