From 460ea0971109783a9cda5451562af37fd994030d Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 6 Feb 2025 22:31:19 +0100 Subject: [PATCH 1/5] feat: add hotstarting --- R/Graph.R | 6 ++++-- R/GraphLearner.R | 18 ++++++++++++++++++ R/PipeOp.R | 3 +++ R/PipeOpLearner.R | 7 +++++++ 4 files changed, 32 insertions(+), 2 deletions(-) diff --git a/R/Graph.R b/R/Graph.R index 366f2332b..eddc655ee 100644 --- a/R/Graph.R +++ b/R/Graph.R @@ -449,16 +449,18 @@ Graph = R6Class("Graph", self$set_names(ids, sprintf("%s%s%s", assert_string(prefix), ids, assert_string(postfix))) invisible(self) }, - train = function(input, single_input = TRUE) { graph_load_namespaces(self, "train") graph_reduce(self, input, "train", single_input) }, - predict = function(input, single_input = TRUE) { graph_load_namespaces(self, "predict") graph_reduce(self, input, "predict", single_input) }, + hotstart = function(input, single_input = TRUE) { + graph_load_namespaces(self, "train") + graph_reduce(self, input, "hotstart", single_input) + }, help = function(help_type = getOption("help_type")) { parts = strsplit(self$man, split = "::", fixed = TRUE)[[1]] match.fun("help")(parts[[2]], package = parts[[1]], help_type = help_type) diff --git a/R/GraphLearner.R b/R/GraphLearner.R index f519e5007..a19f65bb9 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -454,6 +454,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, } on.exit({self$graph$state = NULL}) + self$graph$train(task) state = self$graph$state class(state) = c("graph_learner_model", class(state)) @@ -466,6 +467,23 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, assert_list(prediction, types = "Prediction", len = 1, .var.name = sprintf("Prediction returned by Graph %s", self$id)) prediction[[1]] + }, + + .hotstart = function(task) { + if (!is.null(get0("validate", self))) { + some_pipeops_validate = some(pos_with_property(self, "validation"), function(po) !is.null(po$validate)) + if (!some_pipeops_validate) { + lg$warn("GraphLearner '%s' specifies a validation set, but none of its PipeOps use it.", self$id) + } + } + + on.exit({self$graph$state = NULL}) + # copy hotstart state to graph + self$graph$state = self$state$model + self$graph$hotstart(task) + state = self$graph$state + class(state) = c("graph_learner_model", class(state)) + state } ) ) diff --git a/R/PipeOp.R b/R/PipeOp.R index b3487b0d6..b0583eb24 100644 --- a/R/PipeOp.R +++ b/R/PipeOp.R @@ -348,6 +348,9 @@ PipeOp = R6Class("PipeOp", output = check_types(self, output, "output", "predict") output }, + hotstart = function(input) { + self$train(input) + }, help = function(help_type = getOption("help_type")) { parts = strsplit(self$man, split = "::", fixed = TRUE)[[1]] match.fun("help")(parts[[2]], package = parts[[1]], help_type = help_type) diff --git a/R/PipeOpLearner.R b/R/PipeOpLearner.R index 0849480af..42f0a96d9 100644 --- a/R/PipeOpLearner.R +++ b/R/PipeOpLearner.R @@ -112,6 +112,12 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, output = data.table(name = "output", train = "NULL", predict = out_type), tags = "learner", packages = learner$packages, properties = properties ) + }, + hotstart = function(input) { + # copy model from state to learner + private$.learner$state$model = self$state$model + output = get_private(private$.learner)$.hotstart(input[[1]]) + list(NULL) } ), active = list( @@ -199,6 +205,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, private$.learner$state = self$state list(private$.learner$predict(task)) }, + .additional_phash_input = function() private$.learner$phash ) ) From 3a887ad400d4e19615ba4d407e76abe35cac8cbb Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 6 Feb 2025 22:52:15 +0100 Subject: [PATCH 2/5] ... --- R/PipeOpLearner.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/PipeOpLearner.R b/R/PipeOpLearner.R index 42f0a96d9..6e5690e62 100644 --- a/R/PipeOpLearner.R +++ b/R/PipeOpLearner.R @@ -114,8 +114,8 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, ) }, hotstart = function(input) { - # copy model from state to learner - private$.learner$state$model = self$state$model + # copy state to learner + private$.learner$state = self$state output = get_private(private$.learner)$.hotstart(input[[1]]) list(NULL) } From 660f15949be82e55185899cfc8e699e3d9ec3a68 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 7 Feb 2025 10:09:57 +0100 Subject: [PATCH 3/5] ... --- R/PipeOpLearner.R | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/R/PipeOpLearner.R b/R/PipeOpLearner.R index 6e5690e62..f7b57656a 100644 --- a/R/PipeOpLearner.R +++ b/R/PipeOpLearner.R @@ -114,9 +114,14 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, ) }, hotstart = function(input) { + on.exit({private$.learner$state = NULL}) # copy state to learner - private$.learner$state = self$state - output = get_private(private$.learner)$.hotstart(input[[1]]) + learner = private$.learner + learner$state = self$state + + train_result = mlr3:::learner_train(learner, task = input[[1]], train_row_ids = NULL, mode = "hotstart") + self$state = train_result$learner$state + list(NULL) } ), From 7bf6c1cedcb08b1a42722e0fa7b83d8ab4034fef Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 7 Feb 2025 10:34:02 +0100 Subject: [PATCH 4/5] ... --- R/Graph.R | 2 ++ R/PipeOp.R | 2 ++ R/PipeOpLearner.R | 1 + 3 files changed, 5 insertions(+) diff --git a/R/Graph.R b/R/Graph.R index eddc655ee..483634de5 100644 --- a/R/Graph.R +++ b/R/Graph.R @@ -449,10 +449,12 @@ Graph = R6Class("Graph", self$set_names(ids, sprintf("%s%s%s", assert_string(prefix), ids, assert_string(postfix))) invisible(self) }, + train = function(input, single_input = TRUE) { graph_load_namespaces(self, "train") graph_reduce(self, input, "train", single_input) }, + predict = function(input, single_input = TRUE) { graph_load_namespaces(self, "predict") graph_reduce(self, input, "predict", single_input) diff --git a/R/PipeOp.R b/R/PipeOp.R index b0583eb24..456561387 100644 --- a/R/PipeOp.R +++ b/R/PipeOp.R @@ -349,6 +349,8 @@ PipeOp = R6Class("PipeOp", output }, hotstart = function(input) { + # default for all pipops is to just train them + # pipeops that can do hotstarting should overload this method self$train(input) }, help = function(help_type = getOption("help_type")) { diff --git a/R/PipeOpLearner.R b/R/PipeOpLearner.R index f7b57656a..59c277da6 100644 --- a/R/PipeOpLearner.R +++ b/R/PipeOpLearner.R @@ -119,6 +119,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, learner = private$.learner learner$state = self$state + # train learner with hotstarting train_result = mlr3:::learner_train(learner, task = input[[1]], train_row_ids = NULL, mode = "hotstart") self$state = train_result$learner$state From 73ffa8922ab586859cd808171f049f3a88368d2d Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 7 Feb 2025 10:34:54 +0100 Subject: [PATCH 5/5] ... --- R/PipeOpLearner.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/PipeOpLearner.R b/R/PipeOpLearner.R index 59c277da6..a5fcc8e1b 100644 --- a/R/PipeOpLearner.R +++ b/R/PipeOpLearner.R @@ -211,7 +211,6 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, private$.learner$state = self$state list(private$.learner$predict(task)) }, - .additional_phash_input = function() private$.learner$phash ) )