Skip to content

Commit 04becee

Browse files
authored
Merge pull request #13 from davidrsch/getting-starte-updt
Getting started update
2 parents 8a0351d + 420aa6c commit 04becee

19 files changed

+435
-257
lines changed

.github/workflows/R-CMD-check.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ jobs:
6262
if: runner.os == 'Linux'
6363
run: |
6464
sudo apt-get update
65-
sudo apt-get install -y qpdf ghostscript
65+
sudo apt-get install -y qpdf ghostscript graphviz
66+
if: runner.os == 'macOS'
67+
run: brew install graphviz
68+
if: runner.os == 'Windows'
69+
run: choco install graphviz -y
6670

6771
- uses: r-lib/actions/setup-r-dependencies@v2
6872
with:

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ docs
77
.httr-oauth
88
.DS_Store
99
.quarto
10+
vignettes/*_cache

R/generic_functional_fit.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,15 @@ generic_functional_fit <- function(
128128

129129
# --- Get Repetition Count ---
130130
num_repeats_arg <- paste0("num_", block_name)
131-
num_repeats <- all_args[[num_repeats_arg]] %||% 1
131+
num_repeats_val <- all_args[[num_repeats_arg]]
132+
133+
# If num_repeats_val is NULL or zapped, default to 1.
134+
# Otherwise, use the value provided by the user.
135+
if (is.null(num_repeats_val) || inherits(num_repeats_val, "rlang_zap")) {
136+
num_repeats <- 1
137+
} else {
138+
num_repeats <- as.integer(num_repeats_val)
139+
}
132140

133141
# --- Get Hyperparameters for this block ---
134142
# Hyperparameters are formals that are NOT other block names (graph connections)

R/generic_sequential_fit.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,14 @@ generic_sequential_fit <- function(
118118

119119
num_repeats_arg <- paste0("num_", block_name)
120120
num_repeats_val <- all_args[[num_repeats_arg]]
121-
num_repeats <- num_repeats_val %||% 1
121+
122+
# If num_repeats_val is NULL or zapped, default to 1.
123+
# Otherwise, use the value provided by the user.
124+
if (is.null(num_repeats_val) || inherits(num_repeats_val, "rlang_zap")) {
125+
num_repeats <- 1
126+
} else {
127+
num_repeats <- as.integer(num_repeats_val)
128+
}
122129

123130
# Get the arguments for this specific block from `...`
124131
block_arg_names <- names(block_fmls)[-1] # Exclude 'model'

R/register_fit_predict.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ register_fit_predict <- function(model_name, mode, layer_blocks, functional) {
5757
func = c(fun = "predict"),
5858
args = list(
5959
object = rlang::expr(object$fit$fit),
60-
x = rlang::expr(as.matrix(new_data))
60+
x = rlang::expr(process_x(new_data)$x_proc)
6161
)
6262
)
6363
)
@@ -74,7 +74,7 @@ register_fit_predict <- function(model_name, mode, layer_blocks, functional) {
7474
func = c(fun = "predict"),
7575
args = list(
7676
object = rlang::expr(object$fit$fit),
77-
x = rlang::expr(as.matrix(new_data))
77+
x = rlang::expr(process_x(new_data)$x_proc)
7878
)
7979
)
8080
)
@@ -89,7 +89,7 @@ register_fit_predict <- function(model_name, mode, layer_blocks, functional) {
8989
func = c(fun = "predict"),
9090
args = list(
9191
object = rlang::expr(object$fit$fit),
92-
x = rlang::expr(as.matrix(new_data))
92+
x = rlang::expr(process_x(new_data)$x_proc)
9393
)
9494
)
9595
)

R/utils.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ process_x <- function(x) {
220220
#' @importFrom keras3 to_categorical
221221
#' @noRd
222222
process_y <- function(y, is_classification = NULL, class_levels = NULL) {
223+
# If y is a data frame/tibble, extract the first column
224+
if (is.data.frame(y)) {
225+
y <- y[[1]]
226+
}
227+
223228
if (is.null(is_classification)) {
224229
is_classification <- is.factor(y)
225230
}

_pkgdown.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ guides:
1717
- title: "Getting Started"
1818
navbar: ~
1919
contents:
20-
- getting-started
21-
- functional-api
20+
- getting_started
21+
- functional_api
2222

2323
# examples:
2424

@@ -63,7 +63,7 @@ navbar:
6363
components:
6464
intro:
6565
text: "Getting started"
66-
href: guides/getting-started.html
66+
href: guides/getting_started.html
6767
github:
6868
icon: fa-github
6969
href: https://github.com/davidrsch/kerasnip

tests/testthat/helper-keras.R renamed to tests/testthat/helper_keras.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ library(modeldata)
66
library(rsample)
77
library(dials)
88
library(tune)
9+
library(purrr)
910

1011
skip_if_no_keras <- function() {
1112
testthat::skip_if_not_installed("keras3")

tests/testthat/test-e2e-features.R renamed to tests/testthat/test_e2e_features.R

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,19 @@ test_that("E2E: Customizing fit arguments works", {
102102
expect_lt(length(fit_obj$fit$history$metrics$loss), 5)
103103
})
104104

105-
test_that("E2E: Setting num_blocks = 0 works", {
105+
test_that("E2E: Setting num_blocks = 0 works for sequential models", {
106106
skip_if_no_keras()
107107

108108
input_block_zero <- function(model, input_shape) {
109109
keras3::keras_model_sequential(input_shape = input_shape)
110110
}
111111
dense_block_zero <- function(model, units = 16) {
112-
model |> keras3::layer_dense(units = units, activation = "relu")
112+
model |>
113+
keras3::layer_dense(
114+
units = units,
115+
activation = "relu",
116+
name = "i_should_not_exist"
117+
)
113118
}
114119
output_block_zero <- function(model) {
115120
model |> keras3::layer_dense(units = 1)
@@ -128,10 +133,18 @@ test_that("E2E: Setting num_blocks = 0 works", {
128133
mode = "regression"
129134
)
130135

131-
spec <- e2e_mlp_zero(num_dense = 0, fit_epochs = 2) |>
136+
spec <- e2e_mlp_zero(num_dense = 0, fit_epochs = 1) |>
132137
parsnip::set_engine("keras")
133-
# This should fit a model with only an input and output layer
134-
expect_no_error(parsnip::fit(spec, mpg ~ ., data = mtcars))
138+
139+
fit_obj <- parsnip::fit(spec, mpg ~ ., data = mtcars)
140+
141+
# Check that the dense layer is NOT in the model
142+
keras_model <- fit_obj |> extract_keras_summary()
143+
expect_equal(length(keras_model$layers), 1) # Output layers only
144+
145+
# Check layer names explicitly
146+
layer_names <- sapply(keras_model$layers, function(l) l$name)
147+
expect_false("i_should_not_exist" %in% layer_names)
135148
})
136149

137150
test_that("E2E: Error handling for reserved names works", {

0 commit comments

Comments
 (0)