33export @roc , rocconvert, rocfunction
44
55struct Kernel{F,TT}
6- agent :: HSAAgent
6+ device :: RuntimeDevice
77 mod:: ROCModule
88 fun:: ROCFunction
99end
1010
1111# `split_kwargs()` segregates keyword arguments passed to `@roc` into those
1212# affecting the compiler, kernel execution, or both.
1313function split_kwargs (kwargs)
14- compiler_kws = [:agent , :queue , :name ]
15- call_kws = [:groupsize , :gridsize , :agent , :queue ]
14+ # TODO : Alias groupsize and gridsize as threads and blocks, respectively
15+ compiler_kws = [:device , :agent , :queue , :name ]
16+ call_kws = [:groupsize , :gridsize , :device , :agent , :queue ]
1617 compiler_kwargs = []
1718 call_kwargs = []
1819 for kwarg in kwargs
@@ -60,6 +61,23 @@ function assign_args!(code, args)
6061 return vars, var_exprs
6162end
6263
64+ function extract_device (;device= nothing , agent= nothing , kwargs... )
65+ if device != = nothing
66+ return device
67+ elseif agent != = nothing
68+ return agent
69+ else
70+ return default_device ()
71+ end
72+ end
73+ function extract_queue (device; queue= nothing , kwargs... )
74+ if queue != = nothing
75+ return queue
76+ else
77+ return default_queue (device)
78+ end
79+ end
80+
6381# fast lookup of global world age
6482world_age () = ccall (:jl_get_tls_world_age , UInt, ())
6583
@@ -125,10 +143,10 @@ macro roc(ex...)
125143 GC. @preserve $ (vars... ) begin
126144 local kernel_args = map (rocconvert, ($ (var_exprs... ),))
127145 local kernel_tt = Tuple{Core. Typeof .(kernel_args)... }
128- local agent = get_default_agent ( )
129- local kernel = rocfunction (agent , $ (esc (f)), kernel_tt;
146+ local device = extract_device (; $ ( esc (call_kwargs) ... ) )
147+ local kernel = rocfunction (device , $ (esc (f)), kernel_tt;
130148 $ (map (esc, compiler_kwargs)... ))
131- local queue = get_default_queue (agent )
149+ local queue = extract_queue (device; $ ( esc (call_kwargs) ... ) )
132150 local signal = HSASignal ()
133151 kernel (queue, signal, kernel_args... ; $ (map (esc, call_kwargs)... ))
134152 wait (signal)
@@ -188,7 +206,7 @@ The output of this function is automatically cached, i.e. you can simply call
188206generated automatically, when the function changes, or when different types or
189207keyword arguments are provided.
190208"""
191- @generated function rocfunction (agent :: HSAAgent , f:: Core.Function , tt:: Type = Tuple{}; name= nothing , kwargs... )
209+ @generated function rocfunction (device :: RuntimeDevice , f:: Core.Function , tt:: Type = Tuple{}; name= nothing , kwargs... )
192210 tt = Base. to_tuple_type (tt. parameters[1 ])
193211 sig = Base. signature_type (f, tt)
194212 t = Tuple (tt. parameters)
@@ -217,8 +235,8 @@ keyword arguments are provided.
217235
218236 # compile the function
219237 if ! haskey (compilecache, key)
220- fun, mod = compile (:roc , agent , f, tt; name= name, kwargs... )
221- kernel = Kernel {f,tt} (agent , mod, fun)
238+ fun, mod = compile (:roc , device , f, tt; name= name, kwargs... )
239+ kernel = Kernel {f,tt} (device , mod, fun)
222240 compilecache[key] = kernel
223241 end
224242
@@ -227,9 +245,9 @@ keyword arguments are provided.
227245end
228246
229247rocfunction (f:: Core.Function , tt:: Type = Tuple{}; kwargs... ) =
230- rocfunction (get_default_agent (), f, tt; kwargs... )
248+ rocfunction (default_device (), f, tt; kwargs... )
231249
232- @generated function call (kernel:: Kernel{F,TT} , queue:: HSAQueue ,
250+ @generated function call (kernel:: Kernel{F,TT} , queue:: RuntimeQueue ,
233251 signal:: HSASignal , args... ; call_kwargs... ) where {F,TT}
234252
235253 sig = Base. signature_type (F, TT)
0 commit comments