@@ -35,13 +35,13 @@ const active_xt_handles = Vector{Union{Nothing,cublasXtHandle_t}}()
3535function handle()
3636 tid = Threads.threadid()
3737 if @inbounds active_handles[tid] === nothing
38- context = CuGetContext ()
39- active_handles[tid] = get!(created_handles, context ) do
38+ ctx = context ()
39+ active_handles[tid] = get!(created_handles, ctx ) do
4040 handle = cublasCreate_v2()
41- atexit(()->CUDAdrv.isvalid(context ) && cublasDestroy_v2(handle))
41+ atexit(()->CUDAdrv.isvalid(ctx ) && cublasDestroy_v2(handle))
4242
4343 # enable tensor math mode if our device supports it, and fast math is enabled
44- dev = CUDAdrv.device(context )
44+ dev = CUDAdrv.device()
4545 if Base.JLOptions().fast_math == 1 && CUDAdrv.capability(dev) >= v"7.0" && version() >= v"9"
4646 cublasSetMathMode(CUBLAS_TENSOR_OP_MATH, handle)
4747 end
5555function xt_handle()
5656 tid = Threads.threadid()
5757 if @inbounds active_xt_handles[tid] === nothing
58- CUDAnative.maybe_initialize("cublasXtGetHandle")
59- context = CuCurrentContext()
60- active_xt_handles[tid] = get!(created_xt_handles, context) do
58+ ctx = context()
59+ active_xt_handles[tid] = get!(created_xt_handles, ctx) do
6160 handle = cublasXtCreate()
62- atexit(()->CUDAdrv.isvalid(context ) && cublasXtDestroy(handle))
61+ atexit(()->CUDAdrv.isvalid(ctx ) && cublasXtDestroy(handle))
6362
6463 # select the devices
6564 # TODO: this is weird, since we typically use a single device per thread/context
@@ -79,7 +78,7 @@ function __init__()
7978 resize!(active_xt_handles, Threads.nthreads())
8079 fill!(active_xt_handles, nothing)
8180
82- CUDAnative.atcontextswitch() do tid, ctx, dev
81+ CUDAnative.atcontextswitch() do tid, ctx
8382 # we don't eagerly initialize handles, but do so lazily when requested
8483 active_handles[tid] = nothing
8584 active_xt_handles[tid] = nothing
0 commit comments