Skip to content
12 changes: 6 additions & 6 deletions R/BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ BenchmarkResult = R6Class("BenchmarkResult",
if (!is.null(bmr)) {
assert_benchmark_result(bmr)
if (private$.data$iterations() && self$task_type != bmr$task_type) {
stopf("BenchmarkResult is of task type '%s', but must be '%s'", bmr$task_type, self$task_type)
error_input("BenchmarkResult is of task type '%s', but must be '%s'", bmr$task_type, self$task_type)
}

private$.data$combine(get_private(bmr)$.data)
Expand Down Expand Up @@ -425,8 +425,8 @@ BenchmarkResult = R6Class("BenchmarkResult",
resample_result = function(i = NULL, uhash = NULL, task_id = NULL, learner_id = NULL,
resampling_id = NULL) {
uhash = private$.get_uhashes(i, uhash, learner_id, task_id, resampling_id)
if (length(uhash) != 1L) {
stopf("Method requires selecting exactly one ResampleResult, but got %s",
if (length(uhash) != 1) {
error_input("Method requires selecting exactly one ResampleResult, but got %s",
length(uhash))
}
ResampleResult$new(private$.data, view = uhash)
Expand Down Expand Up @@ -598,7 +598,7 @@ BenchmarkResult = R6Class("BenchmarkResult",
resampling_ids = resampling_ids), is.null)

if (sum(!is.null(i), !is.null(uhashes), length(args) > 0L) > 1) {
stopf("At most one of `i`, `uhash`, or IDs can be provided.")
error_input("At most one of `i`, `uhash`, or IDs can be provided.")
}
if (!is.null(i)) {
uhashes = self$uhashes
Expand All @@ -609,7 +609,7 @@ BenchmarkResult = R6Class("BenchmarkResult",
uhashes = invoke(match.fun("uhashes"), bmr = self, .args = args)
}
if (length(uhashes) == 0L) {
stopf("No resample results found for the given arguments.")
error_input("No resample results found for the given arguments.")
}
uhashes
},
Expand Down Expand Up @@ -714,7 +714,7 @@ uhash = function(bmr, learner_id = NULL, task_id = NULL, resampling_id = NULL) {
assert_string(resampling_id, null.ok = TRUE)
uhash = uhashes(bmr, learner_id, task_id, resampling_id)
if (length(uhash) != 1) {
stopf("Expected exactly one uhash, got %s", length(uhash))
error_input("Expected exactly one uhash, got %s", length(uhash))
}
uhash
}
2 changes: 1 addition & 1 deletion R/DataBackendCbind.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ DataBackendCbind = R6Class("DataBackendCbind", inherit = DataBackend, cloneable
pk = b1$primary_key

if (pk != b2$primary_key) {
stopf("All backends to cbind must have the primary_key '%s'", pk)
error_input("All backends to cbind must have the primary_key '%s'", pk)
}

super$initialize(list(b1 = b1, b2 = b2), pk)
Expand Down
2 changes: 1 addition & 1 deletion R/DataBackendDataTable.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ DataBackendDataTable = R6Class("DataBackendDataTable", inherit = DataBackend,
super$initialize(setkeyv(data, primary_key), primary_key)
ii = match(primary_key, names(data))
if (is.na(ii)) {
stopf("Primary key '%s' not in 'data'", primary_key)
error_input("Primary key '%s' not in 'data'", primary_key)
}
private$.cache = set_names(replace(rep(NA, ncol(data)), ii, FALSE), names(data))
},
Expand Down
2 changes: 1 addition & 1 deletion R/DataBackendRbind.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ DataBackendRbind = R6Class("DataBackendRbind", inherit = DataBackend, cloneable
pk = b1$primary_key

if (pk != b2$primary_key) {
stopf("All backends to rbind must have the primary_key '%s'", pk)
error_input("All backends to rbind must have the primary_key '%s'", pk)
}

super$initialize(list(b1 = b1, b2 = b2), pk)
Expand Down
4 changes: 2 additions & 2 deletions R/DataBackendRename.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ DataBackendRename = R6Class("DataBackendRename", inherit = DataBackend, cloneabl
new = new[ii]

if (self$primary_key %chin% old) {
stopf("Renaming the primary key is not supported")
error_input("Renaming the primary key is not supported")
}


resulting_names = map_values(b$colnames, old, new)
dup = anyDuplicated(resulting_names)
if (dup > 0L) {
stopf("Duplicated column name after rename: %s", resulting_names[dup])
error_input("Duplicated column name after rename: %s", resulting_names[dup])
}

self$old = old
Expand Down
6 changes: 3 additions & 3 deletions R/HotstartStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ HotstartStack = R6Class("HotstartStack",

walk(learners, function(learner) {
if (!is.null(get0("validate", learner))) {
stopf("Hotstart learners that did validation is currently not supported.")
error_input("Hotstart learners that did validation is currently not supported.")
} else if (is.null(learner$model)) {
stopf("Learners must be trained before adding them to the hotstart stack.")
error_input("Learners must be trained before adding them to the hotstart stack.")
} else if (is_marshaled_model(learner$model)) {
stopf("Learners must be unmarshaled before adding them to the hotstart stack.")
error_input("Learners must be unmarshaled before adding them to the hotstart stack.")
}
})

Expand Down
22 changes: 11 additions & 11 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -375,14 +375,14 @@ Learner = R6Class("Learner",
predict = function(task, row_ids = NULL) {
# improve error message for the common mistake of passing a data.frame here
if (is.data.frame(task)) {
stopf("To predict on data.frames, use the method `$predict_newdata()` instead of `$predict()`")
error_input("To predict on data.frames, use the method `$predict_newdata()` instead of `$predict()`")
}
task = assert_task(as_task(task))
assert_predictable(task, self)
row_ids = assert_row_ids(row_ids, task = task, null.ok = TRUE)

if (is.null(self$state$model) && is.null(self$state$fallback_state$model)) {
stopf("Cannot predict, Learner '%s' has not been trained yet", self$id)
error_input("Cannot predict, Learner '%s' has not been trained yet", self$id)
}

# we need to marshal for call-r prediction and parallel prediction, but afterwards we reset the model
Expand Down Expand Up @@ -452,7 +452,7 @@ Learner = R6Class("Learner",
predict_newdata = function(newdata, task = NULL) {
if (is.null(task)) {
if (is.null(self$state$train_task)) {
stopf("No task stored, and no task provided")
error_input("No task stored, and no task provided")
}
task = self$state$train_task$clone()
} else {
Expand Down Expand Up @@ -618,7 +618,7 @@ Learner = R6Class("Learner",
fallback$id, self$id, str_collapse(missing_properties), class = "Mlr3WarningConfigFallbackProperties")
}
} else if (method == "none" && !is.null(fallback)) {
stopf("Fallback learner must be `NULL` if encapsulation is set to `none`.")
error_config("Fallback learner must be `NULL` if encapsulation is set to `none`.")
}

private$.encapsulation = c(train = method, predict = method)
Expand Down Expand Up @@ -665,7 +665,7 @@ Learner = R6Class("Learner",
for (i in seq_along(new_values)) {
nn = ndots[[i]]
if (!exists(nn, envir = self, inherits = FALSE)) {
stopf("Cannot set argument '%s' for '%s' (not a parameter, not a field).%s",
error_config("Cannot set argument '%s' for '%s' (not a parameter, not a field).%s",
nn, class(self)[1L], did_you_mean(nn, c(param_ids, setdiff(names(self), ".__enclos_env__")))) # nolint
}
self[[nn]] = new_values[[i]]
Expand All @@ -681,10 +681,10 @@ Learner = R6Class("Learner",
#' If set to `"error"`, an error is thrown, otherwise all features are returned.
selected_features = function() {
if (is.null(self$model)) {
stopf("No model stored")
error_learner("No model stored")
}
if (private$.selected_features_impute == "error") {
stopf("Learner does not support feature selection")
error_learner("Learner does not support feature selection")
} else {
self$state$feature_names
}
Expand Down Expand Up @@ -790,15 +790,15 @@ Learner = R6Class("Learner",

assert_string(rhs, .var.name = "predict_type")
if (rhs %nin% self$predict_types) {
stopf("Learner '%s' does not support predict type '%s'", self$id, rhs)
error_input("Learner '%s' does not support predict type '%s'", self$id, rhs) # TODO error_learner?
}
private$.predict_type = rhs
},

#' @template field_param_set
param_set = function(rhs) {
if (!missing(rhs) && !identical(rhs, private$.param_set)) {
stopf("param_set is read-only.")
error_input("param_set is read-only.")
}
private$.param_set
},
Expand Down Expand Up @@ -866,7 +866,7 @@ Learner = R6Class("Learner",
# return: Numeric vector of weights or `no_weights_val` (default NULL)
.get_weights = function(task, no_weights_val = NULL) {
if ("weights" %nin% self$properties) {
stop("private$.get_weights should not be used in Learners that do not have the 'weights' property.")
error_mlr3("private$.get_weights should not be used in Learners that do not have the 'weights' property.")
}
if (self$use_weights == "use" && "weights_learner" %in% task$properties) {
task$weights_learner$weight
Expand Down Expand Up @@ -916,7 +916,7 @@ default_values.Learner = function(x, search_space, task, ...) { # nolint
values = default_values(x$param_set)

if (any(search_space$ids() %nin% names(values))) {
stopf("Could not find default values for the following parameters: %s",
error_learner("Could not find default values for the following parameters: %s",
str_collapse(setdiff(search_space$ids(), names(values))))
}

Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ LearnerClassif = R6Class("LearnerClassif", inherit = Learner,
#'
#' @return `list()` with elements `"response"` or `"prob"` depending on the predict type.
predict_newdata_fast = function(newdata, task = NULL) {
if (is.null(task) && is.null(self$state$train_task)) stopf("No task stored, and no task provided")
if (is.null(task) && is.null(self$state$train_task)) error_input("No task stored, and no task provided")
feature_names = self$state$train_task$feature_names %??% task$feature_names
class_names = self$state$train_task$class_names %??% task$class_names

Expand Down
16 changes: 8 additions & 8 deletions R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
#' @return Named `numeric()`.
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
error_learner("No model stored")
}
fns = self$state$feature_names
set_names(rep(0, length(fns)), fns)
Expand All @@ -124,7 +124,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
#' @return `character()`.
selected_features = function() {
if (is.null(self$model)) {
stopf("No model stored")
error_learner("No model stored")
}
character(0)
}
Expand Down Expand Up @@ -180,10 +180,10 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
message("Message from classif.debug->train()")
}
if (roll("warning_train")) {
warningf("Warning from classif.debug->train()")
warning_mlr3("Warning from classif.debug->train()")
}
if (roll("error_train")) {
stopf("Error from classif.debug->train()")
error_learner_train("Error from classif.debug->train()")
}
if (roll("segfault_train")) {
get("attach")(structure(list(), class = "UserDefinedDatabase"))
Expand All @@ -192,7 +192,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
valid_truth = if (!is.null(task$internal_valid_task)) task$internal_valid_task$truth()

if (isTRUE(pv$early_stopping) && is.null(valid_truth)) {
stopf("Early stopping is only possible when a validation task is present.")
error_config("Early stopping is only possible when a validation task is present.")
}

model = list(
Expand Down Expand Up @@ -248,7 +248,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
},
.predict = function(task) {
if (!is.null(self$model$marshal_pid) && self$model$marshal_pid != Sys.getpid()) {
stopf("Model was not unmarshaled correctly")
error_mlr3("Model was not unmarshaled correctly")
}
n = task$nrow
pv = self$param_set$get_values(tags = "predict")
Expand All @@ -265,10 +265,10 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
message("Message from classif.debug->predict()")
}
if (roll("warning_predict")) {
warningf("Warning from classif.debug->predict()")
warning_mlr3("Warning from classif.debug->predict()")
}
if (roll("error_predict")) {
stopf("Error from classif.debug->predict()")
error_learner_predict("Error from classif.debug->predict()")
}
if (roll("segfault_predict")) {
get("attach")(structure(list(), class = "UserDefinedDatabase"))
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClassifFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ LearnerClassifFeatureless = R6Class("LearnerClassifFeatureless", inherit = Learn
#' @return Named `numeric()`.
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
error_learner("No model stored")
}
fn = self$model$features
named_vector(fn, 0)
Expand All @@ -65,7 +65,7 @@ LearnerClassifFeatureless = R6Class("LearnerClassifFeatureless", inherit = Learn
#' @return `character(0)`.
selected_features = function() {
if (is.null(self$model)) {
stopf("No model stored")
error_learner("No model stored")
}
character()
}
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
#' @return Named `numeric()`.
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
error_learner("No model stored")
}
# importance is only present if there is at least on split
sort(self$model$variable.importance %??% set_names(numeric()), decreasing = TRUE)
Expand All @@ -66,7 +66,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
#' @return `character()`.
selected_features = function() {
if (is.null(self$model)) {
stopf("No model stored")
error_learner("No model stored")
}
setdiff(self$model$frame$var, "<leaf>")
}
Expand Down
6 changes: 3 additions & 3 deletions R/LearnerRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ LearnerRegr = R6Class("LearnerRegr", inherit = Learner,
#'
#' @return `list()` with elements `"response"`, `"se"` or `"quantiles"` depending on the predict type.
predict_newdata_fast = function(newdata, task = NULL) {
if (is.null(task) && is.null(self$state$train_task)) stopf("No task stored, and no task provided")
if (is.null(task) && is.null(self$state$train_task)) error_input("No task stored, and no task provided")
feature_names = self$state$train_task$feature_names %??% task$feature_names

# add data and most common used meta data
Expand Down Expand Up @@ -134,7 +134,7 @@ LearnerRegr = R6Class("LearnerRegr", inherit = Learner,
}

if ("quantiles" %nin% self$predict_types) {
stopf("Learner does not support predicting quantiles")
error_learner_predict("Learner does not support predicting quantiles")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Wrong error type for quantile configuration

The errors when setting quantiles and quantile_response fields use error_learner_predict but should use error_config. These errors occur during configuration when a user tries to set quantile-related fields on a learner that doesn't support quantile prediction, not during actual prediction. Using error_learner_predict incorrectly categorizes this as a prediction failure rather than a configuration error.

Additional Locations (1)

Fix in Cursor Fix in Web

}
private$.quantiles = assert_numeric(rhs, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L, sorted = TRUE, .var.name = "quantiles")

Expand All @@ -151,7 +151,7 @@ LearnerRegr = R6Class("LearnerRegr", inherit = Learner,
}

if ("quantiles" %nin% self$predict_types) {
stopf("Learner does not support predicting quantiles")
error_learner_predict("Learner does not support predicting quantiles")
}

private$.quantile_response = assert_number(rhs, lower = 0, upper = 1, .var.name = "response")
Expand Down
16 changes: 8 additions & 8 deletions R/LearnerRegrDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
#' @return Named `numeric()`.
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
error_learner("No model stored")
}
fns = self$state$feature_names
set_names(rep(0, length(fns)), fns)
Expand All @@ -75,7 +75,7 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
#' @return `character()`.
selected_features = function() {
if (is.null(self$model)) {
stopf("No model stored")
error_learner("No model stored")
}
character(0)
}
Expand All @@ -88,13 +88,13 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
}

if (roll("message_train")) {
message("Message from classif.debug->train()")
message("Message from regr.debug->train()")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incorrect learner name in debug message

The training message for regr.debug incorrectly references classif.debug instead of regr.debug. The message says "Message from classif.debug->train()" but this is the regression debug learner, so it should say "Message from regr.debug->train()". The predict message at line 134 correctly uses "regr.debug", showing this is an inconsistency.

Fix in Cursor Fix in Web

}
if (roll("warning_train")) {
warningf("Warning from classif.debug->train()")
warning_mlr3("Warning from regr.debug->train()")
}
if (roll("error_train")) {
stopf("Error from classif.debug->train()")
error_learner_train("Error from regr.debug->train()")
}
if (roll("segfault_train")) {
get("attach")(structure(list(), class = "UserDefinedDatabase"))
Expand Down Expand Up @@ -131,13 +131,13 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
}

if (roll("message_predict")) {
message("Message from classif.debug->predict()")
message("Message from regr.debug->predict()")
}
if (roll("warning_predict")) {
warningf("Warning from classif.debug->predict()")
warning_mlr3("Warning from regr.debug->predict()")
}
if (roll("error_predict")) {
stopf("Error from classif.debug->predict()")
error_learner_predict("Error from regr.debug->predict()")
}
if (roll("segfault_predict")) {
get("attach")(structure(list(), class = "UserDefinedDatabase"))
Expand Down
Loading