@@ -95,6 +95,13 @@ const int num_colors = sizeof(colors) / sizeof(uint32_t);
9595#ifdef SOLUTION
9696#include < nccl.h>
9797#endif
98+ #ifdef NCCL_VERSION
99+ #define NCCL_VERSION_UB NCCL_VERSION (2 ,19 ,1 )
100+ #define NCCL_UB_SUPPORT NCCL_VERSION_CODE >= NCCL_VERSION_UB
101+ #else
102+ #define NCCL_UB_SUPPORT 0
103+ #endif
104+
98105
99106#define NCCL_CALL (call ) \
100107 { \
@@ -172,6 +179,13 @@ int main(int argc, char* argv[]) {
172179 const int nx = get_argval<int >(argv, argv + argc, " -nx" , 16384 );
173180 const int ny = get_argval<int >(argv, argv + argc, " -ny" , 16384 );
174181 const bool csv = get_arg (argv, argv + argc, " -csv" );
182+ bool user_buffer_reg = get_arg (argv, argv + argc, " -user_buffer_reg" );
183+ #if NCCL_UB_SUPPORT == 0
184+ if (user_buffer_reg) {
185+ fprintf (stderr," WARNING: Ignoring -user_buffer_reg, required NCCL APIs are provided by NCCL 2.19.1 or later.\n " );
186+ user_buffer_reg = false ;
187+ }
188+ #endif // NCCL_UB_SUPPORT == 0
175189
176190 int local_rank = -1 ;
177191 {
@@ -226,10 +240,30 @@ int main(int argc, char* argv[]) {
226240 chunk_size = chunk_size_high;
227241
228242 real* a;
229- CUDA_RT_CALL (cudaMalloc (&a, nx * (chunk_size + 2 ) * sizeof (real)));
230243 real* a_new;
231- CUDA_RT_CALL (cudaMalloc (&a_new, nx * (chunk_size + 2 ) * sizeof (real)));
244+ #if NCCL_UB_SUPPORT
245+ void * a_reg_handle;
246+ void * a_new_reg_handle;
247+ if (user_buffer_reg) {
248+ // TODO: Allocate the memory with ncclMemAlloc and register it for the commmunicatior
249+ #ifdef SOLUTION
250+
251+ NCCL_CALL (ncclMemAlloc ( (void **) &a , nx * (chunk_size + 2 ) * sizeof (real)));
252+ NCCL_CALL (ncclMemAlloc ( (void **) &a_new, nx * (chunk_size + 2 ) * sizeof (real)));
253+ NCCL_CALL (ncclCommRegister (nccl_comm, a , nx * (chunk_size + 2 ) * sizeof (real), &a_reg_handle));
254+ NCCL_CALL (ncclCommRegister (nccl_comm, a_new, nx * (chunk_size + 2 ) * sizeof (real), &a_new_reg_handle));
255+ #endif
256+ if ( nccl_version < 22304 ) {
257+ fprintf (stderr," WARNING: -user_buffer_reg available, but Jacobi communication pattern needs NCCL 2.23.4 or later.\n " );
258+ }
259+ }
260+ else
261+ #endif // NCCL_UB_SUPPORT
232262
263+ {
264+ CUDA_RT_CALL (cudaMalloc (&a, nx * (chunk_size + 2 ) * sizeof (real)));
265+ CUDA_RT_CALL (cudaMalloc (&a_new, nx * (chunk_size + 2 ) * sizeof (real)));
266+ }
233267 CUDA_RT_CALL (cudaMemset (a, 0 , nx * (chunk_size + 2 ) * sizeof (real)));
234268 CUDA_RT_CALL (cudaMemset (a_new, 0 , nx * (chunk_size + 2 ) * sizeof (real)));
235269
@@ -434,10 +468,22 @@ int main(int argc, char* argv[]) {
434468
435469 CUDA_RT_CALL (cudaFreeHost (l2_norm_h));
436470 CUDA_RT_CALL (cudaFree (l2_norm_d));
437-
471+ #if NCCL_UB_SUPPORT
472+ if (user_buffer_reg) {
473+ // TODO: Deregister and Free the Buffer
474+ #ifdef SOLUTION
475+ NCCL_CALL (ncclCommDeregister (nccl_comm, a_new_reg_handle));
476+ NCCL_CALL (ncclCommDeregister (nccl_comm, a_reg_handle));
477+ NCCL_CALL (ncclMemFree (a_new));
478+ NCCL_CALL (ncclMemFree (a));
479+ #endif
480+ }
481+ else
482+ #endif // NCCL_UB_SUPPORT
483+ {
438484 CUDA_RT_CALL (cudaFree (a_new));
439485 CUDA_RT_CALL (cudaFree (a));
440-
486+ }
441487 CUDA_RT_CALL (cudaFreeHost (a_h));
442488 CUDA_RT_CALL (cudaFreeHost (a_ref_h));
443489
0 commit comments