You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -29,7 +29,7 @@ This new function behaves just like any other `parsnip` model (e.g., `rand_fores
29
29
30
30
You can install the development version of `kerasnip` from GitHub. You will also need `keras3` and a backend (like TensorFlow).
31
31
32
-
```r
32
+
```{r}
33
33
# install.packages("pak")
34
34
pak::pak("davidrsch/kerasnip")
35
35
pak::pak("rstudio/keras3")
@@ -48,7 +48,15 @@ library(keras3)
48
48
49
49
## A `kerasnip` MNIST Example
50
50
51
-
Let's replicate the standard Keras introductory example, an MLP on the MNIST dataset, but using the `kerasnip` workflow. This will show how to translate a standard Keras model into a reusable, modular `parsnip` specification.
51
+
Let’s replicate the classic Keras introductory example, training a simple MLP on the MNIST dataset, but using the `kerasnip` workflow. This will demonstrate how to translate a standard Keras model into a reusable, modular `parsnip` specification.
52
+
53
+
If you’re familiar with Keras, you’ll recognize the structure; if not, this is a perfect place to start. We’ll begin by learning the basics through a simple task: recognizing handwritten digits from the MNIST dataset.
54
+
55
+
The MNIST dataset contains 28×28 pixel grayscale images of handwritten digits, like these:
56
+
57
+
{fig-alt="A picture showing grayscale images of handwritten digits (5, 0, 4 and 1)"}
58
+
59
+
Each image comes with a label indicating which digit it represents. For example, the labels for the images above might be 5, 0, 4, and 1.
test_df <- data.frame(x = I(x_test), y = y_test_factor)
80
88
```
81
89
90
+
### The Standard Keras Approach (for comparison)
91
+
92
+
Before diving into the `kerasnip` workflow, let's quickly look at how this same model is built using standard `keras3` code. This will help highlight the different approach `kerasnip` enables.
model <- keras_model_sequential(input_shape = 784) |>
97
+
layer_dense(units = 256, activation = "relu") |>
98
+
layer_dropout(rate = 0.4) |>
99
+
layer_dense(units = 128, activation = "relu") |>
100
+
layer_dropout(rate = 0.3) |>
101
+
layer_dense(units = 10, activation = "softmax")
102
+
103
+
summary(model)
104
+
105
+
model |>
106
+
compile(
107
+
loss = "categorical_crossentropy",
108
+
optimizer = optimizer_rmsprop(),
109
+
metrics = "accuracy"
110
+
)
111
+
112
+
# The model would then be trained with model |> fit(...)
113
+
```
114
+
115
+
The code above is imperative: you define each layer and add it to the model step-by-step. Now, let's see how `kerasnip` approaches this by defining reusable components for a declarative, `tidymodels`-friendly workflow.
116
+
82
117
### Defining the Model with Reusable Blocks
83
118
84
119
The original Keras example interleaves `layer_dense()` and `layer_dropout()`. With `kerasnip`, we can encapsulate this pattern into a single, reusable block. This makes the overall architecture cleaner and more modular.
85
120
86
121
```{r define-blocks}
87
122
# An input block to initialize the model.
123
+
# The 'model' argument is supplied implicitly by the kerasnip backend.
88
124
mlp_input_block <- function(model, input_shape) {
89
125
keras_model_sequential(input_shape = input_shape)
90
126
}
91
127
92
128
# A reusable "module" that combines a dense layer and a dropout layer.
We can now use our new `mnist_mlp()` function. To replicate the `keras3` example, we want to repeat our `hidden` block twice with different parameters. `kerasnip` makes this easy: we set `num_hidden = 2` and pass vectors for the `hidden_units` and `hidden_rate` arguments. `kerasnip` will supply the first value to the first instance of the block, the second value to the second instance, and so on.
159
+
We can now use our new `mnist_mlp()` function. Notice how its arguments, such as `hidden_1_units` and `hidden_1_rate`, were automatically generated by `kerasnip`. The names are created by combining the name of the layer block (e.g., `hidden_1`) with the arguments of that block's function (e.g., `units`, `rate`).
160
+
161
+
To replicate the `keras3` example, we'll use both `hidden` blocks and provide their parameters.
125
162
126
163
```{r use-spec}
127
164
mlp_spec <- mnist_mlp(
@@ -140,36 +177,64 @@ mlp_spec <- mnist_mlp(
140
177
141
178
# Fit the model
142
179
mlp_fit <- fit(mlp_spec, y ~ x, data = train_df)
143
-
keras_model <- mlp_fit$fit$fit
144
-
training_history <- mlp_fit$fit$history
145
180
```
146
181
147
182
```{r model-summarize}
148
-
summary(keras_model)
183
+
mlp_fit |>
184
+
extract_keras_summary()
149
185
```
150
186
151
187
```{r model-plot}
152
-
plot(keras_model, show_shapes = TRUE)
188
+
mlp_fit |>
189
+
extract_keras_summary() |>
190
+
plot(show_shapes = TRUE)
153
191
```
154
192
155
193
```{r model-fit-history}
156
-
plot(training_history)
194
+
mlp_fit |>
195
+
extract_keras_history() |>
196
+
plot()
157
197
```
158
198
159
-
Evaluate the model’s performance on the test data: Evaluate method missing
199
+
### Evaluating Model Performance
200
+
201
+
The `keras_evaluate()` function provides a straightforward way to assess the model's performance on a test set, using the underlying `keras3::evaluate()` method. It returns the loss and any other metrics that were specified during the model compilation step.
160
202
161
203
```{r model-evaluate}
162
-
# keras_model |> evaluate(x_test, y_test)
204
+
mlp_fit |> keras_evaluate(x_test, y_test)
205
+
```
206
+
207
+
### Making Predictions
208
+
209
+
Once the model is trained, we can use the standard `tidymodels``predict()` function to generate predictions on new data. By default, `predict()` on a `parsnip` classification model returns the predicted class labels.
210
+
211
+
```{r model-predict-class}
212
+
# Predict the class for the first 5 images in the test set
213
+
class_preds <- mlp_fit |>
214
+
predict(new_data = head(test_df))
215
+
class_preds
163
216
```
164
217
165
-
Generate predictions on new data:
218
+
To get the underlying probabilities for each class, we can set `type = "prob"`. This returns a tibble with a probability column for each of the 10 classes (0-9).
166
219
167
-
```{r model-predict}
168
-
probs <- keras_model |> predict(x_test)
220
+
```{r model-predict-prob}
221
+
# Predict probabilities for the first 5 images
222
+
prob_preds <- mlp_fit |> predict(new_data = head(test_df), type = "prob")
223
+
prob_preds
169
224
```
170
225
171
-
```{r show-predictions}
172
-
max.col(probs) - 1L
226
+
We can then compare the predicted class to the actual class for these images to see how the model is performing.
227
+
228
+
```{r model-predict-compare}
229
+
# Combine predictions with actuals for comparison
230
+
comparison <- bind_cols(
231
+
class_preds,
232
+
prob_preds
233
+
) |>
234
+
bind_cols(
235
+
head(test_df[, "y", drop = FALSE])
236
+
)
237
+
comparison
173
238
```
174
239
175
240
## Example 2: Tuning the Model Architecture
@@ -180,11 +245,12 @@ Using the `mnist_mlp` spec we just created, let's define a tunable model.
180
245
181
246
```{r tune-spec-mnist}
182
247
# Define a tunable specification
248
+
# We set num_hidden_2 = 0 to disable the second hidden block for this tuning example
183
249
tune_spec <- mnist_mlp(
184
250
num_hidden_1 = tune(),
185
251
hidden_1_units = tune(),
186
252
hidden_1_rate = tune(),
187
-
num_hidden2 = 0,
253
+
num_hidden_2 = 0,
188
254
compile_loss = "categorical_crossentropy",
189
255
compile_optimizer = optimizer_rmsprop(),
190
256
compile_metrics = c("accuracy"),
@@ -231,97 +297,32 @@ Finally, we can inspect the results to find which architecture performed the bes
231
297
show_best(tune_res, metric = "accuracy")
232
298
```
233
299
234
-
Now, let's visualize the top 5 models from the tuning results in detail.
Now that we've identified the best-performing hyperparameters, our final step is to create and train the final model. We use `select_best()` to get the top parameters, `finalize_workflow()` to update our workflow with them, and then `fit()` one last time on our full training dataset.
plot(fit_histories[[i]]) + labs(title = paste("Rank", i, "History"))
321
-
})
315
+
```{r inspect-final-model}
316
+
# Print the model summary
317
+
final_fit |>
318
+
extract_fit_parsnip() |>
319
+
extract_keras_summary()
322
320
323
-
# Combine all plots into a single image
324
-
wrap_plots(plot_list, design = design)
321
+
# Plot the training history
322
+
final_fit |>
323
+
extract_fit_parsnip() |>
324
+
extract_keras_history() |>
325
+
plot()
325
326
```
326
327
327
-
This result shows that `tune` has tested various network depths (`num_hidden`), widths (`hidden_units`), and dropout rates, successfully finding the best-performing combination within the search space. This demonstrates how `kerasnip` integrates complex architectural tuning directly into the standard `tidymodels`framework.
328
+
This result shows that `tune` has tested various network depths, widths, and dropout rates, successfully finding the best-performing combination within the search space. By using `kerasnip`, we were able to integrate this complex architectural tuning directly into a standard `tidymodels`workflow.
0 commit comments