@@ -5,10 +5,10 @@ using CUDAapi
55using CUDAdrv
66using CUDAdrv: CUstream
77
8- import CUDAnative
8+ using CUDAnative
99
1010using .. CuArrays
11- using .. CuArrays: active_context, unsafe_free!
11+ using .. CuArrays: unsafe_free!
1212using LinearAlgebra
1313
1414using CEnum
@@ -27,45 +27,62 @@ include("wrappers.jl")
2727# high-level integrations
2828include (" linalg.jl" )
2929
30- const _handles = Dict {CuContext,cublasHandle_t} ()
31- const _xt_handles = Dict {CuContext,cublasXtHandle_t} ()
32- const _handle = Ref { cublasHandle_t}( C_NULL )
33- const _xt_handle = Ref { cublasXtHandle_t}( C_NULL )
30+ const created_handles = IdDict {CuContext,cublasHandle_t} ()
31+ const created_xt_handles = IdDict {CuContext,cublasXtHandle_t} ()
32+ const active_handles = Vector {Union{Nothing, cublasHandle_t}} ( )
33+ const active_xt_handles = Vector {Union{Nothing, cublasXtHandle_t}} ( )
3434
3535function handle ()
36- if _handle[] == C_NULL
37- CUDAnative . maybe_initialize ( " CUBLAS " )
38- _handle[] = get! (_handles, active_context[]) do
39- context = active_context[]
36+ tid = Threads . threadid ()
37+ if @inbounds active_handles[tid] === nothing
38+ ctx = context ()
39+ active_handles[tid] = get! (created_handles, ctx) do
4040 handle = cublasCreate_v2 ()
41+ atexit (()-> CUDAdrv. isvalid (ctx) && cublasDestroy_v2 (handle))
4142
4243 # enable tensor math mode if our device supports it, and fast math is enabled
43- dev = CUDAdrv. device (context )
44+ dev = CUDAdrv. device ()
4445 if Base. JLOptions (). fast_math == 1 && CUDAdrv. capability (dev) >= v " 7.0" && version () >= v " 9"
4546 cublasSetMathMode (CUBLAS_TENSOR_OP_MATH, handle)
4647 end
4748
48- atexit (()-> CUDAdrv. isvalid (context) && cublasDestroy_v2 (handle))
4949 handle
5050 end
5151 end
52-
53- return _handle[]
52+ @inbounds active_handles[tid]
5453end
5554
5655function xt_handle ()
57- if _xt_handle[] == C_NULL
58- @assert isassigned (active_context) # some other call should have initialized CUDA
59- _xt_handle[] = get! (_xt_handles, active_context[]) do
60- context = active_context[]
56+ tid = Threads . threadid ()
57+ if @inbounds active_xt_handles[tid] === nothing
58+ ctx = context ()
59+ active_xt_handles[tid] = get! (created_xt_handles, ctx) do
6160 handle = cublasXtCreate ()
61+ atexit (()-> CUDAdrv. isvalid (ctx) && cublasXtDestroy (handle))
62+
63+ # select the devices
64+ # TODO : this is weird, since we typically use a single device per thread/context
6265 devs = convert .(Cint, CUDAdrv. devices ())
6366 cublasXtDeviceSelect (handle, length (devs), devs)
64- atexit (() -> CUDAdrv . isvalid (context) && cublasXtDestroy (handle))
67+
6568 handle
6669 end
6770 end
68- return _xt_handle[]
71+ @inbounds active_xt_handles[tid]
72+ end
73+
74+ function __init__ ()
75+ resize! (active_handles, Threads. nthreads ())
76+ fill! (active_handles, nothing )
77+
78+ resize! (active_xt_handles, Threads. nthreads ())
79+ fill! (active_xt_handles, nothing )
80+
81+ CUDAnative. atcontextswitch () do tid, ctx
82+ # we don't eagerly initialize handles, but do so lazily when requested
83+ active_handles[tid] = nothing
84+ active_xt_handles[tid] = nothing
85+ end
6986end
7087
7188end
0 commit comments