33Flux provides a large number of common loss functions used for training machine learning models.
44They are grouped together in the ` Flux.Losses ` module.
55
6- Loss functions for supervised learning typically expect as inputs a target ` y ` , and a prediction ` ŷ ` .
7- In Flux's convention, the order of the arguments is the following
6+ As an example, the crossentropy function for multi-class classification that takes logit predictions (i.e. not [ ` softmax ` ] ( @ref ) ed)
7+ can be imported with
8+
9+ ``` julia
10+ using Flux. Losses: logitcrossentropy
11+ ```
12+
13+ Loss functions for supervised learning typically expect as inputs a true target ` y ` and a prediction ` ŷ ` ,
14+ typically passed as arrays of size ` num_target_features x num_examples_in_batch ` .
15+ In Flux's convention, the order of the arguments is the following:
816
917``` julia
1018loss (ŷ, y)
@@ -14,32 +22,16 @@ Most loss functions in Flux have an optional argument `agg`, denoting the type o
1422batch:
1523
1624``` julia
17- loss (ŷ, y) # defaults to `mean`
18- loss (ŷ, y, agg= sum) # use `sum` for reduction
19- loss (ŷ, y, agg= x -> sum (x, dims= 2 )) # partial reduction
20- loss (ŷ, y, agg= x -> mean (w .* x)) # weighted mean
21- loss (ŷ, y, agg= identity) # no aggregation.
25+ loss (ŷ, y) # defaults to `mean`
26+ loss (ŷ, y, agg = sum) # use `sum` for reduction
27+ loss (ŷ, y, agg = x -> sum (x, dims= 2 )) # partial reduction
28+ loss (ŷ, y, agg = x -> mean (w .* x)) # weighted mean
29+ loss (ŷ, y, agg = identity) # no aggregation.
2230```
2331
2432## Losses Reference
2533
26- ``` @docs
27- Flux.Losses.mae
28- Flux.Losses.mse
29- Flux.Losses.msle
30- Flux.Losses.huber_loss
31- Flux.Losses.label_smoothing
32- Flux.Losses.crossentropy
33- Flux.Losses.logitcrossentropy
34- Flux.Losses.binarycrossentropy
35- Flux.Losses.logitbinarycrossentropy
36- Flux.Losses.kldivergence
37- Flux.Losses.poisson_loss
38- Flux.Losses.hinge_loss
39- Flux.Losses.squared_hinge_loss
40- Flux.Losses.dice_coeff_loss
41- Flux.Losses.tversky_loss
42- Flux.Losses.binary_focal_loss
43- Flux.Losses.focal_loss
44- Flux.Losses.siamese_contrastive_loss
34+ ``` @autodocs
35+ Modules = [Flux.Losses]
36+ Pages = ["functions.jl"]
4537```
0 commit comments