Skip to content

Commit b0fdf7b

Browse files
committed
Fixing issue with register model args
1 parent ff74419 commit b0fdf7b

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

R/register_model_args.R

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,30 @@ register_model_args <- function(model_name, parsnip_names) {
6666
if (startsWith(arg, "num_")) {
6767
dials_fun <- "num_terms"
6868
} else {
69-
base_arg <- sub(".*_", "", arg)
70-
idx <- match(base_arg, keras_dials_map$keras_arg)
71-
dials_fun <- if (!is.na(idx)) keras_dials_map$dials_fun[idx] else arg
69+
# First, try to match the full argument name
70+
idx <- match(arg, keras_dials_map$keras_arg)
71+
if (!is.na(idx)) {
72+
dials_fun <- keras_dials_map$dials_fun[idx]
73+
} else {
74+
# If no full match, try to match the base name (e.g., "units" from "dense_units")
75+
base_arg <- sub(".*_", "", arg)
76+
idx <- match(base_arg, keras_dials_map$keras_arg)
77+
dials_fun <- if (!is.na(idx)) keras_dials_map$dials_fun[idx] else arg
78+
}
79+
}
80+
81+
pkg <- if (dials_fun %in% c("loss_function_keras", "optimizer_function")) {
82+
"kerasnip"
83+
} else {
84+
"dials"
7285
}
7386

7487
parsnip::set_model_arg(
7588
model = model_name,
7689
eng = "keras",
7790
parsnip = arg,
7891
original = arg,
79-
func = list(pkg = "dials", fun = dials_fun),
92+
func = list(pkg = pkg, fun = dials_fun),
8093
has_submodel = FALSE
8194
)
8295
}

0 commit comments

Comments
 (0)