From 2116ddf8fd9f6b4c0da2e53d859dc598268d2b53 Mon Sep 17 00:00:00 2001 From: egillax Date: Mon, 19 Jun 2023 17:10:47 +0200 Subject: [PATCH 1/3] wrap comparisons with python objects in py_bool --- R/SklearnToJson.R | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/R/SklearnToJson.R b/R/SklearnToJson.R index 95607dd3c..5b6b54063 100644 --- a/R/SklearnToJson.R +++ b/R/SklearnToJson.R @@ -54,17 +54,17 @@ sklearnFromJson <- function(path) { with(py$open(path, "r"), as=file, { model <- json$load(fp=file) }) - if (model["meta"] == "decision-tree") { + if (reticulate::py_bool(model["meta"] == "decision-tree")) { model <- deSerializeDecisionTree(model) - } else if (model["meta"] == "rf") { + } else if (reticulate::py_bool(model["meta"] == "rf")) { model <- deSerializeRandomForest(model) - } else if (model["meta"] == "adaboost") { + } else if (reticulate::py_bool(model["meta"] == "adaboost")) { model <- deSerializeAdaboost(model) - } else if (model["meta"] == "naive-bayes") { + } else if (reticulate::py_bool(model["meta"] == "naive-bayes")) { model <- deSerializeNaiveBayes(model) - } else if (model["meta"] == "mlp") { + } else if (reticulate::py_bool(model["meta"] == "mlp")) { model <- deSerializeMlp(model) - } else if (model["meta"] == "svm") { + } else if (reticulate::py_bool(model["meta"] == "svm")) { model <- deSerializeSVM(model) } else { stop("Unsupported model") @@ -181,7 +181,7 @@ serializeRandomForest <- function(model) { "params" = model$get_params(), "n_classes_" = model$n_classes_) - if (model$`__dict__`["oob_score_"] != reticulate::py_none()) { + if (reticulate::py_bool(model$`__dict__`["oob_score_"] != reticulate::py_none())) { serialized_model["oob_score_"] <- model$oob_score_ serialized_model["oob_decision_function_"] <- model$oob_decision_function_$tolist() } @@ -215,7 +215,7 @@ deSerializeRandomForest <- function(model_dict) { model$min_impurity_split <- model_dict["min_impurity_split"] model$n_classes_ <- model_dict["n_classes_"] - if (model_dict$oob_score_ != reticulate::py_none()){ + if (reticulate::py_bool(model_dict$oob_score_ != reticulate::py_none())){ model$oob_score_ <- model_dict["oob_score_"] model$oob_decision_function_ <- model_dict["oob_decision_function_"] } From c160182c84a0d61368afc1e5e2febbdda9a6c85a Mon Sep 17 00:00:00 2001 From: egillax Date: Tue, 20 Jun 2023 10:38:19 +0200 Subject: [PATCH 2/3] fix svm from json --- R/SklearnToJson.R | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/R/SklearnToJson.R b/R/SklearnToJson.R index 5b6b54063..9c1dba1ae 100644 --- a/R/SklearnToJson.R +++ b/R/SklearnToJson.R @@ -387,23 +387,23 @@ deSerializeSVM <- function(model_dict) { model$`_probB` <- np$array(model_dict["probB_"])$astype(np$float64) model$`_intercept_` <- np$array(model_dict["_intercept_"])$astype(np$float64) - if ((model_dict$support_vectors_["meta"] != reticulate::py_none()) & - (model_dict$support_vectors_["meta"] == "csr")) { + if (reticulate::py_bool((model_dict$support_vectors_["meta"] != reticulate::py_none())) & + (reticulate::py_bool(model_dict$support_vectors_["meta"] == "csr"))) { model$support_vectors_ <- deSerializeCsrMatrix(model_dict$support_vectors_) model$`_sparse` <- TRUE } else { model$support_vectors_ <- np$array(model_dict$support_vectors_)$astype(np$float64) model$`_sparse` <- FALSE } - if ((model_dict$dual_coef_["meta"] != reticulate::py_none()) & - (model_dict$dual_coef_["meta"] == "csr")) { + if (reticulate::py_bool((model_dict$dual_coef_["meta"] != reticulate::py_none())) & + (reticulate::py_bool(model_dict$dual_coef_["meta"] == "csr"))) { model$dual_coef_ <- deSerializeCsrMatrix(model_dict$dual_coef_) } else { model$dual_coef_ <- np$array(model_dict$dual_coef_)$astype(np$float64) } - if ((model_dict$`_dual_coef_`["meta"] != reticulate::py_none()) & - (model_dict$`_dual_coef_`["meta"] == "csr")) { + if (reticulate::py_bool((model_dict$`_dual_coef_`["meta"] != reticulate::py_none())) & + (reticulate::py_bool(model_dict$`_dual_coef_`["meta"] == "csr"))) { model$`_dual_coef_` <- deSerializeCsrMatrix(model_dict$`dual_coef_`) } else { model$`_dual_coef_` <- np$array(model_dict$`_dual_coef_`)$astype(np$float64) From ea01f2e4fa1d5df33936dc5cdf331edf50c1a158 Mon Sep 17 00:00:00 2001 From: egillax Date: Tue, 20 Jun 2023 10:50:21 +0200 Subject: [PATCH 3/3] remove sklearn-json from python env since its not used --- R/HelperFunctions.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/HelperFunctions.R b/R/HelperFunctions.R index 63c76c990..e3b939ff6 100644 --- a/R/HelperFunctions.R +++ b/R/HelperFunctions.R @@ -115,7 +115,7 @@ configurePython <- function(envname='PLP', envtype=NULL){ ParallelLogger::logInfo(paste0('Creating virtual conda environment called ', envname)) location <- reticulate::conda_create(envname=envname, packages = "python", conda = "auto") } - packages <- c('numpy','scipy','scikit-learn', 'pandas','pydotplus','joblib', 'sklearn-json') + packages <- c('numpy','scipy','scikit-learn', 'pandas','pydotplus','joblib') ParallelLogger::logInfo(paste0('Adding python dependancies to ', envname)) reticulate::conda_install(envname=envname, packages = packages, forge = TRUE, pip = FALSE, pip_ignore_installed = TRUE, conda = "auto") @@ -128,7 +128,7 @@ configurePython <- function(envname='PLP', envtype=NULL){ ParallelLogger::logInfo(paste0('Creating virtual python environment called ', envname)) location <- reticulate::virtualenv_create(envname=envname) } - packages <- c('numpy', 'scikit-learn','scipy', 'pandas','pydotplus','sklearn-json') + packages <- c('numpy', 'scikit-learn','scipy', 'pandas','pydotplus') ParallelLogger::logInfo(paste0('Adding python dependancies to ', envname)) reticulate::virtualenv_install(envname=envname, packages = packages, ignore_installed = TRUE)