Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mlr3pipelines 0.10.0-9000

* New method `$predict_newdata_fast()` for `GraphLearner`. Note that currently this is only a thin wrapper around `$predict_newdata()` to maintain compatibility, but in the future it may get optimized to enable faster predictions on new data.

# mlr3pipelines 0.10.0

* Pretty-printing some info using the `cli` package now.
Expand Down
23 changes: 22 additions & 1 deletion R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,16 @@
#' corresponding [`PipeOpBranch`] is searched, and its hyperparameter configuration is used to select the base learner.
#' There may be multiple corresponding [`PipeOpBranch`]s, which are all considered.
#' If `resolve_branching` is `FALSE`, [`PipeOpUnbranch`] is treated as any other `PipeOp` with multiple inputs; all possible branch paths are considered equally.
#'
#' * `predict_newdata_fast(newdata, task = NULL)`\cr
#' (`data.frame`, [`Task`][mlr3::Task] | `NULL`) -> [`Prediction`][mlr3::Prediction]\cr
#' Predicts outcomes for new data in `newdata` using the model fitted during `$train()`.\cr
#' For the moment, this is merely a thin wrapper around [`Learner$predict_newdata()`][mlr3::Learner] to ensure compatibility, meaning that *no speedup* is currently achieved.
#' In the future, this method may be optimized to be faster than `$predict_newdata()`.\cr
#' Unlike `$predict_newdata()`, this method does not return a [Prediction] object.
#' Instead, it returns a list with elements depending on `$task_type` and `$predict_type`:
#' * for `task_type = "classif"`: `response` and `prob`, or `quantiles` (if `predict_type = "quantiles"`)
#' * for `task_type = "regr"`: `response` and `se`
#'
#' The following standard extractors as defined by the [`Learner`][mlr3::Learner] class are available.
#' Note that these typically only extract information from the `$base_learner()`.
#' This works well for simple [`Graph`]s that do not modify features too much, but may give unexpected results for `Graph`s that
Expand Down Expand Up @@ -324,6 +333,18 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
pp = "non-sequential"
}
cat_cli(cli_h3("Pipeline: {.strong {pp}}"))
},
# TODO: Optimize this method to be actually faster than predict_newdata(), #968
predict_newdata_fast = function(newdata, task = NULL) {
pred = self$predict_newdata(newdata, task)
if (self$task_type == "regr") {
if (!is.null(pred$quantiles)) {
return(list(quantiles = pred$quantiles))
}
list(response = pred$response, se = if (!all(is.na(pred$se))) pred$se else NULL)
} else if (self$task_type == "classif") {
list(response = pred$response, prob = pred$prob)
}
}
),
active = list(
Expand Down
1 change: 1 addition & 0 deletions man/mlr3pipelines-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions man/mlr_learners_graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 49 additions & 0 deletions tests/testthat/test_GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -1345,3 +1345,52 @@ test_that("GraphLearner other properties", {
expect_true(all(c("loglik", "oob_error") %in% g_nested$properties))

})

test_that("GraphLearner - predict_newdata_fast", {
# Classification
task = tsk("iris")
newdata = task$data(cols = task$feature_names, rows = 1:10)

## predict_type = "response"
learner = lrn("classif.featureless")
set.seed(20251117)
lrn_pred = learner$train(task)$predict_newdata_fast(newdata)
set.seed(20251117)
glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata)

expect_equal(as.character(glrn_pred$response), lrn_pred$response)
expect_equal(glrn_pred$prob, lrn_pred$prob)

## predict_type = "prob"
learner = lrn("classif.featureless", predict_type = "prob")
set.seed(20251117)
lrn_pred = learner$train(task)$predict_newdata_fast(newdata)
set.seed(20251117)
glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata)

# Some Learners only return probs if predict_type is "prob" in .predict() and labels are inferred later,
# so there is an acceptable difference between the two methods.
expect_true(identical(glrn_pred$response, lrn_pred$response) || is.null(lrn_pred$response))
expect_equal(glrn_pred$prob, lrn_pred$prob)

# Regression
task = tsk("mtcars")
newdata = task$data(cols = task$feature_names, rows = 1:10)
## predict_type = "response"
learner = lrn("regr.featureless")
lrn_pred = learner$train(task)$predict_newdata_fast(newdata)
glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata)
expect_equal(glrn_pred, lrn_pred)

## predict_type = "se"
learner = lrn("regr.featureless", predict_type = "se")
lrn_pred = learner$train(task)$predict_newdata_fast(newdata)
glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata)
expect_equal(glrn_pred, lrn_pred)

## predict_type = "quantiles"
learner = lrn("regr.featureless", predict_type = "quantiles", quantiles = 0.25)
lrn_pred = learner$train(task)$predict_newdata_fast(newdata)
glrn_pred = as_learner(po("learner", learner))$train(task)$predict_newdata_fast(newdata)
expect_equal(glrn_pred$quantiles[, 1L], lrn_pred$quantiles[, 1L])
})