@@ -315,6 +315,38 @@ def createRrBench(variant_name: str, **kwargs):
315315 ),
316316 ]
317317
318+ for runtime in RUNTIMES :
319+ if runtime != RUNTIMES .UR :
320+
321+ def createTorchMultiQueueBench (variant_name : str , ** kwargs ):
322+ return TorchMultiQueue (
323+ self ,
324+ runtime ,
325+ variant_name ,
326+ PROFILERS .TIMER ,
327+ ** kwargs ,
328+ )
329+
330+ benches += [
331+ createTorchMultiQueueBench (
332+ "large" ,
333+ workgroupCount = 4096 ,
334+ workgroupSize = 512 ,
335+ kernelsPerQueue = 20 ,
336+ ),
337+ createTorchMultiQueueBench (
338+ "medium" ,
339+ workgroupCount = 512 ,
340+ workgroupSize = 256 ,
341+ kernelsPerQueue = 10 ,
342+ ),
343+ createTorchMultiQueueBench (
344+ "small" ,
345+ workgroupCount = 256 ,
346+ workgroupSize = 124 ,
347+ kernelsPerQueue = 4 ,
348+ ),
349+ ]
318350 # Add UR-specific benchmarks
319351 benches += [
320352 # TODO: multithread_benchmark_ur fails with segfault
@@ -735,6 +767,48 @@ def _bin_args(self, run_trace: TracingType = TracingType.NONE) -> list[str]:
735767 return [f"--{ k } ={ v } " for k , v in self ._rr_params .items ()]
736768
737769
770+ class TorchMultiQueue (ComputeBenchmark ):
771+ def __init__ (
772+ self , suite , runtime : RUNTIMES , variant_name : str , profiler_type , ** kwargs
773+ ):
774+ self ._variant_name = variant_name
775+ self ._smq_params = kwargs
776+ self ._iterations_regular = 1000
777+ self ._iterations_trace = 10
778+ super ().__init__ (
779+ suite ,
780+ f"torch_benchmark_{ runtime .value } " ,
781+ "KernelSubmitMultiQueue" ,
782+ runtime ,
783+ profiler_type ,
784+ )
785+
786+ def explicit_group (self ):
787+ return f"{ self ._test } { self ._variant_name } "
788+
789+ def display_name (self ) -> str :
790+ return f"{ self .explicit_group ()} _{ self ._runtime .value } "
791+
792+ def get_tags (self ):
793+ return [runtime_to_tag_name (self ._runtime )]
794+
795+ def name (self ):
796+ ret = []
797+ for k , v in self ._smq_params .items ():
798+ ret .append (f"{ k } { v } " )
799+ ret .sort ()
800+ return self ._bench_name + " " + ", " .join (ret )
801+
802+ def _supported_runtimes (self ) -> list [RUNTIMES ]:
803+ return super ()._supported_runtimes () + [RUNTIMES .SYCL_PREVIEW ]
804+
805+ def _bin_args (self , run_trace : TracingType = TracingType .NONE ) -> list [str ]:
806+ iters = self ._get_iters (run_trace )
807+ return [f"--iterations={ iters } " ] + [
808+ f"--{ k } ={ v } " for k , v in self ._smq_params .items ()
809+ ]
810+
811+
738812class QueueInOrderMemcpy (ComputeBenchmark ):
739813 def __init__ (self , bench , isCopyOnly , source , destination , size , profiler_type ):
740814 self ._is_copy_only = isCopyOnly
0 commit comments