@@ -41,7 +41,27 @@ function aa_getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, Cint, min(siz
4141 Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
4242 m, n, A, lda, ipiv, info)
4343 info[] < 0 && throw (ArgumentError (" Invalid arguments sent to LAPACK dgetrf_" ))
44- A, Vector {BlasInt} (ipiv), BlasInt (info[]) # Error code is stored in LU factorization type
44+ A, ipiv, BlasInt (info[]), info # Error code is stored in LU factorization type
45+ end
46+
47+ function aa_getrs! (trans:: AbstractChar , A:: AbstractMatrix{<:Float64} , ipiv:: AbstractVector{Cint} , B:: AbstractVecOrMat{<:Float64} ; info = Ref {Cint} ())
48+ require_one_based_indexing (A, ipiv, B)
49+ LinearAlgebra. LAPACK. chktrans (trans)
50+ chkstride1 (A, B, ipiv)
51+ n = LinearAlgebra. checksquare (A)
52+ if n != size (B, 1 )
53+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
54+ end
55+ if n != length (ipiv)
56+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
57+ end
58+ nrhs = size (B, 2 )
59+ ccall ((" dgetrs_" , libacc), Cvoid,
60+ (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float64}, Ref{Cint},
61+ Ptr{Cint}, Ptr{Float64}, Ref{Cint}, Ptr{Cint}, Clong),
62+ trans, n, size (B,2 ), A, max (1 ,stride (A,2 )), ipiv, B, max (1 ,stride (B,2 )), info, 1 )
63+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
64+ B
4565end
4666
4767default_alias_A (:: AppleAccelerateLUFactorization , :: Any , :: Any ) = false
@@ -50,7 +70,8 @@ default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
5070function LinearSolve. init_cacheval (alg:: AppleAccelerateLUFactorization , A, b, u, Pl, Pr,
5171 maxiters:: Int , abstol, reltol, verbose:: Bool ,
5272 assumptions:: OperatorAssumptions )
53- ArrayInterface. lu_instance (convert (AbstractMatrix, A))
73+ luinst = ArrayInterface. lu_instance (convert (AbstractMatrix, A))
74+ LU (luinst. factors,similar (A, Cint, 0 ), luinst. info), Ref {Cint} ()
5475end
5576
5677function SciMLBase. solve! (cache:: LinearCache , alg:: AppleAccelerateLUFactorization ;
@@ -59,10 +80,23 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorizatio
5980 A = convert (AbstractMatrix, A)
6081 if cache. isfresh
6182 cacheval = @get_cacheval (cache, :AppleAccelerateLUFactorization )
62- fact = LU (aa_getrf! (A; ipiv = cacheval. ipiv)... )
83+ res = aa_getrf! (A; ipiv = cacheval[1 ]. ipiv, info = cacheval[2 ])
84+ fact = LU (res[1 : 3 ]. .. ), res[4 ]
6385 cache. cacheval = fact
6486 cache. isfresh = false
6587 end
66- y = ldiv! (cache. u, @get_cacheval (cache, :AppleAccelerateLUFactorization ), cache. b)
67- SciMLBase. build_linear_solution (alg, y, nothing , cache)
88+
89+ A, info = @get_cacheval (cache, :AppleAccelerateLUFactorization )
90+ LinearAlgebra. require_one_based_indexing (cache. u, cache. b)
91+ m, n = size (A, 1 ), size (A, 2 )
92+ if m > n
93+ Bc = copy (cache. b)
94+ aa_getrs! (' N' , A. factors, A. ipiv, Bc; info)
95+ return copyto! (cache. u, 1 , Bc, 1 , n)
96+ else
97+ copyto! (cache. u, cache. b)
98+ aa_getrs! (' N' , A. factors, A. ipiv, cache. u; info)
99+ end
100+
101+ SciMLBase. build_linear_solution (alg, cache. u, nothing , cache)
68102end
0 commit comments