diff --git a/NEWS.md b/NEWS.md index beb75a2e9..fce9ab3b3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # mlr3pipelines 0.10.0-9000 +* New method `$predict_newdata_fast()` for `GraphLearner`. Note that currently this is only a thin wrapper around `$predict_newdata()` to maintain compatibility, but in the future it may get optimized to enable faster predictions on new data. + # mlr3pipelines 0.10.0 * Pretty-printing some info using the `cli` package now. diff --git a/R/GraphLearner.R b/R/GraphLearner.R index 8214c179f..b13fdae7a 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -118,7 +118,16 @@ #' corresponding [`PipeOpBranch`] is searched, and its hyperparameter configuration is used to select the base learner. #' There may be multiple corresponding [`PipeOpBranch`]s, which are all considered. #' If `resolve_branching` is `FALSE`, [`PipeOpUnbranch`] is treated as any other `PipeOp` with multiple inputs; all possible branch paths are considered equally. -#' +#' * `predict_newdata_fast(newdata, task = NULL)`\cr +#' (`data.frame`, [`Task`][mlr3::Task] | `NULL`) -> [`Prediction`][mlr3::Prediction]\cr +#' Predicts outcomes for new data in `newdata` using the model fitted during `$train()`.\cr +#' For the moment, this is merely a thin wrapper around [`Learner$predict_newdata()`][mlr3::Learner] to ensure compatibility, meaning that *no speedup* is currently achieved. +#' In the future, this method may be optimized to be faster than `$predict_newdata()`.\cr +#' Unlike `$predict_newdata()`, this method does not return a [Prediction] object. +#' Instead, it returns a list with elements depending on `$task_type` and `$predict_type`: +#' * for `task_type = "classif"`: `response` and `prob`, or `quantiles` (if `predict_type = "quantiles"`) +#' * for `task_type = "regr"`: `response` and `se` +#' #' The following standard extractors as defined by the [`Learner`][mlr3::Learner] class are available. #' Note that these typically only extract information from the `$base_learner()`. #' This works well for simple [`Graph`]s that do not modify features too much, but may give unexpected results for `Graph`s that @@ -324,6 +333,18 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, pp = "non-sequential" } cat_cli(cli_h3("Pipeline: {.strong {pp}}")) + }, + # TODO: Optimize this method to be actually faster than predict_newdata(), #968 + predict_newdata_fast = function(newdata, task = NULL) { + pred = self$predict_newdata(newdata, task) + if (self$task_type == "regr") { + if (!is.null(pred$quantiles)) { + return(list(quantiles = pred$quantiles)) + } + list(response = pred$response, se = if (!all(is.na(pred$se))) pred$se else NULL) + } else if (self$task_type == "classif") { + list(response = pred$response, prob = pred$prob) + } } ), active = list( diff --git a/man/mlr3pipelines-package.Rd b/man/mlr3pipelines-package.Rd index dfdffc8db..3561fabe8 100644 --- a/man/mlr3pipelines-package.Rd +++ b/man/mlr3pipelines-package.Rd @@ -37,6 +37,7 @@ Other contributors: \item Keno Mersmann \email{keno.mersmann@gmail.com} [contributor] \item Maximilian Mücke \email{muecke.maximilian@gmail.com} (\href{https://orcid.org/0009-0000-9432-9795}{ORCID}) [contributor] \item Lona Koers \email{lona.koers@gmail.com} [contributor] + \item Alexander Winterstetter \email{alexander.winterstetter@gmail.com} [contributor] } } diff --git a/man/mlr_learners_graph.Rd b/man/mlr_learners_graph.Rd index 95df4a8db..5dfb989d7 100644 --- a/man/mlr_learners_graph.Rd +++ b/man/mlr_learners_graph.Rd @@ -131,6 +131,17 @@ If \code{resolve_branching} is \code{TRUE}, and when a \code{\link{PipeOpUnbranc corresponding \code{\link{PipeOpBranch}} is searched, and its hyperparameter configuration is used to select the base learner. There may be multiple corresponding \code{\link{PipeOpBranch}}s, which are all considered. If \code{resolve_branching} is \code{FALSE}, \code{\link{PipeOpUnbranch}} is treated as any other \code{PipeOp} with multiple inputs; all possible branch paths are considered equally. +\item \code{predict_newdata_fast(newdata, task = NULL)}\cr +(\code{data.frame}, \code{\link[mlr3:Task]{Task}} | \code{NULL}) -> \code{\link[mlr3:Prediction]{Prediction}}\cr +Predicts outcomes for new data in \code{newdata} using the model fitted during \verb{$train()}.\cr +For the moment, this is merely a thin wrapper around \code{\link[mlr3:Learner]{Learner$predict_newdata()}} to ensure compatibility, meaning that \emph{no speedup} is currently achieved. +In the future, this method may be optimized to be faster than \verb{$predict_newdata()}.\cr +Unlike \verb{$predict_newdata()}, this method does not return a \link[mlr3:Prediction]{mlr3::Prediction} object. +Instead, it returns a list with elements depending on \verb{$task_type} and \verb{$predict_type}: +\itemize{ +\item for \code{task_type = "classif"}: \code{response} and \code{prob}, or \code{quantiles} (if \code{predict_type = "quantiles"}) +\item for \code{task_type = "regr"}: \code{response} and \code{se} +} } The following standard extractors as defined by the \code{\link[mlr3:Learner]{Learner}} class are available. diff --git a/tests/testthat/test_GraphLearner.R b/tests/testthat/test_GraphLearner.R index c8ebe4e8f..2292e54a3 100644 --- a/tests/testthat/test_GraphLearner.R +++ b/tests/testthat/test_GraphLearner.R @@ -1345,3 +1345,52 @@ test_that("GraphLearner other properties", { expect_true(all(c("loglik", "oob_error") %in% g_nested$properties)) }) + +test_that("GraphLearner - predict_newdata_fast", { + # Classification + task = tsk("iris") + newdata = task$data(cols = task$feature_names, rows = 1:10) + + ## predict_type = "response" + learner = lrn("classif.featureless") + set.seed(20251117) + lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + set.seed(20251117) + glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) + + expect_equal(as.character(glrn_pred$response), lrn_pred$response) + expect_equal(glrn_pred$prob, lrn_pred$prob) + + ## predict_type = "prob" + learner = lrn("classif.featureless", predict_type = "prob") + set.seed(20251117) + lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + set.seed(20251117) + glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) + + # Some Learners only return probs if predict_type is "prob" in .predict() and labels are inferred later, + # so there is an acceptable difference between the two methods. + expect_true(identical(glrn_pred$response, lrn_pred$response) || is.null(lrn_pred$response)) + expect_equal(glrn_pred$prob, lrn_pred$prob) + + # Regression + task = tsk("mtcars") + newdata = task$data(cols = task$feature_names, rows = 1:10) + ## predict_type = "response" + learner = lrn("regr.featureless") + lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) + expect_equal(glrn_pred, lrn_pred) + + ## predict_type = "se" + learner = lrn("regr.featureless", predict_type = "se") + lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) + expect_equal(glrn_pred, lrn_pred) + + ## predict_type = "quantiles" + learner = lrn("regr.featureless", predict_type = "quantiles", quantiles = 0.25) + lrn_pred = learner$train(task)$predict_newdata_fast(newdata) + glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata) + expect_equal(glrn_pred$quantiles[, 1L], lrn_pred$quantiles[, 1L]) +})