|
41 | 41 | ex = Expr(:if, ex.args...) |
42 | 42 | end |
43 | 43 |
|
| 44 | +# Handle special case of Column-pivoted QR fallback for LU |
| 45 | +function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v::LinearAlgebra.QRPivoted) |
| 46 | + setfield!(cache, :QRFactorizationPivoted, v) |
| 47 | +end |
| 48 | + |
44 | 49 | # Legacy fallback |
45 | 50 | # For SciML algorithms already using `defaultalg`, all assume square matrix. |
46 | 51 | defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true)) |
@@ -352,11 +357,32 @@ end |
352 | 357 | kwargs...) |
353 | 358 | ex = :() |
354 | 359 | for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T)) |
355 | | - newex = quote |
356 | | - sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) |
357 | | - SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; |
358 | | - retcode = sol.retcode, |
359 | | - iters = sol.iters, stats = sol.stats) |
| 360 | + if alg in Symbol.((DefaultAlgorithmChoice.LUFactorization, |
| 361 | + DefaultAlgorithmChoice.RFLUFactorization, |
| 362 | + DefaultAlgorithmChoice.MKLLUFactorization, |
| 363 | + DefaultAlgorithmChoice.AppleAccelerateLUFactorization, |
| 364 | + DefaultAlgorithmChoice.GenericLUFactorization)) |
| 365 | + newex = quote |
| 366 | + sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) |
| 367 | + if sol.retcode === ReturnCode.Failure && alg.safetyfallback |
| 368 | + ## TODO: Add verbosity logging here about using the fallback |
| 369 | + sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...) |
| 370 | + SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; |
| 371 | + retcode = sol.retcode, |
| 372 | + iters = sol.iters, stats = sol.stats) |
| 373 | + else |
| 374 | + SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; |
| 375 | + retcode = sol.retcode, |
| 376 | + iters = sol.iters, stats = sol.stats) |
| 377 | + end |
| 378 | + end |
| 379 | + else |
| 380 | + newex = quote |
| 381 | + sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) |
| 382 | + SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; |
| 383 | + retcode = sol.retcode, |
| 384 | + iters = sol.iters, stats = sol.stats) |
| 385 | + end |
360 | 386 | end |
361 | 387 | alg_enum = getproperty(LinearSolve.DefaultAlgorithmChoice, alg) |
362 | 388 | ex = if ex == :() |
|
0 commit comments