Skip to content

Commit 9fe3015

Browse files
authored
Merge pull request #745 from JuliaGPU/tb/alloc_opt
Use the default memory pool.
2 parents a8f3043 + a89988e commit 9fe3015

File tree

5 files changed

+66
-42
lines changed

5 files changed

+66
-42
lines changed

lib/cudadrv/pool.jl

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
# Stream-orderdered memory allocator
22

3-
export CuMemoryPool, trim
3+
export CuMemoryPool, default_memory_pool, memory_pool, memory_pool!, trim
44

55
mutable struct CuMemoryPool
66
handle::CUmemoryPool
77
ctx::CuContext
8+
end
89

9-
function CuMemoryPool(dev::CuDevice)
10-
props = Ref(CUmemPoolProps(
11-
CU_MEM_ALLOCATION_TYPE_PINNED,
12-
CU_MEM_HANDLE_TYPE_NONE,
13-
CUmemLocation(
14-
CU_MEM_LOCATION_TYPE_DEVICE,
15-
deviceid(dev)
16-
),
17-
C_NULL,
18-
ntuple(i->Cuchar(0), 64)
19-
))
20-
handle_ref = Ref{CUmemoryPool}()
21-
cuMemPoolCreate(handle_ref, props)
22-
23-
ctx = CuCurrentContext()
24-
obj = new(handle_ref[], ctx)
25-
finalizer(unsafe_destroy!, obj)
26-
return obj
27-
end
10+
function CuMemoryPool(dev::CuDevice)
11+
props = Ref(CUmemPoolProps(
12+
CU_MEM_ALLOCATION_TYPE_PINNED,
13+
CU_MEM_HANDLE_TYPE_NONE,
14+
CUmemLocation(
15+
CU_MEM_LOCATION_TYPE_DEVICE,
16+
deviceid(dev)
17+
),
18+
C_NULL,
19+
ntuple(i->Cuchar(0), 64)
20+
))
21+
handle_ref = Ref{CUmemoryPool}()
22+
cuMemPoolCreate(handle_ref, props)
23+
24+
ctx = CuCurrentContext()
25+
obj = CuMemoryPool(handle_ref[], ctx)
26+
finalizer(unsafe_destroy!, obj)
27+
return obj
2828
end
2929

3030
function unsafe_destroy!(pool::CuMemoryPool)
@@ -38,4 +38,22 @@ Base.unsafe_convert(::Type{CUmemoryPool}, pool::CuMemoryPool) = pool.handle
3838
Base.:(==)(a::CuMemoryPool, b::CuMemoryPool) = a.handle == b.handle
3939
Base.hash(pool::CuMemoryPool, h::UInt) = hash(pool.handle, h)
4040

41+
function default_memory_pool(dev::CuDevice)
42+
handle_ref = Ref{CUmemoryPool}()
43+
cuDeviceGetDefaultMemPool(handle_ref, dev)
44+
45+
ctx = CuCurrentContext()
46+
CuMemoryPool(handle_ref[], ctx)
47+
end
48+
49+
function memory_pool(dev::CuDevice)
50+
handle_ref = Ref{CUmemoryPool}()
51+
cuDeviceGetMemPool(handle_ref, dev)
52+
53+
ctx = CuCurrentContext()
54+
CuMemoryPool(handle_ref[], ctx)
55+
end
56+
57+
memory_pool!(dev::CuDevice, pool::CuMemoryPool) = cuDeviceSetMemPool(dev, pool)
58+
4159
trim(pool::CuMemoryPool, bytes_to_keep::Integer=0) = cuMemPoolTrimTo(pool, bytes_to_keep)

src/pool.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,7 @@ function actual_alloc(dev::CuDevice, bytes::Integer, last_resort::Bool=false)
127127
try
128128
time = Base.@elapsed begin
129129
@timeit_debug alloc_to "alloc" begin
130-
buf = Mem.alloc(Mem.Device, bytes; async=true,
131-
pool = (async_alloc[] ? pool() : nothing))
132-
# we only need a memory pool when we'll be using the async allocator.
133-
# this avoids a needless warning when running under cuda-memcheck,
134-
# which doesn't support the stream-ordered memory allocator.
130+
buf = Mem.alloc(Mem.Device, bytes; async=true)
135131
end
136132
end
137133

@@ -419,7 +415,8 @@ macro retry_reclaim(isfailed, ex)
419415
elseif phase == 4 && async_alloc[]
420416
# this phase is unique to retry_reclaim, as regular allocations come from the pool
421417
# so are assumed to never need to trim its contents.
422-
trim(pool())
418+
pool = memory_pool(device())
419+
trim(pool)
423420
end
424421
end
425422
ret
@@ -652,6 +649,16 @@ function __init_pool__()
652649
TimerOutputs.reset_timer!(alloc_to)
653650
TimerOutputs.reset_timer!(PoolUtils.to)
654651

652+
if isdebug(:init, CUDA)
653+
TimerOutputs.enable_debug_timings(CUDA)
654+
atexit() do
655+
println("Memory pool timings:")
656+
pool_timings()
657+
println("Allocator timings:")
658+
alloc_timings()
659+
end
660+
end
661+
655662
if isinteractive()
656663
@async @pooled pool_cleanup()
657664
end

src/pool/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ export @pool_timeit
103103
const to = TimerOutput()
104104

105105
macro pool_timeit(args...)
106-
TimerOutputs.timer_expr(CUDA, true, :($CUDA.to), args...)
106+
TimerOutputs.timer_expr(CUDA, true, :($PoolUtils.to), args...)
107107
end
108108

109109
end

src/state.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -561,18 +561,3 @@ function stream!(f::Function, s::CuStream)
561561
set_library_streams(old_s)
562562
end
563563
end
564-
565-
566-
## memory pools
567-
568-
const memory_pools = Dict{CuContext,CuMemoryPool}()
569-
570-
function pool()
571-
if CUDA.version() < v"11.2"
572-
return nothing
573-
end
574-
575-
return get!(memory_pools, context()) do
576-
CuMemoryPool(device())
577-
end
578-
end

test/cudadrv/pool.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
let
2+
dev = device()
3+
4+
pool = memory_pool(dev)
5+
6+
pool2 = CuMemoryPool(dev)
7+
@test pool2 != pool
8+
memory_pool!(dev, pool2)
9+
@test pool2 == memory_pool(dev)
10+
@test pool2 != default_memory_pool(dev)
11+
12+
memory_pool!(dev, pool)
13+
@test pool == memory_pool(dev)
14+
end

0 commit comments

Comments
 (0)