33using Printf
44using Logging
55using TimerOutputs
6+ using DataStructures
67
78include (" pool/utils.jl" )
89using . PoolUtils
6465const usage_limit = PerDevice {Int} () do dev
6566 if haskey (ENV , " JULIA_CUDA_MEMORY_LIMIT" )
6667 parse (Int, ENV [" JULIA_CUDA_MEMORY_LIMIT" ])
67- elseif haskey (ENV , " CUARRAYS_MEMORY_LIMIT" )
68- Base. depwarn (" The CUARRAYS_MEMORY_LIMIT environment flag is deprecated, please use JULIA_CUDA_MEMORY_LIMIT instead." , :__init_pool__ )
69- parse (Int, ENV [" CUARRAYS_MEMORY_LIMIT" ])
7068 else
7169 typemax (Int)
7270 end
@@ -116,7 +114,8 @@ function hard_limit(dev::CuDevice)
116114 usage_limit[dev]
117115end
118116
119- function actual_alloc (dev:: CuDevice , bytes:: Integer , last_resort:: Bool = false )
117+ function actual_alloc (dev:: CuDevice , bytes:: Integer , last_resort:: Bool = false ;
118+ stream_ordered:: Bool = false )
120119 buf = @device! dev begin
121120 # check the memory allocation limit
122121 if usage[dev][] + bytes > (last_resort ? hard_limit (dev) : soft_limit (dev))
@@ -127,7 +126,7 @@ function actual_alloc(dev::CuDevice, bytes::Integer, last_resort::Bool=false)
127126 try
128127 time = Base. @elapsed begin
129128 @timeit_debug alloc_to " alloc" begin
130- buf = Mem. alloc (Mem. Device, bytes; async= true )
129+ buf = Mem. alloc (Mem. Device, bytes; async= true , stream_ordered )
131130 end
132131 end
133132
@@ -146,7 +145,7 @@ function actual_alloc(dev::CuDevice, bytes::Integer, last_resort::Bool=false)
146145 return Block (buf, bytes; state= AVAILABLE)
147146end
148147
149- function actual_free (dev:: CuDevice , block:: Block )
148+ function actual_free (dev:: CuDevice , block:: Block ; stream_ordered :: Bool = false )
150149 @assert iswhole (block) " Cannot free $block : block is not whole"
151150 @assert block. off == 0
152151 @assert block. state == AVAILABLE " Cannot free $block : block is not available"
@@ -155,7 +154,7 @@ function actual_free(dev::CuDevice, block::Block)
155154 # free the memory
156155 @timeit_debug alloc_to " free" begin
157156 time = Base. @elapsed begin
158- Mem. free (block. buf; async= true )
157+ Mem. free (block. buf; async= true , stream_ordered )
159158 end
160159 block. state = INVALID
161160
@@ -181,41 +180,49 @@ Show the timings of the currently active memory pool. Assumes
181180pool_timings () = (show (PoolUtils. to; allocations= false , sortby= :name ); println ())
182181
183182# pool API:
184- # - init()
185- # - alloc(::CuDevice , sz)::Block
186- # - free(::CuDevice , ::Block)
187- # - reclaim(::CuDevice , nb::Int=typemax(Int))::Int
188- # - cached_memory()
183+ # - constructor taking a CuDevice
184+ # - alloc(::AbstractPool , sz)::Block
185+ # - free(::AbstractPool , ::Block)
186+ # - reclaim(::AbstractPool , nb::Int=typemax(Int))::Int
187+ # - cached_memory(::AbstractPool )
189188
190189module Pool
191190@enum MemoryPool None Simple Binned Split
192191end
193- const active_pool = Ref {Pool.MemoryPool} ()
194- const async_alloc = Ref {Bool} ()
195-
196- macro pooled (ex)
197- @assert Meta. isexpr (ex, :call )
198- f, args... = ex. args
199- quote
200- if active_pool[] == Pool. None
201- NoPool.$ (f)($ (map (esc, args)... ))
202- elseif active_pool[] == Pool. Simple
203- SimplePool.$ (f)($ (map (esc, args)... ))
204- elseif active_pool[] == Pool. Binned
205- BinnedPool.$ (f)($ (map (esc, args)... ))
206- elseif active_pool[] == Pool. Split
207- SplitPool.$ (f)($ (map (esc, args)... ))
208- else
209- error (" unreachable" )
210- end
211- end
212- end
213192
193+ abstract type AbstractPool end
214194include (" pool/none.jl" )
215195include (" pool/simple.jl" )
216196include (" pool/binned.jl" )
217197include (" pool/split.jl" )
218198
199+ const pools = PerDevice {AbstractPool} (dev-> begin
200+ default_pool = if version () >= v " 11.2" &&
201+ attribute (dev, CUDA. DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED) == 1
202+ " cuda"
203+ else
204+ " binned"
205+ end
206+ pool_name = get (ENV , " JULIA_CUDA_MEMORY_POOL" , default_pool)
207+ pool = if pool_name == " none"
208+ NoPool (; dev, stream_ordered= false )
209+ elseif pool_name == " simple"
210+ SimplePool (; dev, stream_ordered= false )
211+ elseif pool_name == " binned"
212+ BinnedPool (; dev, stream_ordered= false )
213+ elseif pool_name == " split"
214+ SplitPool (; dev, stream_ordered= false )
215+ elseif pool_name == " cuda"
216+ @assert version () >= v " 11.2" " The CUDA memory pool is only supported on CUDA 11.2+"
217+ @assert (attribute (dev, CUDA. DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED) == 1 ,
218+ " Your device $(name (dev)) does not support the CUDA memory pool" )
219+ NoPool (; dev, stream_ordered= true )
220+ else
221+ error (" Invalid memory pool '$pool_name '" )
222+ end
223+ pool
224+ end )
225+
219226
220227# # interface
221228
@@ -263,11 +270,11 @@ a [`OutOfGPUMemoryError`](@ref) if the allocation request cannot be satisfied.
263270 sz == 0 && return CU_NULL
264271
265272 dev = device ()
273+ pool = pools[dev]
266274
267275 time = Base. @elapsed begin
268- @pool_timeit " pooled alloc" block = @pooled alloc (dev , sz)
276+ @pool_timeit " pooled alloc" block = alloc (pool , sz):: Union{Nothing,Block}
269277 end
270- block:: Union{Nothing,Block}
271278 block === nothing && throw (OutOfGPUMemoryError (sz))
272279
273280 # record the memory block
@@ -328,6 +335,7 @@ Releases a buffer pointed to by `ptr` to the memory pool.
328335 ptr == CU_NULL && return
329336
330337 dev = device ()
338+ pool = pools[dev]
331339 last_use[dev] = time ()
332340
333341 if MEMDEBUG && ptr == CuPtr {Cvoid} (0xbbbbbbbbbbbbbbbb )
@@ -359,7 +367,7 @@ Releases a buffer pointed to by `ptr` to the memory pool.
359367 end
360368
361369 time = Base. @elapsed begin
362- @pool_timeit " pooled free" @pooled free (dev , block)
370+ @pool_timeit " pooled free" free (pool , block)
363371 end
364372
365373 alloc_stats. pool_time += time
@@ -382,7 +390,8 @@ actually reclaimed.
382390"""
383391function reclaim (sz:: Int = typemax (Int))
384392 dev = device ()
385- @pooled reclaim (dev, sz)
393+ pool = pools[dev]
394+ reclaim (pool, sz)
386395end
387396
388397"""
@@ -403,6 +412,9 @@ macro retry_reclaim(isfailed, ex)
403412 ret = $ (esc (ex))
404413 $ (esc (isfailed))(ret) || break
405414
415+ dev = device ()
416+ pool = pools[dev]
417+
406418 # incrementally more costly reclaim of cached memory
407419 if phase == 1
408420 reclaim ()
@@ -412,11 +424,10 @@ macro retry_reclaim(isfailed, ex)
412424 elseif phase == 3
413425 GC. gc (true )
414426 reclaim ()
415- elseif phase == 4 && async_alloc[]
427+ elseif phase == 4 && pool . stream_ordered
416428 # this phase is unique to retry_reclaim, as regular allocations come from the pool
417429 # so are assumed to never need to trim its contents.
418- pool = memory_pool (device ())
419- trim (pool)
430+ trim (memory_pool (device ()))
420431 end
421432 end
422433 ret
@@ -445,7 +456,8 @@ function pool_cleanup()
445456
446457 if t1- t0 > 300
447458 # the pool hasn't been used for a while, so reclaim unused buffers
448- @pooled reclaim (dev)
459+ pool = pools[dev]
460+ reclaim (pool)
449461 end
450462 end
451463
@@ -561,7 +573,10 @@ macro timed(ex)
561573 end
562574end
563575
564- cached_memory () = @pooled cached_memory ()
576+ function cached_memory (dev:: CuDevice = device ())
577+ pool = pools[dev]
578+ cached_memory (pool)
579+ end
565580
566581"""
567582 memory_status([io=stdout])
@@ -584,10 +599,11 @@ function memory_status(io::IO=stdout)
584599 end
585600 println (io)
586601
587- alloc_used_bytes = used_memory ()
588- alloc_cached_bytes = cached_memory ()
602+ pool = pools[dev]
603+ alloc_used_bytes = used_memory (dev)
604+ alloc_cached_bytes = cached_memory (pool)
589605 alloc_total_bytes = alloc_used_bytes + alloc_cached_bytes
590- @printf (io, " Memory pool '%s' usage: %s (%s allocated, %s cached)\n " , string (active_pool[] ),
606+ @printf (io, " Memory pool '%s' usage: %s (%s allocated, %s cached)\n " , string (pool ),
591607 Base. format_bytes (alloc_total_bytes), Base. format_bytes (alloc_used_bytes),
592608 Base. format_bytes (alloc_cached_bytes))
593609
@@ -627,24 +643,8 @@ function __init_pool__()
627643 initialize! (allocated, ndevices ())
628644 initialize! (requested, ndevices ())
629645
630- # memory pool configuration
631- default_pool = version () >= v " 11.2" ? " cuda" : " binned"
632- pool_name = get (ENV , " JULIA_CUDA_MEMORY_POOL" , default_pool)
633- active_pool[], async_alloc[] = if pool_name == " none"
634- Pool. None, false
635- elseif pool_name == " simple"
636- Pool. Simple, false
637- elseif pool_name == " binned"
638- Pool. Binned, false
639- elseif pool_name == " split"
640- Pool. Split, false
641- elseif pool_name == " cuda"
642- @assert version () >= v " 11.2" " The CUDA memory pool is only supported on CUDA 11.2+"
643- Pool. None, true
644- else
645- error (" Invalid memory pool '$pool_name '" )
646- end
647- @pooled init ()
646+ # memory pools
647+ initialize! (pools, ndevices ())
648648
649649 TimerOutputs. reset_timer! (alloc_to)
650650 TimerOutputs. reset_timer! (PoolUtils. to)
@@ -660,6 +660,6 @@ function __init_pool__()
660660 end
661661
662662 if isinteractive ()
663- @async @pooled pool_cleanup ()
663+ @async pool_cleanup ()
664664 end
665665end
0 commit comments