Skip to content

Commit d9b6df7

Browse files
committed
Improving tests
1 parent 0944b0f commit d9b6df7

File tree

3 files changed

+319
-0
lines changed

3 files changed

+319
-0
lines changed
File renamed without changes.
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Mock get_keras_object to isolate the logic of collect_compile_args
2+
mock_get_keras_object <- function(name, type, ...) {
3+
# Return a simple string representation for testing purposes
4+
paste0("mocked_", type, "_", name)
5+
}
6+
7+
# Mock optimizer to avoid keras dependency
8+
mock_optimizer_adam <- function(...) {
9+
"mocked_optimizer_adam"
10+
}
11+
12+
test_that("collect_compile_args handles single-output cases correctly", {
13+
# Mock the keras3::optimizer_adam function
14+
testthat::with_mocked_bindings(
15+
.env = as.environment("package:kerasnip"),
16+
get_keras_object = mock_get_keras_object,
17+
{
18+
# Case 1: Single output, non-character loss and metrics
19+
dummy_loss_obj <- structure(list(), class = "dummy_loss")
20+
dummy_metric_obj <- structure(list(), class = "dummy_metric")
21+
22+
args <- collect_compile_args(
23+
all_args = list(
24+
compile_loss = dummy_loss_obj,
25+
compile_metrics = list(dummy_metric_obj)
26+
),
27+
learn_rate = 0.01,
28+
default_loss = "mse",
29+
default_metrics = "mae"
30+
)
31+
expect_equal(args$loss, dummy_loss_obj)
32+
expect_equal(args$metrics, list(dummy_metric_obj))
33+
}
34+
)
35+
})
36+
37+
test_that("collect_compile_args handles multi-output cases correctly", {
38+
testthat::with_mocked_bindings(
39+
.env = as.environment("package:kerasnip"),
40+
get_keras_object = mock_get_keras_object,
41+
{
42+
# Case 2: Multi-output, single string for loss and metrics
43+
args <- collect_compile_args(
44+
all_args = list(
45+
compile_loss = "categorical_crossentropy",
46+
compile_metrics = "accuracy"
47+
),
48+
learn_rate = 0.01,
49+
default_loss = list(out1 = "mse", out2 = "mae"),
50+
default_metrics = list(out1 = "mse", out2 = "mae")
51+
)
52+
expect_equal(args$loss, "mocked_loss_categorical_crossentropy")
53+
expect_equal(args$metrics, "mocked_metric_accuracy")
54+
55+
# Case 3: Multi-output, named list with mixed types
56+
dummy_loss_obj_2 <- structure(list(), class = "dummy_loss_2")
57+
args_mixed <- collect_compile_args(
58+
all_args = list(
59+
compile_loss = list(out1 = "mae", out2 = dummy_loss_obj_2)
60+
),
61+
learn_rate = 0.01,
62+
default_loss = list(out1 = "mse", out2 = "mae"),
63+
default_metrics = list(out1 = "mse", out2 = "mae")
64+
)
65+
expect_equal(args_mixed$loss$out1, "mocked_loss_mae")
66+
expect_equal(args_mixed$loss$out2, dummy_loss_obj_2)
67+
}
68+
)
69+
})
70+
71+
test_that("collect_compile_args handles named list of metrics (multi-output) correctly", {
72+
testthat::with_mocked_bindings(
73+
.env = as.environment("package:kerasnip"),
74+
get_keras_object = mock_get_keras_object,
75+
{
76+
# Test case: Named list of metrics with mixed types (character and object)
77+
dummy_metric_obj_3 <- structure(list(), class = "dummy_metric_3")
78+
args_mixed_metrics <- collect_compile_args(
79+
all_args = list(
80+
compile_metrics = list(out1 = "accuracy", out2 = dummy_metric_obj_3)
81+
),
82+
learn_rate = 0.01,
83+
default_loss = list(out1 = "mse", out2 = "mae"),
84+
default_metrics = list(out1 = "mse", out2 = "mae") # Important: default_metrics must be a named list for this path
85+
)
86+
expect_equal(args_mixed_metrics$metrics$out1, "mocked_metric_accuracy")
87+
expect_equal(args_mixed_metrics$metrics$out2, dummy_metric_obj_3)
88+
89+
# Test case: Named list of metrics with all characters
90+
args_all_char_metrics <- collect_compile_args(
91+
all_args = list(
92+
compile_metrics = list(out1 = "accuracy", out2 = "mse")
93+
),
94+
learn_rate = 0.01,
95+
default_loss = list(out1 = "mse", out2 = "mae"),
96+
default_metrics = list(out1 = "mse", out2 = "mae")
97+
)
98+
expect_equal(args_all_char_metrics$metrics$out1, "mocked_metric_accuracy")
99+
expect_equal(args_all_char_metrics$metrics$out2, "mocked_metric_mse")
100+
101+
# Test case: Named list of metrics with all objects
102+
dummy_metric_obj_4 <- structure(list(), class = "dummy_metric_4")
103+
dummy_metric_obj_5 <- structure(list(), class = "dummy_metric_5")
104+
args_all_obj_metrics <- collect_compile_args(
105+
all_args = list(
106+
compile_metrics = list(out1 = dummy_metric_obj_4, out2 = dummy_metric_obj_5)
107+
),
108+
learn_rate = 0.01,
109+
default_loss = list(out1 = "mse", out2 = "mae"),
110+
default_metrics = list(out1 = "mse", out2 = "mae")
111+
)
112+
expect_equal(args_all_obj_metrics$metrics$out1, dummy_metric_obj_4)
113+
expect_equal(args_all_obj_metrics$metrics$out2, dummy_metric_obj_5)
114+
}
115+
)
116+
})
117+
118+
test_that("collect_compile_args throws errors for invalid multi-output args", {
119+
# Case 4: Multi-output, invalid loss argument
120+
expect_error(
121+
collect_compile_args(
122+
all_args = list(compile_loss = list("a", "b")), # Unnamed list
123+
learn_rate = 0.01,
124+
default_loss = list(out1 = "mse", out2 = "mae"),
125+
default_metrics = list(out1 = "mse", out2 = "mae")
126+
),
127+
"For multiple outputs, 'compile_loss' must be a single string or a named list of losses."
128+
)
129+
130+
# Case 5: Multi-output, invalid metrics argument
131+
expect_error(
132+
collect_compile_args(
133+
all_args = list(compile_metrics = list("a", "b")), # Unnamed list
134+
learn_rate = 0.01,
135+
default_loss = list(out1 = "mse", out2 = "mae"),
136+
default_metrics = list(out1 = "mse", out2 = "mae")
137+
),
138+
"For multiple outputs, 'compile_metrics' must be a single string or a named list of metrics."
139+
)
140+
})
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
skip_if_no_keras()
2+
3+
# Mock object for post-processing functions
4+
mock_object_single_output <- list(
5+
fit = list(
6+
lvl = c("setosa", "versicolor", "virginica") # For classification levels
7+
)
8+
)
9+
class(mock_object_single_output) <- "model_fit"
10+
11+
mock_object_multi_output <- list(
12+
fit = list(
13+
lvl = list(
14+
output1 = c("classA", "classB"),
15+
output2 = c("typeX", "typeY", "typeZ")
16+
)
17+
)
18+
)
19+
class(mock_object_multi_output) <- "model_fit"
20+
21+
# --- Tests for keras_postprocess_numeric ---
22+
23+
test_that("keras_postprocess_numeric handles single output (matrix) correctly", {
24+
results <- matrix(c(0.1, 0.2, 0.3), ncol = 1)
25+
processed <- keras_postprocess_numeric(results, mock_object_single_output)
26+
expect_s3_class(processed, "tbl_df")
27+
expect_equal(names(processed), ".pred")
28+
expect_equal(processed$.pred, c(0.1, 0.2, 0.3))
29+
})
30+
31+
test_that("keras_postprocess_numeric handles single output (named list with one element) correctly", {
32+
results <- list(output1 = matrix(c(0.1, 0.2, 0.3), ncol = 1))
33+
names(results) <- "output1"
34+
processed <- keras_postprocess_numeric(results, mock_object_multi_output)
35+
expect_s3_class(processed, "tbl_df")
36+
expect_equal(names(processed), ".pred")
37+
expect_equal(processed$.pred, matrix(c(0.1, 0.2, 0.3), ncol = 1)) # Changed expected
38+
})
39+
40+
41+
test_that("keras_postprocess_numeric handles multi-output (named list) correctly", {
42+
results <- list(
43+
output1 = matrix(c(0.1, 0.2), ncol = 1),
44+
output2 = matrix(c(0.4, 0.5), ncol = 1)
45+
)
46+
names(results) <- c("output1", "output2")
47+
processed <- keras_postprocess_numeric(results, mock_object_multi_output)
48+
expect_s3_class(processed, "tbl_df")
49+
expect_equal(names(processed), c(".pred_output1", ".pred_output2"))
50+
# Change expected values to 1-column matrices
51+
expect_equal(processed$.pred_output1, matrix(c(0.1, 0.2), ncol = 1))
52+
expect_equal(processed$.pred_output2, matrix(c(0.4, 0.5), ncol = 1))
53+
})
54+
55+
# --- Tests for keras_postprocess_probs ---
56+
57+
test_that("keras_postprocess_probs handles single output (matrix) correctly", {
58+
results <- matrix(c(0.1, 0.9, 0.0, # Example probabilities for 3 classes
59+
0.2, 0.1, 0.7,
60+
0.3, 0.3, 0.4), ncol = 3, byrow = TRUE)
61+
processed <- keras_postprocess_probs(results, mock_object_single_output)
62+
expect_s3_class(processed, "tbl_df")
63+
expect_equal(names(processed), c("setosa", "versicolor", "virginica")) # Updated expected names
64+
expect_equal(processed$setosa, c(0.1, 0.2, 0.3)) # Access by correct column name
65+
expect_equal(processed$versicolor, c(0.9, 0.1, 0.3)) # Access by correct column name
66+
expect_equal(processed$virginica, c(0.0, 0.7, 0.4)) # Access by correct column name
67+
})
68+
69+
test_that("keras_postprocess_probs handles multi-output (named list) correctly", {
70+
results <- list(
71+
output1 = matrix(c(0.1, 0.9, 0.2, 0.8), ncol = 2, byrow = TRUE),
72+
output2 = matrix(c(0.3, 0.4, 0.3, 0.5, 0.2, 0.3), ncol = 3, byrow = TRUE)
73+
)
74+
names(results) <- c("output1", "output2")
75+
processed <- keras_postprocess_probs(results, mock_object_multi_output)
76+
expect_s3_class(processed, "tbl_df")
77+
expect_equal(names(processed), c(".pred_output1_classA", ".pred_output1_classB", ".pred_output2_typeX", ".pred_output2_typeY", ".pred_output2_typeZ"))
78+
expect_equal(processed$.pred_output1_classA, c(0.1, 0.2))
79+
expect_equal(processed$.pred_output2_typeX, c(0.3, 0.5))
80+
})
81+
82+
test_that("keras_postprocess_probs handles multi-output with NULL levels fallback", {
83+
results <- list(
84+
output1 = matrix(c(0.1, 0.9, 0.2, 0.8), ncol = 2, byrow = TRUE)
85+
)
86+
names(results) <- "output1"
87+
mock_object_null_lvl <- list(
88+
fit = list(
89+
lvl = list(output1 = NULL) # Simulate NULL levels for this output
90+
)
91+
)
92+
class(mock_object_null_lvl) <- "model_fit"
93+
processed <- keras_postprocess_probs(results, mock_object_null_lvl)
94+
expect_s3_class(processed, "tbl_df")
95+
expect_equal(names(processed), c(".pred_output1_class1", ".pred_output1_class2"))
96+
})
97+
98+
# --- Tests for keras_postprocess_classes ---
99+
100+
test_that("keras_postprocess_classes handles single output (multiclass) correctly", {
101+
results <- matrix(c(0.1, 0.8, 0.1, 0.2, 0.1, 0.7), ncol = 3, byrow = TRUE)
102+
processed <- keras_postprocess_classes(results, mock_object_single_output)
103+
expect_s3_class(processed, "tbl_df")
104+
expect_equal(names(processed), ".pred_class")
105+
expect_equal(as.character(processed$.pred_class), c("versicolor", "virginica"))
106+
expect_true(is.factor(processed$.pred_class))
107+
expect_equal(levels(processed$.pred_class), c("setosa", "versicolor", "virginica"))
108+
})
109+
110+
test_that("keras_postprocess_classes handles single output (binary) correctly", {
111+
results <- matrix(c(0.6, 0.4), ncol = 1) # Changed to single column
112+
mock_object_binary_lvl <- list(
113+
fit = list(
114+
lvl = c("negative", "positive")
115+
)
116+
)
117+
class(mock_object_binary_lvl) <- "model_fit"
118+
processed <- keras_postprocess_classes(results, mock_object_binary_lvl)
119+
expect_s3_class(processed, "tbl_df")
120+
expect_equal(names(processed), ".pred_class")
121+
expect_equal(as.character(processed$.pred_class), c("positive", "negative")) # Changed expected
122+
expect_true(is.factor(processed$.pred_class))
123+
expect_equal(levels(processed$.pred_class), c("negative", "positive"))
124+
})
125+
126+
test_that("keras_postprocess_classes handles multi-output (named list) correctly", {
127+
results <- list(
128+
output1 = matrix(c(0.1, 0.9, 0.2, 0.8), ncol = 2, byrow = TRUE), # Binary
129+
output2 = matrix(c(0.3, 0.4, 0.3, 0.5, 0.2, 0.3), ncol = 3, byrow = TRUE) # Multiclass
130+
)
131+
names(results) <- c("output1", "output2")
132+
processed <- keras_postprocess_classes(results, mock_object_multi_output)
133+
expect_s3_class(processed, "tbl_df")
134+
expect_equal(names(processed), c(".pred_class_output1", ".pred_class_output2"))
135+
expect_equal(as.character(processed$.pred_class_output1), c("classB", "classB"))
136+
expect_equal(as.character(processed$.pred_class_output2), c("typeY", "typeX"))
137+
expect_true(is.factor(processed$.pred_class_output1))
138+
expect_true(is.factor(processed$.pred_class_output2))
139+
expect_equal(levels(processed$.pred_class_output1), c("classA", "classB"))
140+
expect_equal(levels(processed$.pred_class_output2), c("typeX", "typeY", "typeZ"))
141+
})
142+
143+
test_that("keras_postprocess_classes handles multi-output with NULL levels fallback", {
144+
results <- list(
145+
output1 = matrix(c(0.6, 0.4, 0.2, 0.8), ncol = 2, byrow = TRUE) # Binary
146+
)
147+
names(results) <- "output1"
148+
mock_object_null_lvl <- list(
149+
fit = list(
150+
lvl = list(output1 = NULL) # Simulate NULL levels for this output
151+
)
152+
)
153+
class(mock_object_null_lvl) <- "model_fit"
154+
processed <- keras_postprocess_classes(results, mock_object_null_lvl)
155+
expect_s3_class(processed, "tbl_df")
156+
expect_equal(names(processed), c(".pred_class_output1"))
157+
expect_equal(as.character(processed$.pred_class_output1), c("class1", "class2")) # Changed expected
158+
expect_true(is.factor(processed$.pred_class_output1))
159+
expect_equal(levels(processed$.pred_class_output1), c("class1", "class2"))
160+
})
161+
162+
test_that("keras_postprocess_classes handles multi-output (binary, single column) correctly", {
163+
results <- list(
164+
output1 = matrix(c(0.6, 0.4, 0.2, 0.8), ncol = 1, byrow = TRUE) # Single column binary output
165+
)
166+
names(results) <- "output1"
167+
mock_object_multi_output_binary <- list(
168+
fit = list(
169+
lvl = list(output1 = c("negative", "positive")) # Levels for binary output
170+
)
171+
)
172+
class(mock_object_multi_output_binary) <- "model_fit"
173+
processed <- keras_postprocess_classes(results, mock_object_multi_output_binary)
174+
expect_s3_class(processed, "tbl_df")
175+
expect_equal(names(processed), c(".pred_class_output1"))
176+
expect_equal(as.character(processed$.pred_class_output1), c("positive", "negative", "negative", "positive")) # Expected based on 0.5 threshold
177+
expect_true(is.factor(processed$.pred_class_output1))
178+
expect_equal(levels(processed$.pred_class_output1), c("negative", "positive"))
179+
})

0 commit comments

Comments
 (0)