@@ -99,7 +99,7 @@ test_that("E2E: Functional spec tuning (including repetition) works", {
9999 tune_wf <- workflows :: workflow(rec , tune_spec )
100100
101101 folds <- rsample :: vfold_cv(iris , v = 2 )
102- params <- extract_parameter_set_dials(tune_wf ) | >
102+ params <- extract_parameter_set_dials(tune_wf ) | >
103103 update(
104104 num_dense_path = num_terms(c(1 , 2 )),
105105 dense_path_units = hidden_units(c(4 , 8 ))
@@ -192,4 +192,107 @@ test_that("E2E: Multi-input, single-output functional classification works", {
192192 expect_equal(names(preds ), c(" .pred_class" ))
193193 expect_equal(nrow(preds ), 5 )
194194 expect_true(is.factor(preds $ .pred_class ))
195+ })
196+
197+ test_that(" E2E: Functional spec with pre-constructed optimizer works" , {
198+ skip_if_no_keras()
199+
200+ # Define blocks for a simple forked functional model
201+ input_block <- function (input_shape ) keras3 :: layer_input(shape = input_shape )
202+ path_block <- function (tensor , units = 16 ) {
203+ tensor | > keras3 :: layer_dense(units = units , activation = " relu" )
204+ }
205+ concat_block <- function (input_a , input_b ) {
206+ keras3 :: layer_concatenate(list (input_a , input_b ))
207+ }
208+ output_block_class <- function (tensor , num_classes ) {
209+ tensor | > keras3 :: layer_dense(units = num_classes , activation = " softmax" )
210+ }
211+
212+ model_name <- " e2e_func_class_optimizer"
213+ on.exit(suppressMessages(remove_keras_spec(model_name )), add = TRUE )
214+
215+ # Create a spec with two parallel paths that are then concatenated
216+ create_keras_functional_spec(
217+ model_name = model_name ,
218+ layer_blocks = list (
219+ main_input = input_block ,
220+ path_a = inp_spec(path_block , " main_input" ),
221+ path_b = inp_spec(path_block , " main_input" ),
222+ concatenated = inp_spec(
223+ concat_block ,
224+ c(path_a = " input_a" , path_b = " input_b" )
225+ ),
226+ output = inp_spec(output_block_class , " concatenated" )
227+ ),
228+ mode = " classification"
229+ )
230+
231+ # Define a pre-constructed optimizer
232+ my_optimizer <- keras3 :: optimizer_sgd(learning_rate = 0.001 )
233+
234+ spec <- e2e_func_class_optimizer(
235+ path_a_units = 8 ,
236+ path_b_units = 4 ,
237+ fit_epochs = 2 ,
238+ compile_optimizer = my_optimizer
239+ ) | >
240+ set_engine(" keras" )
241+
242+ data <- iris
243+ rec <- recipe(Species ~ . , data = data )
244+ wf <- workflows :: workflow(rec , spec )
245+
246+ expect_no_error(fit_obj <- parsnip :: fit(wf , data = data ))
247+ expect_s3_class(fit_obj , " workflow" )
248+ })
249+
250+ test_that(" E2E: Functional spec with string loss works" , {
251+ skip_if_no_keras()
252+
253+ # Define blocks for a simple forked functional model
254+ input_block <- function (input_shape ) keras3 :: layer_input(shape = input_shape )
255+ path_block <- function (tensor , units = 16 ) {
256+ tensor | > keras3 :: layer_dense(units = units , activation = " relu" )
257+ }
258+ concat_block <- function (input_a , input_b ) {
259+ keras3 :: layer_concatenate(list (input_a , input_b ))
260+ }
261+ output_block_class <- function (tensor , num_classes ) {
262+ tensor | > keras3 :: layer_dense(units = num_classes , activation = " softmax" )
263+ }
264+
265+ model_name <- " e2e_func_class_loss_string"
266+ on.exit(suppressMessages(remove_keras_spec(model_name )), add = TRUE )
267+
268+ # Create a spec with two parallel paths that are then concatenated
269+ create_keras_functional_spec(
270+ model_name = model_name ,
271+ layer_blocks = list (
272+ main_input = input_block ,
273+ path_a = inp_spec(path_block , " main_input" ),
274+ path_b = inp_spec(path_block , " main_input" ),
275+ concatenated = inp_spec(
276+ concat_block ,
277+ c(path_a = " input_a" , path_b = " input_b" )
278+ ),
279+ output = inp_spec(output_block_class , " concatenated" )
280+ ),
281+ mode = " classification"
282+ )
283+
284+ spec <- e2e_func_class_loss_string(
285+ path_a_units = 8 ,
286+ path_b_units = 4 ,
287+ fit_epochs = 2 ,
288+ compile_loss = " categorical_crossentropy"
289+ ) | >
290+ set_engine(" keras" )
291+
292+ data <- iris
293+ rec <- recipe(Species ~ . , data = data )
294+ wf <- workflows :: workflow(rec , spec )
295+
296+ expect_no_error(fit_obj <- parsnip :: fit(wf , data = data ))
297+ expect_s3_class(fit_obj , " workflow" )
195298})
0 commit comments