@@ -884,6 +884,114 @@ for (fname, elty) in ((:cusolverDnSpotrfBatched, :Float32),
884884 end
885885end
886886
887+ # gesv
888+ function gesv! (X:: CuVecOrMat{T} , A:: CuMatrix{T} , B:: CuVecOrMat{T} ; fallback:: Bool = true ,
889+ residual_history:: Bool = false , irs_precision:: String = " AUTO" , refinement_solver:: String = " CLASSICAL" ,
890+ maxiters:: Int = 0 , maxiters_inner:: Int = 0 , tol:: Float64 = 0.0 , tol_inner= Float64= 0.0 ) where T <: BlasFloat
891+
892+ params = CuSolverIRSParameters ()
893+ info = CuSolverIRSInformation ()
894+ n = checksquare (A)
895+ nrhs = size (B, 2 )
896+ lda = max (1 , stride (A, 2 ))
897+ ldb = max (1 , stride (B, 2 ))
898+ ldx = max (1 , stride (X, 2 ))
899+ niters = Ref {Cint} ()
900+ dh = dense_handle ()
901+
902+ if irs_precision == " AUTO"
903+ (T == Float32) && (irs_precision = " R_32F" )
904+ (T == Float64) && (irs_precision = " R_64F" )
905+ (T == ComplexF32) && (irs_precision = " C_32F" )
906+ (T == ComplexF64) && (irs_precision = " C_64F" )
907+ else
908+ (T == Float32) && (irs_precision ∈ (" R_32F" , " R_16F" , " R_16BF" , " R_TF32" ) || error (" $irs_precision is not supported." ))
909+ (T == Float64) && (irs_precision ∈ (" R_64F" , " R_32F" , " R_16F" , " R_16BF" , " R_TF32" ) || error (" $irs_precision is not supported." ))
910+ (T == ComplexF32) && (irs_precision ∈ (" C_32F" , " C_16F" , " C_16BF" , " C_TF32" ) || error (" $irs_precision is not supported." ))
911+ (T == ComplexF64) && (irs_precision ∈ (" C_64F" , " C_32F" , " C_16F" , " C_16BF" , " C_TF32" ) || error (" $irs_precision is not supported." ))
912+ end
913+ cusolverDnIRSParamsSetSolverMainPrecision (params, T)
914+ cusolverDnIRSParamsSetSolverLowestPrecision (params, irs_precision)
915+ cusolverDnIRSParamsSetRefinementSolver (params, refinement_solver)
916+ (tol != 0.0 ) && cusolverDnIRSParamsSetTol (params, tol)
917+ (tol_inner != 0.0 ) && cusolverDnIRSParamsSetTolInner (params, tol_inner)
918+ (maxiters != 0 ) && cusolverDnIRSParamsSetMaxIters (params, maxiters)
919+ (maxiters_inner != 0 ) && cusolverDnIRSParamsSetMaxItersInner (params, maxiters_inner)
920+ fallback ? cusolverDnIRSParamsEnableFallback (params) : cusolverDnIRSParamsDisableFallback (params)
921+ residual_history && cusolverDnIRSInfosRequestResidual (info)
922+
923+ function bufferSize ()
924+ buffer_size = Ref {Csize_t} (0 )
925+ cusolverDnIRSXgesv_bufferSize (dh, params, n, nrhs, buffer_size)
926+ return buffer_size[]
927+ end
928+
929+ with_workspace (dh. workspace_gpu, bufferSize) do buffer
930+ cusolverDnIRSXgesv (dh, params, info, n, nrhs, A, lda, B, ldb,
931+ X, ldx, buffer, sizeof (buffer), niters, dh. info)
932+ end
933+
934+ # Copy the solver flag and delete the device memory
935+ flag = @allowscalar dh. info[1 ]
936+ chklapackerror (flag |> BlasInt)
937+
938+ return X, info
939+ end
940+
941+ # gels
942+ function gels! (X:: CuVecOrMat{T} , A:: CuMatrix{T} , B:: CuVecOrMat{T} ; fallback:: Bool = true ,
943+ residual_history:: Bool = false , irs_precision:: String = " AUTO" , refinement_solver:: String = " CLASSICAL" ,
944+ maxiters:: Int = 0 , maxiters_inner:: Int = 0 , tol:: Float64 = 0.0 , tol_inner= Float64= 0.0 ) where T <: BlasFloat
945+
946+ params = CuSolverIRSParameters ()
947+ info = CuSolverIRSInformation ()
948+ m,n = size (A)
949+ nrhs = size (B, 2 )
950+ lda = max (1 , stride (A, 2 ))
951+ ldb = max (1 , stride (B, 2 ))
952+ ldx = max (1 , stride (X, 2 ))
953+ niters = Ref {Cint} ()
954+ dh = dense_handle ()
955+
956+ if irs_precision == " AUTO"
957+ (T == Float32) && (irs_precision = " R_32F" )
958+ (T == Float64) && (irs_precision = " R_64F" )
959+ (T == ComplexF32) && (irs_precision = " C_32F" )
960+ (T == ComplexF64) && (irs_precision = " C_64F" )
961+ else
962+ (T == Float32) && (irs_precision ∈ (" R_32F" , " R_16F" , " R_16BF" , " R_TF32" ) || error (" $irs_precision is not supported." ))
963+ (T == Float64) && (irs_precision ∈ (" R_64F" , " R_32F" , " R_16F" , " R_16BF" , " R_TF32" ) || error (" $irs_precision is not supported." ))
964+ (T == ComplexF32) && (irs_precision ∈ (" C_32F" , " C_16F" , " C_16BF" , " C_TF32" ) || error (" $irs_precision is not supported." ))
965+ (T == ComplexF64) && (irs_precision ∈ (" C_64F" , " C_32F" , " C_16F" , " C_16BF" , " C_TF32" ) || error (" $irs_precision is not supported." ))
966+ end
967+ cusolverDnIRSParamsSetSolverMainPrecision (params, T)
968+ cusolverDnIRSParamsSetSolverLowestPrecision (params, irs_precision)
969+ cusolverDnIRSParamsSetRefinementSolver (params, refinement_solver)
970+ (tol != 0.0 ) && cusolverDnIRSParamsSetTol (params, tol)
971+ (tol_inner != 0.0 ) && cusolverDnIRSParamsSetTolInner (params, tol_inner)
972+ (maxiters != 0 ) && cusolverDnIRSParamsSetMaxIters (params, maxiters)
973+ (maxiters_inner != 0 ) && cusolverDnIRSParamsSetMaxItersInner (params, maxiters_inner)
974+ fallback ? cusolverDnIRSParamsEnableFallback (params) : cusolverDnIRSParamsDisableFallback (params)
975+ residual_history && cusolverDnIRSInfosRequestResidual (info)
976+
977+ function bufferSize ()
978+ buffer_size = Ref {Csize_t} (0 )
979+ cusolverDnIRSXgels_bufferSize (dh, params, m, n, nrhs, buffer_size)
980+ return buffer_size[]
981+ end
982+
983+ with_workspace (dh. workspace_gpu, bufferSize) do buffer
984+ cusolverDnIRSXgels (dh, params, info, m, n, nrhs, A, lda, B, ldb,
985+ X, ldx, buffer, sizeof (buffer), niters, dh. info)
986+ end
987+
988+ # Copy the solver flag and delete the device memory
989+ flag = @allowscalar dh. info[1 ]
990+ chklapackerror (flag |> BlasInt)
991+
992+ return X, info
993+ end
994+
887995# LAPACK
888996for elty in (:Float32 , :Float64 , :ComplexF32 , :ComplexF64 )
889997 @eval begin
0 commit comments