Skip to content

Commit f5386f1

Browse files
clean losses
1 parent d317492 commit f5386f1

File tree

6 files changed

+144
-145
lines changed

6 files changed

+144
-145
lines changed

docs/make.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
using Documenter, Flux, NNlib, Functors, MLUtils
22

33
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)
4+
DocMeta.setdocmeta!(Flux.Losses, :DocTestSetup, :(using Flux.Losses); recursive = true)
5+
6+
# In the Losses module, doctests which differ in the printed Float32 values won't fail
7+
DocMeta.setdocmeta!(Flux.Losses, :DocTestFilters, :(r"[0-9\.]+f0"); recursive = true)
8+
49
makedocs(modules = [Flux, NNlib, Functors, MLUtils],
510
doctest = false,
611
sitename = "Flux",

docs/src/models/advanced.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,14 @@ model(gpu(rand(10)))
213213
A custom loss function for the multiple outputs may look like this:
214214
```julia
215215
using Statistics
216+
using Flux.Losses: mse
216217

217218
# assuming model returns the output of a Split
218219
# x is a single input
219220
# ys is a tuple of outputs
220221
function loss(x, ys, model)
221222
# rms over all the mse
222223
ŷs = model(x)
223-
return sqrt(mean(Flux.mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs)))
224+
return sqrt(mean(mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs)))
224225
end
225226
```

docs/src/models/losses.md

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,16 @@
33
Flux provides a large number of common loss functions used for training machine learning models.
44
They are grouped together in the `Flux.Losses` module.
55

6-
Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction ``.
7-
In Flux's convention, the order of the arguments is the following
6+
As an example, the crossentropy function for multi-class classification that takes logit predictions (i.e. not [`softmax`](@ref)ed)
7+
can be imported with
8+
9+
```julia
10+
using Flux.Losses: logitcrossentropy
11+
```
12+
13+
Loss functions for supervised learning typically expect as inputs a true target `y` and a prediction ``,
14+
typically passed as arrays of size `num_target_features x num_examples_in_batch`.
15+
In Flux's convention, the order of the arguments is the following:
816

917
```julia
1018
loss(ŷ, y)
@@ -14,32 +22,16 @@ Most loss functions in Flux have an optional argument `agg`, denoting the type o
1422
batch:
1523

1624
```julia
17-
loss(ŷ, y) # defaults to `mean`
18-
loss(ŷ, y, agg=sum) # use `sum` for reduction
19-
loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction
20-
loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean
21-
loss(ŷ, y, agg=identity) # no aggregation.
25+
loss(ŷ, y) # defaults to `mean`
26+
loss(ŷ, y, agg = sum) # use `sum` for reduction
27+
loss(ŷ, y, agg = x -> sum(x, dims=2)) # partial reduction
28+
loss(ŷ, y, agg = x -> mean(w .* x)) # weighted mean
29+
loss(ŷ, y, agg = identity) # no aggregation.
2230
```
2331

2432
## Losses Reference
2533

26-
```@docs
27-
Flux.Losses.mae
28-
Flux.Losses.mse
29-
Flux.Losses.msle
30-
Flux.Losses.huber_loss
31-
Flux.Losses.label_smoothing
32-
Flux.Losses.crossentropy
33-
Flux.Losses.logitcrossentropy
34-
Flux.Losses.binarycrossentropy
35-
Flux.Losses.logitbinarycrossentropy
36-
Flux.Losses.kldivergence
37-
Flux.Losses.poisson_loss
38-
Flux.Losses.hinge_loss
39-
Flux.Losses.squared_hinge_loss
40-
Flux.Losses.dice_coeff_loss
41-
Flux.Losses.tversky_loss
42-
Flux.Losses.binary_focal_loss
43-
Flux.Losses.focal_loss
44-
Flux.Losses.siamese_contrastive_loss
34+
```@autodocs
35+
Modules = [Flux.Losses]
36+
Pages = ["functions.jl"]
4537
```

src/Flux.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ include("outputsize.jl")
5151
include("data/Data.jl")
5252
using .Data
5353

54-
5554
include("losses/Losses.jl")
56-
using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12
55+
using .Losses # TODO: stop importing Losses in Flux's namespace in v0.14?
5756

5857
include("deprecations.jl")
5958

0 commit comments

Comments
 (0)