Skip to content

Commit 335428c

Browse files
committed
Modularizin x and y processors
1 parent 2a90d2a commit 335428c

File tree

4 files changed

+111
-45
lines changed

4 files changed

+111
-45
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export(register_keras_loss)
1414
export(register_keras_metric)
1515
export(register_keras_optimizer)
1616
export(remove_keras_spec)
17+
importFrom(keras3,to_categorical)
1718
importFrom(parsnip,update_dot_check)
1819
importFrom(rlang,arg_match)
1920
importFrom(rlang,dots_list)

R/generic_functional_fit.R

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -83,31 +83,32 @@ generic_functional_fit <- function(
8383
learn_rate <- all_args$learn_rate %||% 0.01
8484
verbose <- all_args$verbose %||% 0
8585

86-
if (is.data.frame(x) && ncol(x) == 1 && is.list(x[[1]])) {
87-
x_proc <- do.call(abind::abind, c(x[[1]], list(along = 0)))
88-
} else {
89-
x_proc <- as.matrix(x)
90-
}
91-
input_shape <- if (length(dim(x_proc)) > 2) dim(x_proc)[-1] else ncol(x_proc)
92-
is_classification <- is.factor(y)
93-
if (is_classification) {
94-
class_levels <- levels(y)
95-
num_classes <- length(class_levels)
96-
y_mat <- keras3::to_categorical(
97-
as.numeric(y) - 1,
98-
num_classes = num_classes
99-
)
100-
default_loss <- if (num_classes > 2) {
86+
# Process x input
87+
x_processed <- process_x(x)
88+
x_proc <- x_processed$x_proc
89+
input_shape <- x_processed$input_shape
90+
91+
# Process y input
92+
y_processed <- process_y(y)
93+
y_mat <- y_processed$y_proc
94+
is_classification <- y_processed$is_classification
95+
class_levels <- y_processed$class_levels
96+
num_classes <- y_processed$num_classes
97+
98+
# Determine default compile arguments based on mode
99+
default_loss <- if (is_classification) {
100+
if (num_classes > 2) {
101101
"categorical_crossentropy"
102102
} else {
103103
"binary_crossentropy"
104104
}
105-
default_metrics <- "accuracy"
106105
} else {
107-
class_levels <- NULL
108-
y_mat <- as.matrix(y)
109-
default_loss <- "mean_squared_error"
110-
default_metrics <- "mean_absolute_error"
106+
"mean_squared_error"
107+
}
108+
default_metrics <- if (is_classification) {
109+
"accuracy"
110+
} else {
111+
"mean_absolute_error"
111112
}
112113

113114
# --- 2. Dynamic Model Architecture Construction (DIFFERENT from sequential) ---

R/generic_sequential_fit.R

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -78,39 +78,32 @@ generic_sequential_fit <- function(
7878
learn_rate <- all_args$learn_rate %||% 0.01
7979
verbose <- all_args$verbose %||% 0
8080

81-
# Handle both standard tabular data (matrix) and list-columns of arrays
82-
# (for images/sequences) that come from recipes.
83-
if (is.data.frame(x) && ncol(x) == 1 && is.list(x[[1]])) {
84-
# Assumes a single predictor column containing a list of arrays.
85-
# We stack them into a single higher-dimensional array.
86-
x_proc <- do.call(abind::abind, c(x[[1]], list(along = 0)))
87-
} else {
88-
x_proc <- as.matrix(x)
89-
}
81+
# Process x input
82+
x_processed <- process_x(x)
83+
x_proc <- x_processed$x_proc
84+
input_shape <- x_processed$input_shape
9085

91-
# Determine the correct input shape for the Keras model.
92-
input_shape <- if (length(dim(x_proc)) > 2) dim(x_proc)[-1] else ncol(x_proc)
86+
# Process y input
87+
y_processed <- process_y(y)
88+
y_mat <- y_processed$y_proc
89+
is_classification <- y_processed$is_classification
90+
class_levels <- y_processed$class_levels
91+
num_classes <- y_processed$num_classes
9392

9493
# Determine default compile arguments based on mode
95-
is_classification <- is.factor(y)
96-
if (is_classification) {
97-
class_levels <- levels(y)
98-
num_classes <- length(class_levels)
99-
y_mat <- keras3::to_categorical(
100-
as.numeric(y) - 1,
101-
num_classes = num_classes
102-
)
103-
default_loss <- if (num_classes > 2) {
94+
default_loss <- if (is_classification) {
95+
if (num_classes > 2) {
10496
"categorical_crossentropy"
10597
} else {
10698
"binary_crossentropy"
10799
}
108-
default_metrics <- "accuracy"
109100
} else {
110-
class_levels <- NULL
111-
y_mat <- as.matrix(y)
112-
default_loss <- "mean_squared_error"
113-
default_metrics <- "mean_absolute_error"
101+
"mean_squared_error"
102+
}
103+
default_metrics <- if (is_classification) {
104+
"accuracy"
105+
} else {
106+
"mean_absolute_error"
114107
}
115108

116109
# --- 2. Dynamic Model Architecture Construction ---

R/utils.R

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,74 @@ loss_function_keras <- function(values = NULL) {
175175
finalize = NULL
176176
)
177177
}
178+
179+
#' Process Predictor Input for Keras
180+
#'
181+
#' @description
182+
#' Preprocesses predictor data (`x`) into a format suitable for Keras models.
183+
#' Handles both tabular data and list-columns of arrays (e.g., for images).
184+
#'
185+
#' @param x A data frame or matrix of predictors.
186+
#' @return A list containing:
187+
#' - `x_proc`: The processed predictor data (matrix or array).
188+
#' - `input_shape`: The determined input shape for the Keras model.
189+
#' @noRd
190+
process_x <- function(x) {
191+
if (is.data.frame(x) && ncol(x) == 1 && is.list(x[[1]])) {
192+
# Assumes a single predictor column containing a list of arrays.
193+
# We stack them into a single higher-dimensional array.
194+
x_proc <- do.call(abind::abind, c(x[[1]], list(along = 0)))
195+
} else {
196+
x_proc <- as.matrix(x)
197+
}
198+
input_shape <- if (length(dim(x_proc)) > 2) dim(x_proc)[-1] else ncol(x_proc)
199+
list(x_proc = x_proc, input_shape = input_shape)
200+
}
201+
202+
#' Process Outcome Input for Keras
203+
#'
204+
#' @description
205+
#' Preprocesses outcome data (`y`) into a format suitable for Keras models.
206+
#' Handles both regression (numeric) and classification (factor) outcomes,
207+
#' including one-hot encoding for classification.
208+
#'
209+
#' @param y A vector of outcomes.
210+
#' @param is_classification Logical, optional. If `TRUE`, treats `y` as
211+
#' classification. If `FALSE`, treats as regression. If `NULL` (default),
212+
#' it's determined from `is.factor(y)`.
213+
#' @param class_levels Character vector, optional. The factor levels for
214+
#' classification outcomes. If `NULL` (default), determined from `levels(y)`.
215+
#' @return A list containing:
216+
#' - `y_proc`: The processed outcome data (matrix or one-hot encoded array).
217+
#' - `is_classification`: Logical, indicating if `y` was treated as classification.
218+
#' - `num_classes`: Integer, the number of classes for classification, or `NULL`.
219+
#' - `class_levels`: Character vector, the factor levels for classification, or `NULL`.
220+
#' @importFrom keras3 to_categorical
221+
#' @noRd
222+
process_y <- function(y, is_classification = NULL, class_levels = NULL) {
223+
if (is.null(is_classification)) {
224+
is_classification <- is.factor(y)
225+
}
226+
227+
y_proc <- NULL
228+
num_classes <- NULL
229+
if (is_classification) {
230+
if (is.null(class_levels)) {
231+
class_levels <- levels(y)
232+
}
233+
num_classes <- length(class_levels)
234+
y_factored <- factor(y, levels = class_levels)
235+
y_proc <- keras3::to_categorical(
236+
as.numeric(y_factored) - 1,
237+
num_classes = num_classes
238+
)
239+
} else {
240+
y_proc <- as.matrix(y)
241+
}
242+
list(
243+
y_proc = y_proc,
244+
is_classification = is_classification,
245+
num_classes = num_classes,
246+
class_levels = class_levels
247+
)
248+
}

0 commit comments

Comments
 (0)