Skip to content

Commit 4e04b27

Browse files
committed
Completing the adaptation of the getting started vignette to mirror the one at keras3
1 parent 10d52af commit 4e04b27

File tree

2 files changed

+105
-104
lines changed

2 files changed

+105
-104
lines changed

vignettes/getting-started.Rmd

Lines changed: 105 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ This new function behaves just like any other `parsnip` model (e.g., `rand_fores
2929

3030
You can install the development version of `kerasnip` from GitHub. You will also need `keras3` and a backend (like TensorFlow).
3131

32-
``` r
32+
```{r}
3333
# install.packages("pak")
3434
pak::pak("davidrsch/kerasnip")
3535
pak::pak("rstudio/keras3")
@@ -48,7 +48,15 @@ library(keras3)
4848

4949
## A `kerasnip` MNIST Example
5050

51-
Let's replicate the standard Keras introductory example, an MLP on the MNIST dataset, but using the `kerasnip` workflow. This will show how to translate a standard Keras model into a reusable, modular `parsnip` specification.
51+
Let’s replicate the classic Keras introductory example, training a simple MLP on the MNIST dataset, but using the `kerasnip` workflow. This will demonstrate how to translate a standard Keras model into a reusable, modular `parsnip` specification.
52+
53+
If you’re familiar with Keras, you’ll recognize the structure; if not, this is a perfect place to start. We’ll begin by learning the basics through a simple task: recognizing handwritten digits from the MNIST dataset.
54+
55+
The MNIST dataset contains 28×28 pixel grayscale images of handwritten digits, like these:
56+
57+
![MINIST](images/MNIST.png){fig-alt="A picture showing grayscale images of handwritten digits (5, 0, 4 and 1)"}
58+
59+
Each image comes with a label indicating which digit it represents. For example, the labels for the images above might be 5, 0, 4, and 1.
5260

5361
### Preparing the Data
5462

@@ -79,20 +87,47 @@ train_df <- data.frame(x = I(x_train), y = y_train_factor)
7987
test_df <- data.frame(x = I(x_test), y = y_test_factor)
8088
```
8189

90+
### The Standard Keras Approach (for comparison)
91+
92+
Before diving into the `kerasnip` workflow, let's quickly look at how this same model is built using standard `keras3` code. This will help highlight the different approach `kerasnip` enables.
93+
94+
```{r keras-standard, eval=FALSE, echo=TRUE, results='hide'}
95+
# The standard Keras3 approach
96+
model <- keras_model_sequential(input_shape = 784) |>
97+
layer_dense(units = 256, activation = "relu") |>
98+
layer_dropout(rate = 0.4) |>
99+
layer_dense(units = 128, activation = "relu") |>
100+
layer_dropout(rate = 0.3) |>
101+
layer_dense(units = 10, activation = "softmax")
102+
103+
summary(model)
104+
105+
model |>
106+
compile(
107+
loss = "categorical_crossentropy",
108+
optimizer = optimizer_rmsprop(),
109+
metrics = "accuracy"
110+
)
111+
112+
# The model would then be trained with model |> fit(...)
113+
```
114+
115+
The code above is imperative: you define each layer and add it to the model step-by-step. Now, let's see how `kerasnip` approaches this by defining reusable components for a declarative, `tidymodels`-friendly workflow.
116+
82117
### Defining the Model with Reusable Blocks
83118

84119
The original Keras example interleaves `layer_dense()` and `layer_dropout()`. With `kerasnip`, we can encapsulate this pattern into a single, reusable block. This makes the overall architecture cleaner and more modular.
85120

86121
```{r define-blocks}
87122
# An input block to initialize the model.
123+
# The 'model' argument is supplied implicitly by the kerasnip backend.
88124
mlp_input_block <- function(model, input_shape) {
89125
keras_model_sequential(input_shape = input_shape)
90126
}
91127
92128
# A reusable "module" that combines a dense layer and a dropout layer.
93-
# This pattern can now be repeated easily.
94-
# default values for parameters must be set
95-
dense_dropout_block <- function(model, units=128, rate=0.1) {
129+
# All arguments that should be tunable need a default value.
130+
dense_dropout_block <- function(model, units = 128, rate = 0.1) {
96131
model |>
97132
layer_dense(units = units, activation = "relu") |>
98133
layer_dropout(rate = rate)
@@ -121,7 +156,9 @@ create_keras_sequential_spec(
121156

122157
### Building and Fitting the Model
123158

124-
We can now use our new `mnist_mlp()` function. To replicate the `keras3` example, we want to repeat our `hidden` block twice with different parameters. `kerasnip` makes this easy: we set `num_hidden = 2` and pass vectors for the `hidden_units` and `hidden_rate` arguments. `kerasnip` will supply the first value to the first instance of the block, the second value to the second instance, and so on.
159+
We can now use our new `mnist_mlp()` function. Notice how its arguments, such as `hidden_1_units` and `hidden_1_rate`, were automatically generated by `kerasnip`. The names are created by combining the name of the layer block (e.g., `hidden_1`) with the arguments of that block's function (e.g., `units`, `rate`).
160+
161+
To replicate the `keras3` example, we'll use both `hidden` blocks and provide their parameters.
125162

126163
```{r use-spec}
127164
mlp_spec <- mnist_mlp(
@@ -140,36 +177,64 @@ mlp_spec <- mnist_mlp(
140177
141178
# Fit the model
142179
mlp_fit <- fit(mlp_spec, y ~ x, data = train_df)
143-
keras_model <- mlp_fit$fit$fit
144-
training_history <- mlp_fit$fit$history
145180
```
146181

147182
```{r model-summarize}
148-
summary(keras_model)
183+
mlp_fit |>
184+
extract_keras_summary()
149185
```
150186

151187
```{r model-plot}
152-
plot(keras_model, show_shapes = TRUE)
188+
mlp_fit |>
189+
extract_keras_summary() |>
190+
plot(show_shapes = TRUE)
153191
```
154192

155193
```{r model-fit-history}
156-
plot(training_history)
194+
mlp_fit |>
195+
extract_keras_history() |>
196+
plot()
157197
```
158198

159-
Evaluate the model’s performance on the test data: Evaluate method missing
199+
### Evaluating Model Performance
200+
201+
The `keras_evaluate()` function provides a straightforward way to assess the model's performance on a test set, using the underlying `keras3::evaluate()` method. It returns the loss and any other metrics that were specified during the model compilation step.
160202

161203
```{r model-evaluate}
162-
# keras_model |> evaluate(x_test, y_test)
204+
mlp_fit |> keras_evaluate(x_test, y_test)
205+
```
206+
207+
### Making Predictions
208+
209+
Once the model is trained, we can use the standard `tidymodels` `predict()` function to generate predictions on new data. By default, `predict()` on a `parsnip` classification model returns the predicted class labels.
210+
211+
```{r model-predict-class}
212+
# Predict the class for the first 5 images in the test set
213+
class_preds <- mlp_fit |>
214+
predict(new_data = head(test_df))
215+
class_preds
163216
```
164217

165-
Generate predictions on new data:
218+
To get the underlying probabilities for each class, we can set `type = "prob"`. This returns a tibble with a probability column for each of the 10 classes (0-9).
166219

167-
```{r model-predict}
168-
probs <- keras_model |> predict(x_test)
220+
```{r model-predict-prob}
221+
# Predict probabilities for the first 5 images
222+
prob_preds <- mlp_fit |> predict(new_data = head(test_df), type = "prob")
223+
prob_preds
169224
```
170225

171-
```{r show-predictions}
172-
max.col(probs) - 1L
226+
We can then compare the predicted class to the actual class for these images to see how the model is performing.
227+
228+
```{r model-predict-compare}
229+
# Combine predictions with actuals for comparison
230+
comparison <- bind_cols(
231+
class_preds,
232+
prob_preds
233+
) |>
234+
bind_cols(
235+
head(test_df[, "y", drop = FALSE])
236+
)
237+
comparison
173238
```
174239

175240
## Example 2: Tuning the Model Architecture
@@ -180,11 +245,12 @@ Using the `mnist_mlp` spec we just created, let's define a tunable model.
180245

181246
```{r tune-spec-mnist}
182247
# Define a tunable specification
248+
# We set num_hidden_2 = 0 to disable the second hidden block for this tuning example
183249
tune_spec <- mnist_mlp(
184250
num_hidden_1 = tune(),
185251
hidden_1_units = tune(),
186252
hidden_1_rate = tune(),
187-
num_hidden2 = 0,
253+
num_hidden_2 = 0,
188254
compile_loss = "categorical_crossentropy",
189255
compile_optimizer = optimizer_rmsprop(),
190256
compile_metrics = c("accuracy"),
@@ -231,97 +297,32 @@ Finally, we can inspect the results to find which architecture performed the bes
231297
show_best(tune_res, metric = "accuracy")
232298
```
233299

234-
Now, let's visualize the top 5 models from the tuning results in detail.
235-
236-
```{r extract-top-models}
237-
# Get the top 5 results to iterate through
238-
top_5_results <- show_best(tune_res, metric = "accuracy") |>
239-
select(all_of(names(grid)), .config)
240-
241-
finalize_fit_tops <- function(parameters, workflow) {
242-
finalize_workflow(x = workflow, parameters = parameters) |>
243-
fit(train_df)
244-
}
245-
246-
fited_tops <- 1:5 |>
247-
map(\(x) finalize_fit_tops(parameters = top_5_results[x,], tune_wf))
300+
Now that we've identified the best-performing hyperparameters, our final step is to create and train the final model. We use `select_best()` to get the top parameters, `finalize_workflow()` to update our workflow with them, and then `fit()` one last time on our full training dataset.
248301

249-
get_models <- function(fited_model){
250-
fited_model$fit$fit$fit$fit
251-
}
302+
```{r finalize-best-model}
303+
# Select the best hyperparameters
304+
best_hps <- select_best(tune_res, metric = "accuracy")
252305
253-
models <- fited_tops |> map(get_models)
306+
# Finalize the workflow with the best hyperparameters
307+
final_wf <- finalize_workflow(tune_wf, best_hps)
254308
255-
get_fit_histories <- function(fited_model){
256-
fited_model$fit$fit$fit$history
257-
}
258-
259-
fit_histories <- fited_tops |> map(get_fit_histories)
260-
261-
summary(models[[1]])
262-
plot(models[[1]], show_shapes = TRUE)
263-
plot(fit_histories[[1]])
309+
# Fit the final model on the full training data
310+
final_fit <- fit(final_wf, data = train_df)
264311
```
265312

266-
### Top 5 Model Summaries
267-
268-
```{r tops-summary, results='asis'}
269-
# Loop through each model and print its summary
270-
for (i in 1:length(models)) {
271-
if (i == 1) {
272-
cat("::: {.grid}")
273-
} else if (i%%2 == 0) {
274-
cat("::: {.grid}")
275-
cat("::: {.g-col-6}")
276-
} else {
277-
cat("::: {.g-col-6}")
278-
}
279-
cat(paste0("\n\n#### Rank ", i, " Model Summary\n\n"))
280-
capture.output(summary(models[[i]])) |> cat(sep="\n")
281-
if (i == 1 || i%%2 != 0) {
282-
cat(":::")
283-
} else {
284-
cat(":::")
285-
cat(":::")
286-
}
287-
}
288-
```
289-
290-
### Top 5 Model Architectures
291-
292-
```{r tops-models-plot, fig.height=20}
293-
# Use par(mfrow) to create a grid for the base plots
294-
mat <- matrix(
295-
c(1, 1, 2, 3, 4, 5),
296-
nrow = 3,
297-
ncol = 2,
298-
byrow = TRUE
299-
)
300-
301-
layout(mat = mat)
302-
303-
for (i in 1:length(models)) {
304-
plot(models[[i]], show_shapes = TRUE)
305-
title(paste0("Rank ", i, " Model"))
306-
}
307-
```
308-
309-
### Top 5 Training Histories
310-
311-
```{r tops-models-fit-history, fig.height=12}
312-
# The history plots are ggplots, so we use patchwork to combine them
313-
library(patchwork)
314-
315-
design <- "A#
316-
BC
317-
DE"
313+
We can now inspect our final, tuned model.
318314

319-
plot_list <- purrr::map(1:length(fit_histories), \(i) {
320-
plot(fit_histories[[i]]) + labs(title = paste("Rank", i, "History"))
321-
})
315+
```{r inspect-final-model}
316+
# Print the model summary
317+
final_fit |>
318+
extract_fit_parsnip() |>
319+
extract_keras_summary()
322320
323-
# Combine all plots into a single image
324-
wrap_plots(plot_list, design = design)
321+
# Plot the training history
322+
final_fit |>
323+
extract_fit_parsnip() |>
324+
extract_keras_history() |>
325+
plot()
325326
```
326327

327-
This result shows that `tune` has tested various network depths (`num_hidden`), widths (`hidden_units`), and dropout rates, successfully finding the best-performing combination within the search space. This demonstrates how `kerasnip` integrates complex architectural tuning directly into the standard `tidymodels` framework.
328+
This result shows that `tune` has tested various network depths, widths, and dropout rates, successfully finding the best-performing combination within the search space. By using `kerasnip`, we were able to integrate this complex architectural tuning directly into a standard `tidymodels` workflow.

vignettes/images/MNIST.png

16 KB
Loading

0 commit comments

Comments
 (0)