You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
epochs="The total number of iterations to train the model.",
141
-
batch_size="The number of samples per gradient update.",
142
-
learn_rate="The learning rate for the default Adam optimizer. This is ignored if `compile_optimizer` is provided as a pre-built object.",
143
-
validation_split="The proportion of the training data to be used as a validation set.",
144
-
verbose="The level of verbosity for model fitting (0, 1, or 2)."
145
-
)
140
+
# Document special `learn_rate` param
146
141
param_docs<- c(
147
142
param_docs,
148
-
purrr::map_chr(global_params, function(p) {
149
-
paste0("@param ", p, "", global_param_desc[[p]])
150
-
})
143
+
"@param learn_rate The learning rate for the default Adam optimizer. This is ignored if `compile_optimizer` is provided as a pre-built Keras optimizer object."
151
144
)
152
145
153
146
# Document compile params
154
-
compile_param_desc<-list(
155
-
compile_loss="The loss function for compiling the model. Can be a string (e.g., 'mse') or a Keras loss object. Overrides the default.",
156
-
compile_optimizer="The optimizer for compiling the model. Can be a string (e.g., 'sgd') or a Keras optimizer object. Overrides the default.",
157
-
compile_metrics="A character vector of metrics to monitor during training (e.g., `c('mae', 'mse')`). Overrides the default."
158
-
)
159
-
param_docs<- c(
160
-
param_docs,
161
-
purrr::map_chr(compile_params, function(p) {
162
-
paste0("@param ", p, "", compile_param_desc[[p]])
163
-
})
164
-
)
147
+
if (length(compile_params) >0) {
148
+
param_docs<- c(
149
+
param_docs,
150
+
purrr::map_chr(compile_params, function(p) {
151
+
paste0(
152
+
"@param ",
153
+
p,
154
+
" Argument to `keras3::compile()`. See the 'Model Compilation' section."
155
+
)
156
+
})
157
+
)
158
+
}
159
+
160
+
# Document fit params
161
+
if (length(fit_params) >0) {
162
+
param_docs<- c(
163
+
param_docs,
164
+
purrr::map_chr(fit_params, function(p) {
165
+
paste0(
166
+
"@param ",
167
+
p,
168
+
" Argument to `keras3::fit()`. See the 'Model Fitting' section."
169
+
)
170
+
})
171
+
)
172
+
}
165
173
166
174
# Add ... param
167
175
param_docs<- c(
168
176
param_docs,
169
177
paste0(
170
-
"@param ... Additional arguments passed to the Keras engine. This is commonly used for arguments to `keras3::fit()` (prefixed with `fit_`).",
171
-
"See the 'Model Fitting' and 'Model Compilation' sections for details."
178
+
"@param ... Additional arguments passed to the Keras engine. Use this for arguments to `keras3::fit()` or `keras3::compile()`",
0 commit comments