From ccdbf784765c1a4746fdda930182d4f5d697ea3f Mon Sep 17 00:00:00 2001 From: advieser Date: Sun, 16 Nov 2025 12:35:47 +0100 Subject: [PATCH 1/4] implemented predict_newdata_fast as thin wrapper around predict_newdata --- DESCRIPTION | 2 +- R/GraphLearner.R | 10 +++++++++- man/mlr_graphs.Rd | 2 +- man/mlr_graphs_bagging.Rd | 2 +- man/mlr_graphs_ovr.Rd | 2 +- man/mlr_graphs_robustify.Rd | 2 +- man/mlr_graphs_stacking.Rd | 2 +- man/mlr_graphs_targettrafo.Rd | 2 +- man/mlr_learners_graph.Rd | 7 ++++++- man/mlr_pipeops.Rd | 2 +- man/mlr_pipeops_adas.Rd | 2 +- man/mlr_pipeops_blsmote.Rd | 2 +- man/mlr_pipeops_boxcox.Rd | 2 +- man/mlr_pipeops_classifavg.Rd | 2 +- man/mlr_pipeops_encodelmer.Rd | 2 +- man/mlr_pipeops_encodepltree.Rd | 2 +- man/mlr_pipeops_filter.Rd | 2 +- man/mlr_pipeops_ica.Rd | 2 +- man/mlr_pipeops_imputelearner.Rd | 2 +- man/mlr_pipeops_kernelpca.Rd | 2 +- man/mlr_pipeops_learner.Rd | 2 +- man/mlr_pipeops_learner_cv.Rd | 2 +- man/mlr_pipeops_learner_pi_cvplus.Rd | 2 +- man/mlr_pipeops_nearmiss.Rd | 2 +- man/mlr_pipeops_nmf.Rd | 2 +- man/mlr_pipeops_ovrsplit.Rd | 2 +- man/mlr_pipeops_ovrunite.Rd | 2 +- man/mlr_pipeops_proxy.Rd | 2 +- man/mlr_pipeops_randomresponse.Rd | 2 +- man/mlr_pipeops_regravg.Rd | 2 +- man/mlr_pipeops_smote.Rd | 2 +- man/mlr_pipeops_smotenc.Rd | 2 +- man/mlr_pipeops_targetmutate.Rd | 2 +- man/mlr_pipeops_targettrafoscalerange.Rd | 2 +- man/mlr_pipeops_textvectorizer.Rd | 2 +- man/mlr_pipeops_threshold.Rd | 2 +- man/mlr_pipeops_tomek.Rd | 2 +- man/mlr_pipeops_tunethreshold.Rd | 2 +- man/mlr_pipeops_updatetarget.Rd | 2 +- man/mlr_pipeops_vtreat.Rd | 2 +- man/mlr_pipeops_yeojohnson.Rd | 2 +- man/po.Rd | 2 +- man/ppl.Rd | 2 +- man/preproc.Rd | 2 +- 44 files changed, 57 insertions(+), 44 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 7a54a3604..4f636716d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -107,7 +107,7 @@ Config/testthat/edition: 3 Config/testthat/parallel: true NeedsCompilation: no Roxygen: list(markdown = TRUE, r6 = FALSE) -RoxygenNote: 7.3.2 +RoxygenNote: 7.3.3 VignetteBuilder: knitr, rmarkdown Collate: 'CnfAtom.R' diff --git a/R/GraphLearner.R b/R/GraphLearner.R index 0fa72cad6..e3090b910 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -118,7 +118,12 @@ #' corresponding [`PipeOpBranch`] is searched, and its hyperparameter configuration is used to select the base learner. #' There may be multiple corresponding [`PipeOpBranch`]s, which are all considered. #' If `resolve_branching` is `FALSE`, [`PipeOpUnbranch`] is treated as any other `PipeOp` with multiple inputs; all possible branch paths are considered equally. -#' +#' * `predict_newdata_fast(newdata, task = NULL)`\cr +#' (`data.frame`, [`Task`][mlr3::Task] | `NULL`) -> [`Prediction`][mlr3::Prediction]\cr +#' Predicts outcomes for new data in `newdata` using the model fitted during `$train()`. +#' For the moment, this is merely a thin wrapper around [`Learner$predict_newdata()`][mlr3::Learner] to ensure compatibility, meaning that *no speedup* is currently achieved. +#' In the future, this method may be optimized to be faster than `$predict_newdata()`. +#' #' The following standard extractors as defined by the [`Learner`][mlr3::Learner] class are available. #' Note that these typically only extract information from the `$base_learner()`. #' This works well for simple [`Graph`]s that do not modify features too much, but may give unexpected results for `Graph`s that @@ -312,6 +317,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, }, plot = function(html = FALSE, horizontal = FALSE, ...) { private$.graph$plot(html = html, horizontal = horizontal, ...) + }, + predict_newdata_fast = function(newdata, task = NULL) { + self$predict_newdata(newdata, task) } ), active = list( diff --git a/man/mlr_graphs.Rd b/man/mlr_graphs.Rd index 8ed2e1b58..8a736fe9d 100644 --- a/man/mlr_graphs.Rd +++ b/man/mlr_graphs.Rd @@ -35,7 +35,7 @@ Returns a \code{data.table} with column \code{key} (\code{character}). } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library(mlr3) lrn = lrn("regr.rpart") task = mlr_tasks$get("boston_housing") diff --git a/man/mlr_graphs_bagging.Rd b/man/mlr_graphs_bagging.Rd index 45883f0b7..e88f7f205 100644 --- a/man/mlr_graphs_bagging.Rd +++ b/man/mlr_graphs_bagging.Rd @@ -54,7 +54,7 @@ This is done as follows: All input arguments are cloned and have no references in common with the returned \code{\link{Graph}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} \donttest{ library(mlr3) lrn_po = po("learner", lrn("regr.rpart")) diff --git a/man/mlr_graphs_ovr.Rd b/man/mlr_graphs_ovr.Rd index bf0f42924..ab4e2d303 100644 --- a/man/mlr_graphs_ovr.Rd +++ b/man/mlr_graphs_ovr.Rd @@ -23,7 +23,7 @@ perform "One vs. Rest" classification. All input arguments are cloned and have no references in common with the returned \code{\link{Graph}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") task = tsk("wine") diff --git a/man/mlr_graphs_robustify.Rd b/man/mlr_graphs_robustify.Rd index 4cfd40044..f9f79d7bd 100644 --- a/man/mlr_graphs_robustify.Rd +++ b/man/mlr_graphs_robustify.Rd @@ -86,7 +86,7 @@ factor variables, no encoding is performed. All input arguments are cloned and have no references in common with the returned \code{\link{Graph}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} \donttest{ library(mlr3) lrn = lrn("regr.rpart") diff --git a/man/mlr_graphs_stacking.Rd b/man/mlr_graphs_stacking.Rd index 11dbc6344..c864a52e9 100644 --- a/man/mlr_graphs_stacking.Rd +++ b/man/mlr_graphs_stacking.Rd @@ -43,7 +43,7 @@ features in order to predict the outcome. All input arguments are cloned and have no references in common with the returned \code{\link{Graph}}. } \examples{ -\dontshow{if (mlr3misc::require_namespaces("rpart", quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (mlr3misc::require_namespaces("rpart", quietly = TRUE)) withAutoprint(\{ # examplesIf} library(mlr3) library(mlr3learners) diff --git a/man/mlr_graphs_targettrafo.Rd b/man/mlr_graphs_targettrafo.Rd index c858596f1..ecd47e6a5 100644 --- a/man/mlr_graphs_targettrafo.Rd +++ b/man/mlr_graphs_targettrafo.Rd @@ -40,7 +40,7 @@ parameters \code{trafo} and \code{inverter} of the \code{param_set} of the resul All input arguments are cloned and have no references in common with the returned \code{\link{Graph}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") tt = pipeline_targettrafo(PipeOpLearner$new(LearnerRegrRpart$new())) diff --git a/man/mlr_learners_graph.Rd b/man/mlr_learners_graph.Rd index bf684365b..8b3df844e 100644 --- a/man/mlr_learners_graph.Rd +++ b/man/mlr_learners_graph.Rd @@ -131,6 +131,11 @@ If \code{resolve_branching} is \code{TRUE}, and when a \code{\link{PipeOpUnbranc corresponding \code{\link{PipeOpBranch}} is searched, and its hyperparameter configuration is used to select the base learner. There may be multiple corresponding \code{\link{PipeOpBranch}}s, which are all considered. If \code{resolve_branching} is \code{FALSE}, \code{\link{PipeOpUnbranch}} is treated as any other \code{PipeOp} with multiple inputs; all possible branch paths are considered equally. +\item \code{predict_newdata_fast(newdata, task = NULL)}\cr +(\code{data.frame}, \code{\link[mlr3:Task]{Task}} | \code{NULL}) -> \code{\link[mlr3:Prediction]{Prediction}}\cr +Predicts outcomes for new data in \code{newdata} using the model fitted during \verb{$train()}. +For the moment, this is merely a thin wrapper around \code{\link[mlr3:Learner]{Learner$predict_newdata()}} to ensure compatibility, meaning that \emph{no speedup} is currently achieved. +In the future, this method may be optimized to be faster than \verb{$predict_newdata()}. } The following standard extractors as defined by the \code{\link[mlr3:Learner]{Learner}} class are available. @@ -175,7 +180,7 @@ recommended. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") graph = po("pca") \%>>\% lrn("classif.rpart") diff --git a/man/mlr_pipeops.Rd b/man/mlr_pipeops.Rd index c41449bbe..ffd87f9c1 100644 --- a/man/mlr_pipeops.Rd +++ b/man/mlr_pipeops.Rd @@ -67,7 +67,7 @@ values enclosed by square brackets ("\code{[}", "\verb{]}"), then the respective } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") mlr_pipeops$get("learner", lrn("classif.rpart")) diff --git a/man/mlr_pipeops_adas.Rd b/man/mlr_pipeops_adas.Rd index 33a094003..01e21c463 100644 --- a/man/mlr_pipeops_adas.Rd +++ b/man/mlr_pipeops_adas.Rd @@ -69,7 +69,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("smotefamily")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("smotefamily")) withAutoprint(\{ # examplesIf} library("mlr3") # Create example task diff --git a/man/mlr_pipeops_blsmote.Rd b/man/mlr_pipeops_blsmote.Rd index 08c6b0d8d..98a859917 100644 --- a/man/mlr_pipeops_blsmote.Rd +++ b/man/mlr_pipeops_blsmote.Rd @@ -77,7 +77,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("smotefamily")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("smotefamily")) withAutoprint(\{ # examplesIf} library("mlr3") # Create example task diff --git a/man/mlr_pipeops_boxcox.Rd b/man/mlr_pipeops_boxcox.Rd index 1b672f56c..2b94a75a5 100644 --- a/man/mlr_pipeops_boxcox.Rd +++ b/man/mlr_pipeops_boxcox.Rd @@ -72,7 +72,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("bestNormalize")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("bestNormalize")) withAutoprint(\{ # examplesIf} library("mlr3") task = tsk("iris") diff --git a/man/mlr_pipeops_classifavg.Rd b/man/mlr_pipeops_classifavg.Rd index beb232e55..0f4443fa5 100644 --- a/man/mlr_pipeops_classifavg.Rd +++ b/man/mlr_pipeops_classifavg.Rd @@ -75,7 +75,7 @@ Only methods inherited from \code{\link{PipeOpEnsemble}}/\code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} \donttest{ library("mlr3") diff --git a/man/mlr_pipeops_encodelmer.Rd b/man/mlr_pipeops_encodelmer.Rd index 33cb75759..a4c3a9984 100644 --- a/man/mlr_pipeops_encodelmer.Rd +++ b/man/mlr_pipeops_encodelmer.Rd @@ -94,7 +94,7 @@ Only methods inherited \code{\link{PipeOpTaskPreprocSimple}}/\code{\link{PipeOpT } \examples{ -\dontshow{if (mlr3misc::require_namespaces(c("nloptr", "lme4"), quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (mlr3misc::require_namespaces(c("nloptr", "lme4"), quietly = TRUE)) withAutoprint(\{ # examplesIf} library("mlr3") poe = po("encodelmer") diff --git a/man/mlr_pipeops_encodepltree.Rd b/man/mlr_pipeops_encodepltree.Rd index 2053eb2b5..c535a8ac0 100644 --- a/man/mlr_pipeops_encodepltree.Rd +++ b/man/mlr_pipeops_encodepltree.Rd @@ -73,7 +73,7 @@ Only methods inherited from \code{\link{PipeOpEncodePL}}/\code{\link{PipeOpTaskP } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library(mlr3) # For classification task diff --git a/man/mlr_pipeops_filter.Rd b/man/mlr_pipeops_filter.Rd index 308e66011..aeedae8dc 100644 --- a/man/mlr_pipeops_filter.Rd +++ b/man/mlr_pipeops_filter.Rd @@ -98,7 +98,7 @@ Methods inherited from \code{\link{PipeOpTaskPreprocSimple}}/\code{\link{PipeOpT } \examples{ -\dontshow{if (mlr3misc::require_namespaces(c("mlr3filters", "rpart"), quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (mlr3misc::require_namespaces(c("mlr3filters", "rpart"), quietly = TRUE)) withAutoprint(\{ # examplesIf} library("mlr3") library("mlr3filters") \dontshow{data.table::setDTthreads(1)} diff --git a/man/mlr_pipeops_ica.Rd b/man/mlr_pipeops_ica.Rd index 6de1101d2..bbc1132b4 100644 --- a/man/mlr_pipeops_ica.Rd +++ b/man/mlr_pipeops_ica.Rd @@ -98,7 +98,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("fastICA")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("fastICA")) withAutoprint(\{ # examplesIf} library("mlr3") task = tsk("iris") diff --git a/man/mlr_pipeops_imputelearner.Rd b/man/mlr_pipeops_imputelearner.Rd index fb7c920a9..72d4e7ba7 100644 --- a/man/mlr_pipeops_imputelearner.Rd +++ b/man/mlr_pipeops_imputelearner.Rd @@ -90,7 +90,7 @@ Only methods inherited from \code{\link{PipeOpImpute}}/\code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") task = tsk("pima") diff --git a/man/mlr_pipeops_kernelpca.Rd b/man/mlr_pipeops_kernelpca.Rd index b18f5806f..0c5eb8a91 100644 --- a/man/mlr_pipeops_kernelpca.Rd +++ b/man/mlr_pipeops_kernelpca.Rd @@ -75,7 +75,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("kernlab")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("kernlab")) withAutoprint(\{ # examplesIf} library("mlr3") task = tsk("iris") diff --git a/man/mlr_pipeops_learner.Rd b/man/mlr_pipeops_learner.Rd index aab63683b..fff28b33f 100644 --- a/man/mlr_pipeops_learner.Rd +++ b/man/mlr_pipeops_learner.Rd @@ -100,7 +100,7 @@ Methods inherited from \code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") task = tsk("iris") diff --git a/man/mlr_pipeops_learner_cv.Rd b/man/mlr_pipeops_learner_cv.Rd index 4dfd38111..26a6d0d4d 100644 --- a/man/mlr_pipeops_learner_cv.Rd +++ b/man/mlr_pipeops_learner_cv.Rd @@ -112,7 +112,7 @@ Methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") task = tsk("iris") diff --git a/man/mlr_pipeops_learner_pi_cvplus.Rd b/man/mlr_pipeops_learner_pi_cvplus.Rd index 188718a80..71c29e9b7 100644 --- a/man/mlr_pipeops_learner_pi_cvplus.Rd +++ b/man/mlr_pipeops_learner_pi_cvplus.Rd @@ -100,7 +100,7 @@ Methods inherited from \code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") task = tsk("mtcars") diff --git a/man/mlr_pipeops_nearmiss.Rd b/man/mlr_pipeops_nearmiss.Rd index 8d4898a66..6340acf23 100644 --- a/man/mlr_pipeops_nearmiss.Rd +++ b/man/mlr_pipeops_nearmiss.Rd @@ -69,7 +69,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("themis")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("themis")) withAutoprint(\{ # examplesIf} library("mlr3") # Create example task diff --git a/man/mlr_pipeops_nmf.Rd b/man/mlr_pipeops_nmf.Rd index 7f0cfc265..cb542325c 100644 --- a/man/mlr_pipeops_nmf.Rd +++ b/man/mlr_pipeops_nmf.Rd @@ -111,7 +111,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (mlr3misc::require_namespaces(c("NMF", "MASS"), quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (mlr3misc::require_namespaces(c("NMF", "MASS"), quietly = TRUE)) withAutoprint(\{ # examplesIf} \dontshow{ # NMF attaches these packages to search path on load, #929 lapply(c("package:Biobase", "package:BiocGenerics", "package:generics"), detach, character.only = TRUE) diff --git a/man/mlr_pipeops_ovrsplit.Rd b/man/mlr_pipeops_ovrsplit.Rd index 062e086f8..3aa506282 100644 --- a/man/mlr_pipeops_ovrsplit.Rd +++ b/man/mlr_pipeops_ovrsplit.Rd @@ -81,7 +81,7 @@ Only methods inherited from \code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library(mlr3) task = tsk("iris") po = po("ovrsplit") diff --git a/man/mlr_pipeops_ovrunite.Rd b/man/mlr_pipeops_ovrunite.Rd index 3ce033902..ba513e4ee 100644 --- a/man/mlr_pipeops_ovrunite.Rd +++ b/man/mlr_pipeops_ovrunite.Rd @@ -74,7 +74,7 @@ Only methods inherited from \code{\link{PipeOpEnsemble}}/\code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library(mlr3) task = tsk("iris") gr = po("ovrsplit") \%>>\% lrn("classif.rpart") \%>>\% po("ovrunite") diff --git a/man/mlr_pipeops_proxy.Rd b/man/mlr_pipeops_proxy.Rd index d1b92cf65..b472ba83d 100644 --- a/man/mlr_pipeops_proxy.Rd +++ b/man/mlr_pipeops_proxy.Rd @@ -74,7 +74,7 @@ Only methods inherited from \code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") set.seed(1234) diff --git a/man/mlr_pipeops_randomresponse.Rd b/man/mlr_pipeops_randomresponse.Rd index e3331d268..570d9e20f 100644 --- a/man/mlr_pipeops_randomresponse.Rd +++ b/man/mlr_pipeops_randomresponse.Rd @@ -79,7 +79,7 @@ Only methods inherited from \code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library(mlr3) library(mlr3learners) diff --git a/man/mlr_pipeops_regravg.Rd b/man/mlr_pipeops_regravg.Rd index 76fac8eb0..62243efe9 100644 --- a/man/mlr_pipeops_regravg.Rd +++ b/man/mlr_pipeops_regravg.Rd @@ -70,7 +70,7 @@ Only methods inherited from \code{\link{PipeOpEnsemble}}/\code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") # Simple Bagging diff --git a/man/mlr_pipeops_smote.Rd b/man/mlr_pipeops_smote.Rd index 0bfd310db..412e0bed5 100644 --- a/man/mlr_pipeops_smote.Rd +++ b/man/mlr_pipeops_smote.Rd @@ -69,7 +69,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("smotefamily")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("smotefamily")) withAutoprint(\{ # examplesIf} library("mlr3") # Create example task diff --git a/man/mlr_pipeops_smotenc.Rd b/man/mlr_pipeops_smotenc.Rd index 5a231bb2a..b9dfa69a9 100644 --- a/man/mlr_pipeops_smotenc.Rd +++ b/man/mlr_pipeops_smotenc.Rd @@ -78,7 +78,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("themis")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("themis")) withAutoprint(\{ # examplesIf} library("mlr3") # Create example task diff --git a/man/mlr_pipeops_targetmutate.Rd b/man/mlr_pipeops_targetmutate.Rd index b348d1ce3..eb1d10b11 100644 --- a/man/mlr_pipeops_targetmutate.Rd +++ b/man/mlr_pipeops_targetmutate.Rd @@ -78,7 +78,7 @@ Only methods inherited from \code{\link{PipeOpTargetTrafo}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library(mlr3) task = tsk("boston_housing") po = PipeOpTargetMutate$new("logtrafo", param_vals = list( diff --git a/man/mlr_pipeops_targettrafoscalerange.Rd b/man/mlr_pipeops_targettrafoscalerange.Rd index 8400551c5..9f4b95778 100644 --- a/man/mlr_pipeops_targettrafoscalerange.Rd +++ b/man/mlr_pipeops_targettrafoscalerange.Rd @@ -66,7 +66,7 @@ Only methods inherited from \code{\link{PipeOpTargetTrafo}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library(mlr3) task = tsk("boston_housing") po = PipeOpTargetTrafoScaleRange$new() diff --git a/man/mlr_pipeops_textvectorizer.Rd b/man/mlr_pipeops_textvectorizer.Rd index d40503694..726573712 100644 --- a/man/mlr_pipeops_textvectorizer.Rd +++ b/man/mlr_pipeops_textvectorizer.Rd @@ -167,7 +167,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (mlr3misc::require_namespaces(c("stopwords", "quanteda"), quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (mlr3misc::require_namespaces(c("stopwords", "quanteda"), quietly = TRUE)) withAutoprint(\{ # examplesIf} library("mlr3") library("data.table") # create some text data diff --git a/man/mlr_pipeops_threshold.Rd b/man/mlr_pipeops_threshold.Rd index d8aa2fa5c..65ef1fd9e 100644 --- a/man/mlr_pipeops_threshold.Rd +++ b/man/mlr_pipeops_threshold.Rd @@ -67,7 +67,7 @@ Only methods inherited from \code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") t = tsk("german_credit") gr = po(lrn("classif.rpart", predict_type = "prob")) \%>>\% diff --git a/man/mlr_pipeops_tomek.Rd b/man/mlr_pipeops_tomek.Rd index 7a3bee4bd..7fa699d8c 100644 --- a/man/mlr_pipeops_tomek.Rd +++ b/man/mlr_pipeops_tomek.Rd @@ -61,7 +61,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("themis")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("themis")) withAutoprint(\{ # examplesIf} library("mlr3") # Create example task diff --git a/man/mlr_pipeops_tunethreshold.Rd b/man/mlr_pipeops_tunethreshold.Rd index f2707ef05..3ffd5e43f 100644 --- a/man/mlr_pipeops_tunethreshold.Rd +++ b/man/mlr_pipeops_tunethreshold.Rd @@ -89,7 +89,7 @@ Only methods inherited from \code{\link{PipeOp}}. } \examples{ -\dontshow{if (mlr3misc::require_namespaces(c("bbotk", "rpart", "GenSA"), quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (mlr3misc::require_namespaces(c("bbotk", "rpart", "GenSA"), quietly = TRUE)) withAutoprint(\{ # examplesIf} library("mlr3") task = tsk("iris") diff --git a/man/mlr_pipeops_updatetarget.Rd b/man/mlr_pipeops_updatetarget.Rd index 263b41ff3..434e3ec96 100644 --- a/man/mlr_pipeops_updatetarget.Rd +++ b/man/mlr_pipeops_updatetarget.Rd @@ -75,7 +75,7 @@ Only methods inherited from \code{\link{PipeOp}}. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} \dontrun{ # Create a binary class task from iris library(mlr3) diff --git a/man/mlr_pipeops_vtreat.Rd b/man/mlr_pipeops_vtreat.Rd index 81514f09d..110fe7084 100644 --- a/man/mlr_pipeops_vtreat.Rd +++ b/man/mlr_pipeops_vtreat.Rd @@ -128,7 +128,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("vtreat")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("vtreat")) withAutoprint(\{ # examplesIf} library("mlr3") set.seed(2020) diff --git a/man/mlr_pipeops_yeojohnson.Rd b/man/mlr_pipeops_yeojohnson.Rd index f1823e343..1293846c7 100644 --- a/man/mlr_pipeops_yeojohnson.Rd +++ b/man/mlr_pipeops_yeojohnson.Rd @@ -73,7 +73,7 @@ Only methods inherited from \code{\link{PipeOpTaskPreproc}}/\code{\link{PipeOp}} } \examples{ -\dontshow{if (requireNamespace("bestNormalize")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("bestNormalize")) withAutoprint(\{ # examplesIf} library("mlr3") task = tsk("iris") diff --git a/man/po.Rd b/man/po.Rd index cdb941a25..8b1b13562 100644 --- a/man/po.Rd +++ b/man/po.Rd @@ -48,7 +48,7 @@ it to a \code{\link{PipeOp}}. \code{pos()} (with plural-s) takes either a \code{ list of objects, and creates a \code{list} of \code{\link{PipeOp}}s. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") po("learner", lrn("classif.rpart"), cp = 0.3) diff --git a/man/ppl.Rd b/man/ppl.Rd index 190eef7d1..77ec8d1c4 100644 --- a/man/ppl.Rd +++ b/man/ppl.Rd @@ -32,7 +32,7 @@ Creates a \code{\link{Graph}} from \code{\link{mlr_graphs}} from given ID vector of any list and returns a \code{list} of possibly muliple \code{\link{Graph}}s. } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") gr = ppl("bagging", graph = po(lrn("regr.rpart")), diff --git a/man/preproc.Rd b/man/preproc.Rd index 2b6340637..7fa3b3e43 100644 --- a/man/preproc.Rd +++ b/man/preproc.Rd @@ -56,7 +56,7 @@ of \code{\link[mlr3:TaskSupervised]{TaskSupervised}} will not work with these in } \examples{ -\dontshow{if (requireNamespace("rpart")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (requireNamespace("rpart")) withAutoprint(\{ # examplesIf} library("mlr3") task = tsk("iris") From 0e78cd5eed893ce6079c5ae08266c135adfba3eb Mon Sep 17 00:00:00 2001 From: advieser Date: Sun, 16 Nov 2025 17:24:11 +0100 Subject: [PATCH 2/4] WIP tests --- R/GraphLearner.R | 11 +++++++++- tests/testthat/test_GraphLearner.R | 35 ++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/R/GraphLearner.R b/R/GraphLearner.R index 4d049c732..03cbc4e61 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -330,8 +330,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, } cat_cli(cli_h3("Pipeline: {.strong {pp}}")) }, + # TODO: Optimize this method to be actually faster than predict_newdata(), #968 predict_newdata_fast = function(newdata, task = NULL) { - self$predict_newdata(newdata, task) + pred = self$predict_newdata(newdata, task) + if (self$task_type == "regr") { + if (!is.null(pred$quantiles)) { + return(list(quantiles = pred$quantiles)) + } + list(response = pred$response, se = if (!all(is.na(pred$se))) pred$se else NULL) + } else if (self$task_type == "classif") { + list(response = pred$response, prob = pred$prob) + } } ), active = list( diff --git a/tests/testthat/test_GraphLearner.R b/tests/testthat/test_GraphLearner.R index c8ebe4e8f..b3198ad56 100644 --- a/tests/testthat/test_GraphLearner.R +++ b/tests/testthat/test_GraphLearner.R @@ -1345,3 +1345,38 @@ test_that("GraphLearner other properties", { expect_true(all(c("loglik", "oob_error") %in% g_nested$properties)) }) + +test_that("GraphLearner - predict_newdata_fast", { + # Classification + task = tsk("iris") + newdata = task$data(cols = task$feature_names, rows = 1:10) + # predict_type = "response" + learner = lrn("classif.featureless") + lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) + expect_equal(glrn_pred, lrn_pred) + # predict_type = "prob" + learner = lrn("classif.featureless", predict_type = "prob") + lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) + expect_equal(glrn_pred, lrn_pred) + + # Regression + task = tsk("mtcars") + newdata = task$data(cols = task$feature_names, rows = 1:10) + # predict_type = "response" + learner = lrn("regr.featureless") + lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) + expect_equal(glrn_pred, lrn_pred) + # predict_type = "se" + learner = lrn("regr.featureless", predict_type = "se") + lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) + expect_equal(glrn_pred, lrn_pred) + # predict_type = "quantiles" + learner = lrn("regr.featureless", predict_type = "quantiles", quantiles = 0.25) + lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) + expect_equal(glrn_pred, lrn_pred) +}) From a01cfc7a13a4d5f094b96c0f41e67c8b108940b7 Mon Sep 17 00:00:00 2001 From: advieser Date: Mon, 17 Nov 2025 17:45:31 +0100 Subject: [PATCH 3/4] docs and test --- R/GraphLearner.R | 8 ++++++-- man/mlr3pipelines-package.Rd | 1 + man/mlr_learners_graph.Rd | 10 ++++++++-- tests/testthat/test_GraphLearner.R | 30 ++++++++++++++++++++++-------- 4 files changed, 37 insertions(+), 12 deletions(-) diff --git a/R/GraphLearner.R b/R/GraphLearner.R index 03cbc4e61..b13fdae7a 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -120,9 +120,13 @@ #' If `resolve_branching` is `FALSE`, [`PipeOpUnbranch`] is treated as any other `PipeOp` with multiple inputs; all possible branch paths are considered equally. #' * `predict_newdata_fast(newdata, task = NULL)`\cr #' (`data.frame`, [`Task`][mlr3::Task] | `NULL`) -> [`Prediction`][mlr3::Prediction]\cr -#' Predicts outcomes for new data in `newdata` using the model fitted during `$train()`. +#' Predicts outcomes for new data in `newdata` using the model fitted during `$train()`.\cr #' For the moment, this is merely a thin wrapper around [`Learner$predict_newdata()`][mlr3::Learner] to ensure compatibility, meaning that *no speedup* is currently achieved. -#' In the future, this method may be optimized to be faster than `$predict_newdata()`. +#' In the future, this method may be optimized to be faster than `$predict_newdata()`.\cr +#' Unlike `$predict_newdata()`, this method does not return a [Prediction] object. +#' Instead, it returns a list with elements depending on `$task_type` and `$predict_type`: +#' * for `task_type = "classif"`: `response` and `prob`, or `quantiles` (if `predict_type = "quantiles"`) +#' * for `task_type = "regr"`: `response` and `se` #' #' The following standard extractors as defined by the [`Learner`][mlr3::Learner] class are available. #' Note that these typically only extract information from the `$base_learner()`. diff --git a/man/mlr3pipelines-package.Rd b/man/mlr3pipelines-package.Rd index dfdffc8db..3561fabe8 100644 --- a/man/mlr3pipelines-package.Rd +++ b/man/mlr3pipelines-package.Rd @@ -37,6 +37,7 @@ Other contributors: \item Keno Mersmann \email{keno.mersmann@gmail.com} [contributor] \item Maximilian Mücke \email{muecke.maximilian@gmail.com} (\href{https://orcid.org/0009-0000-9432-9795}{ORCID}) [contributor] \item Lona Koers \email{lona.koers@gmail.com} [contributor] + \item Alexander Winterstetter \email{alexander.winterstetter@gmail.com} [contributor] } } diff --git a/man/mlr_learners_graph.Rd b/man/mlr_learners_graph.Rd index 8b3df844e..5dfb989d7 100644 --- a/man/mlr_learners_graph.Rd +++ b/man/mlr_learners_graph.Rd @@ -133,9 +133,15 @@ There may be multiple corresponding \code{\link{PipeOpBranch}}s, which are all c If \code{resolve_branching} is \code{FALSE}, \code{\link{PipeOpUnbranch}} is treated as any other \code{PipeOp} with multiple inputs; all possible branch paths are considered equally. \item \code{predict_newdata_fast(newdata, task = NULL)}\cr (\code{data.frame}, \code{\link[mlr3:Task]{Task}} | \code{NULL}) -> \code{\link[mlr3:Prediction]{Prediction}}\cr -Predicts outcomes for new data in \code{newdata} using the model fitted during \verb{$train()}. +Predicts outcomes for new data in \code{newdata} using the model fitted during \verb{$train()}.\cr For the moment, this is merely a thin wrapper around \code{\link[mlr3:Learner]{Learner$predict_newdata()}} to ensure compatibility, meaning that \emph{no speedup} is currently achieved. -In the future, this method may be optimized to be faster than \verb{$predict_newdata()}. +In the future, this method may be optimized to be faster than \verb{$predict_newdata()}.\cr +Unlike \verb{$predict_newdata()}, this method does not return a \link[mlr3:Prediction]{mlr3::Prediction} object. +Instead, it returns a list with elements depending on \verb{$task_type} and \verb{$predict_type}: +\itemize{ +\item for \code{task_type = "classif"}: \code{response} and \code{prob}, or \code{quantiles} (if \code{predict_type = "quantiles"}) +\item for \code{task_type = "regr"}: \code{response} and \code{se} +} } The following standard extractors as defined by the \code{\link[mlr3:Learner]{Learner}} class are available. diff --git a/tests/testthat/test_GraphLearner.R b/tests/testthat/test_GraphLearner.R index b3198ad56..2292e54a3 100644 --- a/tests/testthat/test_GraphLearner.R +++ b/tests/testthat/test_GraphLearner.R @@ -1350,33 +1350,47 @@ test_that("GraphLearner - predict_newdata_fast", { # Classification task = tsk("iris") newdata = task$data(cols = task$feature_names, rows = 1:10) - # predict_type = "response" + + ## predict_type = "response" learner = lrn("classif.featureless") + set.seed(20251117) lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + set.seed(20251117) glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) - expect_equal(glrn_pred, lrn_pred) - # predict_type = "prob" + + expect_equal(as.character(glrn_pred$response), lrn_pred$response) + expect_equal(glrn_pred$prob, lrn_pred$prob) + + ## predict_type = "prob" learner = lrn("classif.featureless", predict_type = "prob") + set.seed(20251117) lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + set.seed(20251117) glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) - expect_equal(glrn_pred, lrn_pred) + + # Some Learners only return probs if predict_type is "prob" in .predict() and labels are inferred later, + # so there is an acceptable difference between the two methods. + expect_true(identical(glrn_pred$response, lrn_pred$response) || is.null(lrn_pred$response)) + expect_equal(glrn_pred$prob, lrn_pred$prob) # Regression task = tsk("mtcars") newdata = task$data(cols = task$feature_names, rows = 1:10) - # predict_type = "response" + ## predict_type = "response" learner = lrn("regr.featureless") lrn_pred = learner$train(task)$predict_newdata_fast(newdata) glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) expect_equal(glrn_pred, lrn_pred) - # predict_type = "se" + + ## predict_type = "se" learner = lrn("regr.featureless", predict_type = "se") lrn_pred = learner$train(task)$predict_newdata_fast(newdata) glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) expect_equal(glrn_pred, lrn_pred) - # predict_type = "quantiles" + + ## predict_type = "quantiles" learner = lrn("regr.featureless", predict_type = "quantiles", quantiles = 0.25) lrn_pred = learner$train(task)$predict_newdata_fast(newdata) glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) - expect_equal(glrn_pred, lrn_pred) + expect_equal(glrn_pred$quantiles[, 1L], lrn_pred$quantiles[, 1L]) }) From e9fbacc2a2b0d250efc26b8e4e971ea1c0253da1 Mon Sep 17 00:00:00 2001 From: advieser Date: Sat, 22 Nov 2025 13:55:39 +0100 Subject: [PATCH 4/4] Updated NEWS with cautionary note about predict_newdata_fast --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index beb75a2e9..fce9ab3b3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # mlr3pipelines 0.10.0-9000 +* New method `$predict_newdata_fast()` for `GraphLearner`. Note that currently this is only a thin wrapper around `$predict_newdata()` to maintain compatibility, but in the future it may get optimized to enable faster predictions on new data. + # mlr3pipelines 0.10.0 * Pretty-printing some info using the `cli` package now.