@@ -13,38 +13,41 @@ struct MultiLevel{S, Pre, Post, TA, TP, TR, TW}
1313 workspace:: TW
1414end
1515
16- struct MultiLevelWorkspace{T , bs}
17- coarse_xs:: Vector{Vector{Vector{T}} }
18- coarse_bs:: Vector{Vector{Vector{T}} }
19- res_vecs:: Vector{Vector{Vector{T}} }
16+ struct MultiLevelWorkspace{TX , bs}
17+ coarse_xs:: Vector{TX }
18+ coarse_bs:: Vector{TX }
19+ res_vecs:: Vector{TX }
2020end
2121function MultiLevelWorkspace (:: Type{Val{bs}} , :: Type{T} ) where {bs, T<: Number }
22- MultiLevelWorkspace {T, bs} ( Vector{Vector{Vector{T}}}[],
23- Vector{Vector{Vector{T}}}[],
24- Vector{Vector{Vector{T}}}[])
22+ if bs === 1
23+ TX = Vector{T}
24+ else
25+ TX = Matrix{T}
26+ end
27+ MultiLevelWorkspace {TX, bs} (TX[], TX[], TX[])
2528end
26- Base. eltype (w:: MultiLevelWorkspace{T } ) where T = T
27- blocksize (w:: MultiLevelWorkspace{T , bs} ) where {T , bs} = bs
29+ Base. eltype (w:: MultiLevelWorkspace{TX } ) where TX = eltype (TX)
30+ blocksize (w:: MultiLevelWorkspace{TX , bs} ) where {TX , bs} = bs
2831
29- function residual! (m:: MultiLevelWorkspace{T , bs} , n) where {T , bs}
32+ function residual! (m:: MultiLevelWorkspace{TX , bs} , n) where {TX , bs}
3033 if bs === 1
31- push! (m. res_vecs, [ Vector {T} (undef, n) for _ in 1 : nthreads ()] )
34+ push! (m. res_vecs, TX (undef, n))
3235 else
33- push! (m. res_vecs, [ Vector {T} (undef, n, bs) for _ in 1 : nthreads ()] )
36+ push! (m. res_vecs, TX (undef, n, bs))
3437 end
3538end
36- function coarse_x! (m:: MultiLevelWorkspace{T , bs} , n) where {T , bs}
39+ function coarse_x! (m:: MultiLevelWorkspace{TX , bs} , n) where {TX , bs}
3740 if bs === 1
38- push! (m. coarse_xs, [ Vector {T} (undef, n) for _ in 1 : nthreads ()] )
41+ push! (m. coarse_xs, TX (undef, n))
3942 else
40- push! (m. coarse_xs, [ Vector {T} (undef, n, bs) for _ in 1 : nthreads ()] )
43+ push! (m. coarse_xs, TX (undef, n, bs))
4144 end
4245end
43- function coarse_b! (m:: MultiLevelWorkspace{T , bs} , n) where {T , bs}
46+ function coarse_b! (m:: MultiLevelWorkspace{TX , bs} , n) where {TX , bs}
4447 if bs === 1
45- push! (m. coarse_bs, [ Vector {T} (undef, n) for _ in 1 : nthreads ()] )
48+ push! (m. coarse_bs, TX (undef, n))
4649 else
47- push! (m. coarse_bs, [ Vector {T} (undef, n, bs) for _ in 1 : nthreads ()] )
50+ push! (m. coarse_bs, TX (undef, n, bs))
4851 end
4952end
5053
@@ -147,7 +150,7 @@ function solve!(x, ml::MultiLevel, b::AbstractArray{T},
147150 tol:: Float64 = 1e-5 ,
148151 verbose:: Bool = false ,
149152 log:: Bool = false ,
150- calculate_residual = false ) where {T}
153+ calculate_residual = true ) where {T}
151154
152155 A = length (ml) == 1 ? ml. final_A : ml. levels[1 ]. A
153156 V = promote_type (eltype (A), eltype (b))
@@ -184,14 +187,14 @@ function __solve!(x, ml, v::V, b, lvl)
184187 A = ml. levels[lvl]. A
185188 ml. presmoother (A, x, b)
186189
187- res = ml. workspace. res_vecs[lvl][ threadid ()]
190+ res = ml. workspace. res_vecs[lvl]
188191 mul! (res, A, x)
189192 reshape (res, size (b)) .= b .- reshape (res, size (b))
190193
191- coarse_b = ml. workspace. coarse_bs[lvl][ threadid ()]
194+ coarse_b = ml. workspace. coarse_bs[lvl]
192195 mul! (coarse_b, ml. levels[lvl]. R, res)
193196
194- coarse_x = ml. workspace. coarse_xs[lvl][ threadid ()]
197+ coarse_x = ml. workspace. coarse_xs[lvl]
195198 coarse_x .= 0
196199 if lvl == length (ml. levels)
197200 ml. coarse_solver (coarse_x, coarse_b)
0 commit comments