Skip to content

Commit 0eeb716

Browse files
wsmosesavik-palgdalle
authored
Add AutoReactant{AutoEnzyme} (#127)
* Add AutoReactant{AutoEnzyme} * Update index.md with new mode documentation Added documentation for forward, reverse, or sparse mode. * Update src/dense.jl Co-authored-by: Avik Pal <avik.pal.2017@gmail.com> * Fix documentation for AutoReactant mode parameter * Update dense.jl * Update parameterization description in AutoReactant Clarified the description of the 'mode' field in AutoReactant. * Apply suggestion from @gdalle --------- Co-authored-by: Avik Pal <avik.pal.2017@gmail.com> Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 1f78b26 commit 0eeb716

File tree

7 files changed

+78
-3
lines changed

7 files changed

+78
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ADTypes"
22
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com>, Guillaume Dalle and contributors"]
4-
version = "1.18.0"
4+
version = "1.19.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ AutoChainRules
5555
AutoDiffractor
5656
```
5757

58+
### Forward, reverse, or sparse mode
59+
60+
```@docs
61+
AutoReactant{<:AutoEnzyme}
62+
```
63+
5864
### Symbolic mode
5965

6066
```@docs

src/ADTypes.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ export AutoChainRules,
4747
AutoTracker,
4848
AutoZygote,
4949
NoAutoDiff,
50-
NoAutoDiffSelectedError
50+
NoAutoDiffSelectedError,
51+
AutoReactant
5152
@public AbstractMode
5253
@public ForwardMode, ReverseMode, ForwardOrReverseMode, SymbolicMode
5354
@public mode

src/dense.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,45 @@ function Base.show(io::IO, backend::AutoEnzyme{M, A}) where {M, A}
8282
print(io, ")")
8383
end
8484

85+
86+
"""
87+
AutoReactant{M<:AutoEnzyme}
88+
89+
Struct used to select the [Reactant.jl](https://github.com/EnzymeAD/Reactant.jl) compilation atop Enzyme for automatic differentiation.
90+
91+
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
92+
93+
# Constructors
94+
95+
AutoReactant(; mode::Union{AutoEnzyme,Nothing}=nothing)
96+
97+
# Fields
98+
99+
- `mode::M` specifies the parameterization of differentiation. It can be:
100+
101+
+ an [`AutoEnzyme`](@ref) object if a specific mode is required
102+
+ `nothing` to choose the best mode automatically
103+
"""
104+
struct AutoReactant{M<:AutoEnzyme} <: AbstractADType
105+
mode::M
106+
end
107+
108+
function AutoReactant(;
109+
mode::Union{AutoEnzyme,Nothing} = nothing)
110+
if mode === nothing
111+
mode = AutoEnzyme()
112+
end
113+
return AutoReactant(mode)
114+
end
115+
116+
mode(r::AutoReactant) = mode(r.mode)
117+
118+
function Base.show(io::IO, backend::AutoReactant)
119+
print(io, AutoReactant, "(")
120+
print(io, "mode=", repr(backend.mode; context = io))
121+
print(io, ")")
122+
end
123+
85124
"""
86125
AutoFastDifferentiation
87126

src/symbols.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ ADTypes.AutoZygote()
2222
"""
2323
Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...)
2424

25-
for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation,
25+
for backend in (:ChainRules, :Diffractor, :Enzyme, :Reactant, :FastDifferentiation,
2626
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :Mooncake, :PolyesterForwardDiff,
2727
:ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote)
2828
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(

test/dense.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,34 @@ end
5252
@test ad.mode == EnzymeCore.Reverse
5353
end
5454

55+
@testset "AutoReactant" begin
56+
ad = AutoReactant()
57+
@test ad isa AbstractADType
58+
@test ad isa AutoReactant{<:AutoEnzyme}
59+
@test ad.mode isa AutoEnzyme
60+
@test ad.mode.mode === nothing
61+
@test mode(ad) isa ForwardOrReverseMode
62+
63+
ad = AutoReactant(; mode=AutoEnzyme(; mode = EnzymeCore.Forward))
64+
@test ad isa AbstractADType
65+
@test ad isa AutoReactant{<:AutoEnzyme{typeof(EnzymeCore.Forward), Nothing}}
66+
@test mode(ad) isa ForwardMode
67+
@test ad.mode.mode == EnzymeCore.Forward
68+
69+
ad = AutoReactant(; mode=AutoEnzyme(; function_annotation = EnzymeCore.Const))
70+
@test ad isa AbstractADType
71+
@test ad isa AutoReactant{<:AutoEnzyme{Nothing, EnzymeCore.Const}}
72+
@test mode(ad) isa ForwardOrReverseMode
73+
@test ad.mode.mode === nothing
74+
75+
ad = AutoReactant(; mode=AutoEnzyme(;
76+
mode = EnzymeCore.Reverse, function_annotation = EnzymeCore.Duplicated))
77+
@test ad isa AbstractADType
78+
@test ad isa AutoReactant{<:AutoEnzyme{typeof(EnzymeCore.Reverse), EnzymeCore.Duplicated}}
79+
@test mode(ad) isa ReverseMode
80+
@test ad.mode.mode == EnzymeCore.Reverse
81+
end
82+
5583
@testset "AutoFastDifferentiation" begin
5684
ad = AutoFastDifferentiation()
5785
@test ad isa AbstractADType

test/symbols.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using Test
1111
@test ADTypes.Auto(:Mooncake) isa AutoMooncake
1212
@test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff
1313
@test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff
14+
@test ADTypes.Auto(:Reactant) isa AutoReactant
1415
@test ADTypes.Auto(:Symbolics) isa AutoSymbolics
1516
@test ADTypes.Auto(:Tapir) isa AutoTapir
1617
@test ADTypes.Auto(:Tracker) isa AutoTracker

0 commit comments

Comments
 (0)