Skip to content

Commit 71bfaec

Browse files
authored
Support r2r (#1)
* Support r2r * overload mul! * Use matrix transforms * add matrix tests * add tests * Update runtests.jl * Update Project.toml
1 parent e3bcf26 commit 71bfaec

File tree

4 files changed

+74
-58
lines changed

4 files changed

+74
-58
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.0.1"
66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
910

1011
[compat]
1112
AbstractFFTs = "1"
@@ -14,8 +15,8 @@ ForwardDiff = "0.10"
1415
julia = "1.6"
1516

1617
[extras]
18+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1719
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
18-
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1920

2021
[targets]
21-
test = ["Test", "FFTW"]
22+
test = ["Test", "LinearAlgebra"]

src/FastTransformsForwardDiff.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
module FastTransformsForwardDiff
2-
using ForwardDiff
3-
import AbstractFFTs
2+
using ForwardDiff, FFTW
3+
using AbstractFFTs
44
import ForwardDiff: value, partials, npartials, Dual, tagtype, derivative, jacobian, gradient
5+
import AbstractFFTs: plan_fft, plan_ifft, plan_bfft, plan_rfft, plan_brfft, plan_irfft
6+
import FFTW: r2r, r2r!, plan_r2r, mul!, Plan
57

68
@inline tagtype(::Complex{T}) where T = tagtype(T)
79
@inline tagtype(::Type{Complex{T}}) where T = tagtype(T)

src/fft.jl

Lines changed: 19 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x)
2+
dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x)))
3+
array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x))
4+
array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x)))
5+
16
value(x::Complex{<:Dual}) = Complex(x.re.value, x.im.value)
27

38
partials(x::Complex{<:Dual}, n::Int) = Complex(partials(x.re, n), partials(x.im, n))
@@ -12,70 +17,32 @@ AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V
1217
AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x)
1318
AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d)
1419

15-
for plan in [:plan_fft, :plan_ifft, :plan_bfft]
20+
for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft)
1621
@eval begin
17-
18-
AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
19-
AbstractFFTs.$plan(value.(x), region)
20-
21-
AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, region=1:ndims(x)) =
22-
AbstractFFTs.$plan(value.(x), region)
23-
22+
$plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims)
23+
$plan(x::AbstractArray{<:Complex{<:Dual}}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims)
2424
end
2525
end
2626

27-
# rfft only accepts real arrays
28-
AbstractFFTs.plan_rfft(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
29-
AbstractFFTs.plan_rfft(value.(x), region)
27+
plan_r2r(x::AbstractArray{<:Dual}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims)
28+
plan_r2r(x::AbstractArray{<:Complex{<:Dual}}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims)
3029

31-
for plan in [:plan_irfft, :plan_brfft] # these take an extra argument, only when complex?
30+
for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex?
3231
@eval begin
33-
34-
AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
35-
AbstractFFTs.$plan(value.(x), region)
36-
37-
AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, region=1:ndims(x)) =
38-
AbstractFFTs.$plan(value.(x), d, region)
39-
32+
$plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims)
33+
$plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, dims=1:ndims(x)) = $plan(dual2array(x), d, 1 .+ dims)
4034
end
4135
end
4236

43-
# for f in (:dct, :idct)
44-
# pf = Symbol("plan_", f)
45-
# @eval begin
46-
# AbstractFFTs.$f(x::AbstractArray{<:Dual}) = $pf(x) * x
47-
# AbstractFFTs.$f(x::AbstractArray{<:Dual}, region) = $pf(x, region) * x
48-
# AbstractFFTs.$pf(x::AbstractArray{<:Dual}, region; kws...) = $pf(value.(x), region; kws...)
49-
# AbstractFFTs.$pf(x::AbstractArray{<:Complex}, region; kws...) = $pf(value.(x), region; kws...)
50-
# end
51-
# end
37+
r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x
38+
r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x
5239

5340

54-
for P in [:Plan, :ScaledPlan] # need ScaledPlan to avoid ambiguities
41+
for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities
5542
@eval begin
56-
57-
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Dual}) =
58-
_apply_plan(p, x)
59-
60-
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{<:Dual}}) =
61-
_apply_plan(p, x)
62-
43+
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x))
44+
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x))
6345
end
6446
end
6547

66-
function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray)
67-
xtil = p * value.(x)
68-
dxtils = ntuple(npartials(eltype(x))) do n
69-
p * partials.(x, n)
70-
end
71-
__apply_plan(tagtype(eltype(x)), xtil, dxtils)
72-
end
73-
74-
function __apply_plan(T, xtil, dxtils)
75-
map(xtil, dxtils...) do val, parts...
76-
Complex(
77-
Dual{T}(real(val), map(real, parts)),
78-
Dual{T}(imag(val), map(imag, parts)),
79-
)
80-
end
81-
end
48+
mul!(y::AbstractArray{<:Dual}, p::Plan, x::AbstractArray{<:Dual}) = copyto!(y, p*x)

test/runtests.jl

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,36 @@
1-
using FastTransformsForwardDiff, FFTW, Test
1+
using FastTransformsForwardDiff, FFTW, LinearAlgebra, Test
22
using ForwardDiff: Dual, valtype, value, partials, derivative
33
using AbstractFFTs: complexfloat, realfloat
44

5+
@testset "complex dual" begin
6+
x = Dual(1., 2., 3.) + im*Dual(4.,5.,6.)
7+
@test value(x) == 1 + 4im
8+
@test partials(x,1) == 2 + 5im
9+
@test partials(x,2) == 3 + 6im
10+
end
511

612
@testset "fft and rfft" begin
713
x1 = Dual.(1:4.0, 2:5, 3:6)
814

915
@test value.(x1) == 1:4
1016
@test partials.(x1, 1) == 2:5
17+
@test partials.(x1, 2) == 3:6
1118

1219
@test complexfloat(x1)[1] === complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im
1320
@test realfloat(x1)[1] === realfloat(x1[1]) === Dual(1.0, 2.0, 3.0)
1421

1522
@test fft(x1, 1)[1] isa Complex{<:Dual}
1623

17-
@testset "$f" for f in [fft, ifft, rfft, bfft]
24+
@testset "$f" for f in (fft, ifft, rfft, bfft)
1825
@test value.(f(x1)) == f(value.(x1))
1926
@test partials.(f(x1), 1) == f(partials.(x1, 1))
27+
@test partials.(f(x1), 2) == f(partials.(x1, 2))
2028
end
2129

30+
@test ifft(fft(x1)) == x1
31+
@test irfft(rfft(x1), length(x1)) x1
32+
@test brfft(rfft(x1), length(x1)) 4x1
33+
2234
f = x -> real(fft([x; 0; 0])[1])
2335
@test derivative(f,0.1) 1
2436

@@ -33,4 +45,38 @@ using AbstractFFTs: complexfloat, realfloat
3345

3446
# c = x -> dct([x; 0; 0])[1]
3547
# @test derivative(c,0.1) ≈ 1
48+
49+
@testset "matrix" begin
50+
A = x1 * (1:10)'
51+
@test value.(fft(A)) == fft(value.(A))
52+
@test partials.(fft(A), 1) == fft(partials.(A, 1))
53+
@test partials.(fft(A), 2) == fft(partials.(A, 2))
54+
55+
@test value.(fft(A, 1)) == fft(value.(A), 1)
56+
@test partials.(fft(A, 1), 1) == fft(partials.(A, 1), 1)
57+
@test partials.(fft(A, 1), 2) == fft(partials.(A, 2), 1)
58+
59+
@test value.(fft(A, 2)) == fft(value.(A), 2)
60+
@test partials.(fft(A, 2), 1) == fft(partials.(A, 1), 2)
61+
@test partials.(fft(A, 2), 2) == fft(partials.(A, 2), 2)
62+
end
63+
end
64+
65+
@testset "r2r" begin
66+
x1 = Dual.(1:4.0, 2:5, 3:6)
67+
t = FFTW.r2r(x1, FFTW.R2HC)
68+
69+
@test value.(t) == FFTW.r2r(value.(x1), FFTW.R2HC)
70+
@test partials.(t, 1) == FFTW.r2r(partials.(x1, 1), FFTW.R2HC)
71+
@test partials.(t, 2) == FFTW.r2r(partials.(x1, 2), FFTW.R2HC)
72+
73+
t = FFTW.r2r(x1 + 2im*x1, FFTW.R2HC)
74+
@test value.(t) == FFTW.r2r(value.(x1 + 2im*x1), FFTW.R2HC)
75+
@test partials.(t, 1) == FFTW.r2r(partials.(x1 + 2im*x1, 1), FFTW.R2HC)
76+
@test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC)
77+
78+
f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1]
79+
@test derivative(f, 0.1) 1.0
80+
81+
@test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC)
3682
end

0 commit comments

Comments
 (0)