@@ -112,22 +112,22 @@ function B_mul_X!(b::Blocks{false}, B, n = 0)
112112 return
113113end
114114
115- struct Constraint{T, TA <: AbstractArray {T} , TC}
116- Y:: TA
117- BY:: TA
115+ struct Constraint{T, TVorM <: Union{AbstractVector{T}, AbstractMatrix{T}} , TM <: AbstractMatrix {T} , TC}
116+ Y:: TVorM
117+ BY:: TVorM
118118 gram_chol:: TC
119- gramYBV:: TA # to be used in view
120- tmp:: TA # to be used in view
119+ gramYBV:: TM # to be used in view
120+ tmp:: TM # to be used in view
121121end
122122function Constraint (:: Void , B, X)
123- return Constraint {Void, Matrix{Void}, Void} (Matrix {Void} (0 ,0 ), Matrix {Void} (0 ,0 ), nothing , Matrix {Void} (0 ,0 ), Matrix {Void} (0 ,0 ))
123+ return Constraint {Void, Matrix{Void}, Matrix{Void}, Void} (Matrix {Void} (0 ,0 ), Matrix {Void} (0 ,0 ), nothing , Matrix {Void} (0 ,0 ), Matrix {Void} (0 ,0 ))
124124end
125125function Constraint (Y, B, X)
126126 T = eltype (X)
127127 if B isa Void
128- B = Y
128+ BY = Y
129129 else
130- BY = similar (B )
130+ BY = similar (Y )
131131 A_mul_B! (BY, B, Y)
132132 end
133133 gramYBY = Ac_mul_B (Y, BY)
@@ -136,21 +136,22 @@ function Constraint(Y, B, X)
136136 gramYBV = zeros (T, size (Y, 2 ), size (X, 2 ))
137137 tmp = similar (gramYBV)
138138
139- return Constraint (Y, BY, gramYBY_chol, gramYBV, tmp)
139+ return Constraint {eltype(Y), typeof(Y), typeof(gramYBV), typeof(gramYBY_chol)} (Y, BY, gramYBY_chol, gramYBV, tmp)
140140end
141141
142- function (constr!:: Constraint{Void} )(X)
142+ function (constr!:: Constraint{Void} )(X, X_temp )
143143 nothing
144144end
145145
146- function (constr!:: Constraint )(X)
146+ function (constr!:: Constraint )(X, X_temp )
147147 sizeX = size (X, 2 )
148148 sizeY = size (constr!. Y, 2 )
149149 gramYBV_view = view (constr!. gramYBV, 1 : sizeY, 1 : sizeX)
150150 Ac_mul_B! (gramYBV_view, constr!. BY, X)
151151 tmp_view = view (constr!. tmp, 1 : sizeY, 1 : sizeX)
152- A_ldiv_B! (tmp_view, gram_chol, gramYBV_view)
153- A_mul_B! (X, constr!. Y, tmp_view)
152+ A_ldiv_B! (tmp_view, constr!. gram_chol, gramYBV_view)
153+ A_mul_B! (X_temp, constr!. Y, tmp_view)
154+ @inbounds X .= X .- X_temp
154155
155156 nothing
156157end
@@ -200,9 +201,9 @@ PAP!(BlockGram, PBlocks, n) = Ac_mul_B!(view(BlockGram.PAP, 1:n, 1:n), view(PBlo
200201XBP! (BlockGram, XBlocks, PBlocks, n) = Ac_mul_B! (view (BlockGram. XAP, :, 1 : n), XBlocks. block, view (PBlocks. B_block, :, 1 : n))
201202XBR! (BlockGram, XBlocks, RBlocks, n) = Ac_mul_B! (view (BlockGram. XAR, :, 1 : n), XBlocks. block, view (RBlocks. B_block, :, 1 : n))
202203RBP! (BlockGram, RBlocks, PBlocks, n) = Ac_mul_B! (view (BlockGram. RAP, 1 : n, 1 : n), view (RBlocks. B_block, :, 1 : n), view (PBlocks. block, :, 1 : n))
203- XBX! (BlockGram, XBlocks) = Ac_mul_B! (BlockGram. XAX, XBlocks. block, XBlocks. B_block)
204- RBR! (BlockGram, RBlocks, n) = Ac_mul_B! (view (BlockGram. RAR, 1 : n, 1 : n), view (RBlocks. block, :, 1 : n), view (RBlocks. B_block, :, 1 : n))
205- PBP! (BlockGram, PBlocks, n) = Ac_mul_B! (view (BlockGram. PAP, 1 : n, 1 : n), view (PBlocks. block, :, 1 : n), view (PBlocks. B_block, :, 1 : n))
204+ # XBX!(BlockGram, XBlocks) = Ac_mul_B!(BlockGram.XAX, XBlocks.block, XBlocks.B_block)
205+ # RBR!(BlockGram, RBlocks, n) = Ac_mul_B!(view(BlockGram.RAR, 1:n, 1:n), view(RBlocks.block, :, 1:n), view(RBlocks.B_block, :, 1:n))
206+ # PBP!(BlockGram, PBlocks, n) = Ac_mul_B!(view(BlockGram.PAP, 1:n, 1:n), view(PBlocks.block, :, 1:n), view(PBlocks.B_block, :, 1:n))
206207
207208function I! (G, xr)
208209 @inbounds for j in xr, i in xr
@@ -242,24 +243,24 @@ function (g::BlockGram)(gram, n1::Int, n2::Int, n3::Int, normalized::Bool=true)
242243 if n1 > 0
243244 if normalized
244245 I! (gram, xr)
245- else
246- @inbounds gram[xr, xr] .= view (g. XAX, 1 : n1, 1 : n1)
246+ # else
247+ # @inbounds gram[xr, xr] .= view(g.XAX, 1:n1, 1:n1)
247248 end
248249 end
249250 if n2 > 0
250251 if normalized
251252 I! (gram, rr)
252- else
253- @inbounds gram[rr, rr] .= view (g. RAR, 1 : n2, 1 : n2)
253+ # else
254+ # @inbounds gram[rr, rr] .= view(g.RAR, 1:n2, 1:n2)
254255 end
255256 @inbounds gram[xr, rr] .= view (g. XAR, 1 : n1, 1 : n2)
256257 @inbounds conj! (transpose! (view (gram, rr, xr), view (g. XAR, 1 : n1, 1 : n2)))
257258 end
258259 if n3 > 0
259260 if normalized
260261 I! (gram, pr)
261- else
262- @inbounds gram[pr, pr] .= view (g. PAP, 1 : n3, 1 : n3)
262+ # else
263+ # @inbounds gram[pr, pr] .= view(g.PAP, 1:n3, 1:n3)
263264 end
264265 @inbounds gram[rr, pr] .= view (g. RAP, 1 : n2, 1 : n3)
265266 @inbounds gram[xr, pr] .= view (g. XAP, 1 : n1, 1 : n3)
@@ -463,10 +464,10 @@ function update_active!(mask, bs::Int, blockPairs...)
463464 return
464465end
465466
466- function precond_constr! (block, bs, precond!, constr!)
467+ function precond_constr! (block, temp_block, bs, precond!, constr!)
467468 precond! (view (block, :, 1 : bs))
468469 # Constrain the active residual vectors to be B-orthogonal to Y
469- constr! (view (block, :, 1 : bs))
470+ constr! (view (block, :, 1 : bs), view (temp_block, :, 1 : bs) )
470471 return
471472end
472473function block_grams_1x1! (iterator)
@@ -525,7 +526,6 @@ function sub_problem!(iterator, sizeX, bs1, bs2)
525526 selectperm! (view (iterator. λperm, 1 : subdim), eigf. values, 1 : subdim, rev= iterator. largest)
526527 @inbounds iterator. ritz_values[1 : sizeX] .= view (eigf. values, view (iterator. λperm, 1 : sizeX))
527528 @inbounds iterator. V[1 : subdim, 1 : sizeX] .= view (eigf. vectors, :, view (iterator. λperm, 1 : sizeX))
528-
529529 return
530530end
531531
@@ -594,7 +594,6 @@ function (iterator::LOBPCGIterator{Generalized})(residualTolerance, log) where {
594594 sizeX = size (iterator. XBlocks. block, 2 )
595595 iteration = iterator. iteration[]
596596 if iteration == 1
597- iterator. constr! (iterator. XBlocks. block)
598597 ortho_AB_mul_X! (iterator. XBlocks, iterator. ortho!, iterator. A, iterator. B)
599598 # Finds gram matrix X'AX
600599 block_grams_1x1! (iterator)
@@ -608,7 +607,7 @@ function (iterator::LOBPCGIterator{Generalized})(residualTolerance, log) where {
608607 # Update active R blocks
609608 update_active! (iterator. activeMask, bs, (iterator. activeRBlocks. block, iterator. RBlocks. block))
610609 # Precondition and constrain the active residual vectors
611- precond_constr! (iterator. activeRBlocks. block, bs, iterator. precond!, iterator. constr!)
610+ precond_constr! (iterator. activeRBlocks. block, iterator . tempXBlocks . block, bs, iterator. precond!, iterator. constr!)
612611 # Orthonormalizes R[:,1:bs] and finds AR[:,1:bs] and BR[:,1:bs]
613612 ortho_AB_mul_X! (iterator. activeRBlocks, iterator. ortho!, iterator. A, iterator. B, bs)
614613 # Find [X R] A [X R] and [X R]' B [X R]
@@ -628,7 +627,7 @@ function (iterator::LOBPCGIterator{Generalized})(residualTolerance, log) where {
628627 (iterator. activePBlocks. A_block, iterator. PBlocks. A_block),
629628 (iterator. activePBlocks. B_block, iterator. PBlocks. B_block))
630629 # Precondition and constrain the active residual vectors
631- precond_constr! (iterator. activeRBlocks. block, bs, iterator. precond!, iterator. constr!)
630+ precond_constr! (iterator. activeRBlocks. block, iterator . tempXBlocks . block, bs, iterator. precond!, iterator. constr!)
632631 # Orthonormalizes R[:,1:bs] and finds AR[:,1:bs] and BR[:,1:bs]
633632 ortho_AB_mul_X! (iterator. activeRBlocks, iterator. ortho!, iterator. A, iterator. B, bs)
634633 # Orthonormalizes P and updates AP
@@ -766,16 +765,18 @@ end
766765function lobpcg! (iterator:: LOBPCGIterator ; log= false , tol= nothing , maxiter= 200 , not_zeros= false )
767766 T = eltype (iterator. XBlocks. block)
768767 X = iterator. XBlocks. block
768+ iterator. constr! (iterator. XBlocks. block, iterator. tempXBlocks. block)
769769 if ! not_zeros
770770 for j in 1 : size (X,2 )
771771 if all (x -> x== 0 , view (X, :, j))
772772 @inbounds X[:,j] .= rand .()
773773 end
774774 end
775+ iterator. constr! (iterator. XBlocks. block, iterator. tempXBlocks. block)
775776 end
776777 n = size (X, 1 )
777778 sizeX = size (X, 2 )
778- residualTolerance = (tol isa Void) ? sqrt (eps (real (T))) : tol
779+ residualTolerance = (tol isa Void) ? (eps (real (T))) ^ ( real (T)( 4 ) / 10 ) : tol
779780 iterator. iteration[] = 1
780781 while iterator. iteration[] <= maxiter
781782 state = iterator (residualTolerance, log)
0 commit comments