Skip to content

Commit 8c5bcbc

Browse files
Copilotyebai
andauthored
Implement forward rule (frule!!) for find_alpha with integer arguments in Mooncake (#407)
* Initial plan * Initial analysis and plan for implementing find_alpha forward rule Co-authored-by: yebai <3279477+yebai@users.noreply.github.com> * Update BijectorsMooncakeExt.jl * Update Project.toml --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Hong Ge <hg344@cam.ac.uk> Co-authored-by: yebai <3279477+yebai@users.noreply.github.com>
1 parent 7cf56cf commit 8c5bcbc

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.15.9"
3+
version = "0.15.10"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

ext/BijectorsMooncakeExt.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,33 @@ using Bijectors: find_alpha, ChainRulesCore
1515
# unusual Integer type is encountered.
1616
@is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat})
1717

18-
# TODO: This needs a corresponding frule!! as well for it to work on forward-mode Mooncake.
18+
function Mooncake.frule!!(
19+
::Mooncake.Dual{typeof(find_alpha)},
20+
x::Mooncake.Dual{P},
21+
y::Mooncake.Dual{P},
22+
z::Mooncake.Dual{I},
23+
) where {P<:Base.IEEEFloat,I<:Integer}
24+
# Require that the integer is non-differentiable.
25+
if tangent_type(I) != Mooncake.NoTangent
26+
msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent."
27+
throw(ArgumentError(msg))
28+
end
29+
# Convert Mooncake.NoTangent to ChainRulesCore.NoTangent for the integer argument
30+
out, tangent_out = ChainRulesCore.frule(
31+
(
32+
ChainRulesCore.NoTangent(),
33+
Mooncake.tangent(x),
34+
Mooncake.tangent(y),
35+
ChainRulesCore.NoTangent(),
36+
),
37+
find_alpha,
38+
Mooncake.primal(x),
39+
Mooncake.primal(y),
40+
Mooncake.primal(z),
41+
)
42+
return Mooncake.Dual(out, tangent_out)
43+
end
44+
1945
function Mooncake.rrule!!(
2046
::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I}
2147
) where {P<:Base.IEEEFloat,I<:Integer}

test/ad/chainrules.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ end
3131

3232
if @isdefined Mooncake
3333
rng = Xoshiro(123456)
34-
# TODO: Enable Mooncake.ForwardMode as well.
35-
@testset "$mode" for mode in (Mooncake.ReverseMode,)
34+
@testset "$mode" for mode in (Mooncake.ReverseMode, Mooncake.ForwardMode)
3635
Mooncake.TestUtils.test_rule(
3736
rng,
3837
Bijectors.find_alpha,

0 commit comments

Comments
 (0)