Skip to content

Commit 2623e21

Browse files
nnethercoteLegNeato
authored andcommitted
Fix CUDA 13 API issues.
Four functions we use changed type signature in CUDA 13.0. This commit adds a cfg for each one and then uses conditional compilation at each call site so that the call works for both 12.x and 13.0.
1 parent b0c4c71 commit 2623e21

File tree

5 files changed

+87
-8
lines changed

5 files changed

+87
-8
lines changed

crates/cust/build.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,34 @@ fn main() {
1010
if driver_version >= 12030 {
1111
println!("cargo::rustc-cfg=conditional_node");
1212
}
13+
// In CUDA 13.0 several pairs/trios of functions were merged:
14+
// ```
15+
// CUresult cuMemAdvise(CUdeviceptr devPtr, size_t count, CUmem_advise advice, CUdevice device);
16+
// CUresult cuMemAdvise_v2(CUdeviceptr devPtr, size_t count, CUmem_advise advice, CUmemLocation location);
17+
//
18+
// CUresult cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, CUdevice dstDevice, CUstream hStream);
19+
// CUresult cuMemPrefetchAsync_v2(CUdeviceptr devPtr, size_t count, CUmemLocation location, unsigned int flags, CUstream hStream);
20+
//
21+
// CUresult cuGraphGetEdges(CUgraph hGraph, CUgraphNode* from, CUgraphNode* to, size_t* numEdges);
22+
// CUresult cuGraphGetEdges_v2(CUgraph hGraph, CUgraphNode* from, CUgraphNode* to, CUgraphEdgeData* edgeData, size_t* numEdges);
23+
//
24+
// CUresult cuCtxCreate(CUcontext* pctx, unsigned int flags, CUdevice dev);
25+
// CUresult cuCtxCreate_v3(CUcontext* pctx, CUexecAffinityParam* paramsArray, int numParams, unsigned int flags, CUdevice dev);
26+
// CUresult cuCtxCreate_v4(CUcontext* pctx, CUctxCreateParams* ctxCreateParams, unsigned int flags, CUdevice dev);
27+
// ```
28+
// In each case, the resulting single function has the name of the first function and the type
29+
// signature of the last.
30+
//
31+
// These cfgs let you call these functions and make it work for both pre CUDA-13.0 and CUDA
32+
// 13.0. When support for CUDA 12.x is dropped, these cfgs can be removed.
33+
println!("cargo::rustc-check-cfg=cfg(cuMemAdvise_v2)");
34+
println!("cargo::rustc-check-cfg=cfg(cuMemPrefetchAsync_v2)");
35+
println!("cargo::rustc-check-cfg=cfg(cuGraphGetEdges_v2)");
36+
println!("cargo::rustc-check-cfg=cfg(cuCtxCreate_v4)");
37+
if driver_version >= 13000 {
38+
println!("cargo::rustc-cfg=cuMemAdvise_v2");
39+
println!("cargo::rustc-cfg=cuMemPrefetchAsync_v2");
40+
println!("cargo::rustc-cfg=cuGraphGetEdges_v2");
41+
println!("cargo::rustc-cfg=cuCtxCreate_v4");
42+
}
1343
}

crates/cust/src/context/legacy.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,14 @@ impl Context {
262262
// lifetime guarantees so we create-and-push, then pop, then the programmer has to
263263
// push again.
264264
let mut ctx: CUcontext = ptr::null_mut();
265-
driver_sys::cuCtxCreate(&mut ctx as *mut CUcontext, flags.bits(), device.as_raw())
266-
.to_result()?;
265+
driver_sys::cuCtxCreate(
266+
&mut ctx as *mut CUcontext,
267+
#[cfg(cuCtxCreate_v4)]
268+
&mut driver_sys::CUctxCreateParams::default(),
269+
flags.bits(),
270+
device.as_raw(),
271+
)
272+
.to_result()?;
267273
Ok(Context { inner: ctx })
268274
}
269275
}

crates/cust/src/graph.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,8 @@ impl Graph {
415415
self.raw,
416416
ptr::null_mut(),
417417
ptr::null_mut(),
418+
#[cfg(cuGraphGetEdges_v2)]
419+
ptr::null_mut(),
418420
size.as_mut_ptr(),
419421
)
420422
.to_result()?;
@@ -439,6 +441,8 @@ impl Graph {
439441
self.raw,
440442
from.as_mut_ptr(),
441443
to.as_mut_ptr(),
444+
#[cfg(cuGraphGetEdges_v2)]
445+
ptr::null_mut(),
442446
&num_edges as *const _ as *mut usize,
443447
)
444448
.to_result()?;

crates/cust/src/memory/unified.rs

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -640,10 +640,19 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
640640
let mem_size = std::mem::size_of_val(slice);
641641

642642
unsafe {
643+
let id = -1; // -1 is CU_DEVICE_CPU
643644
driver_sys::cuMemPrefetchAsync(
644645
slice.as_ptr() as driver_sys::CUdeviceptr,
645646
mem_size,
646-
-1, // CU_DEVICE_CPU #define
647+
#[cfg(cuMemPrefetchAsync_v2)]
648+
driver_sys::CUmemLocation {
649+
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
650+
id,
651+
},
652+
#[cfg(not(cuMemPrefetchAsync_v2))]
653+
id,
654+
#[cfg(cuMemPrefetchAsync_v2)]
655+
0, // flags for future use, must be 0 as of CUDA 13.0
647656
stream.as_inner(),
648657
)
649658
.to_result()?;
@@ -677,10 +686,19 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
677686
let mem_size = std::mem::size_of_val(slice);
678687

679688
unsafe {
689+
let id = device.as_raw();
680690
driver_sys::cuMemPrefetchAsync(
681691
slice.as_ptr() as driver_sys::CUdeviceptr,
682692
mem_size,
683-
device.as_raw(),
693+
#[cfg(cuMemPrefetchAsync_v2)]
694+
driver_sys::CUmemLocation {
695+
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
696+
id,
697+
},
698+
#[cfg(not(cuMemPrefetchAsync_v2))]
699+
id,
700+
#[cfg(cuMemPrefetchAsync_v2)]
701+
0, // flags for future use, must be 0 as of CUDA 13.0
684702
stream.as_inner(),
685703
)
686704
.to_result()?;
@@ -709,11 +727,18 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
709727
};
710728

711729
unsafe {
730+
let id = 0;
712731
driver_sys::cuMemAdvise(
713732
slice.as_ptr() as driver_sys::CUdeviceptr,
714733
mem_size,
715734
advice,
716-
0,
735+
#[cfg(cuMemAdvise_v2)]
736+
driver_sys::CUmemLocation {
737+
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
738+
id,
739+
},
740+
#[cfg(not(cuMemAdvise_v2))]
741+
id,
717742
)
718743
.to_result()?;
719744
}
@@ -744,11 +769,18 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
744769
let mem_size = std::mem::size_of_val(slice);
745770

746771
unsafe {
772+
let id = preferred_location.map(|d| d.as_raw()).unwrap_or(-1); // -1 is CU_DEVICE_CPU
747773
driver_sys::cuMemAdvise(
748774
slice.as_ptr() as driver_sys::CUdeviceptr,
749775
mem_size,
750776
driver_sys::CUmem_advise::CU_MEM_ADVISE_SET_PREFERRED_LOCATION,
751-
preferred_location.map(|d| d.as_raw()).unwrap_or(-1),
777+
#[cfg(cuMemAdvise_v2)]
778+
driver_sys::CUmemLocation {
779+
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
780+
id,
781+
},
782+
#[cfg(not(cuMemAdvise_v2))]
783+
id,
752784
)
753785
.to_result()?;
754786
}
@@ -761,11 +793,18 @@ pub trait MemoryAdvise<T: DeviceCopy>: private::Sealed {
761793
let mem_size = std::mem::size_of_val(slice);
762794

763795
unsafe {
796+
let id = 0;
764797
driver_sys::cuMemAdvise(
765798
slice.as_ptr() as driver_sys::CUdeviceptr,
766799
mem_size,
767800
driver_sys::CUmem_advise::CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION,
768-
0,
801+
#[cfg(cuMemAdvise_v2)]
802+
driver_sys::CUmemLocation {
803+
type_: driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
804+
id,
805+
},
806+
#[cfg(not(cuMemAdvise_v2))]
807+
id,
769808
)
770809
.to_result()?;
771810
}

crates/cust_raw/build/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//! The build script for the cust_raw generates bindings for libraries in the
33
//! CUDA SDK. The build scripts searches for the CUDA SDK by reading the
44
//! `CUDA_PATH`, `CUDA_ROOT`, or `CUDA_TOOLKIT_ROOT_DIR` environment variables
5-
//! in that order. If none of these variables are set to a vaild CUDA Toolkit
5+
//! in that order. If none of these variables are set to a valid CUDA Toolkit
66
//! SDK path, the build script will attempt to search for any SDK in the
77
//! default installation locations for the current platform.
88
//!

0 commit comments

Comments
 (0)