Skip to content

Commit 995eaed

Browse files
nnethercoteLegNeato
authored andcommitted
Fix gemm example on CUDA 13.0.
`#[address_space(shared)] static mut` is used to model GPU shared memory. It's a bit weird. In particular, GPU shared memory is uninitialized, but `static mut` requires an initializer in Rust. `gemm` uses a zero initialize, but this initializer is ignored by NVVM. At least, it was in CUDA 12.x, but in CUDA 13.0 the `gemm` example fails with this error: ``` thread 'rustc' panicked at crates/rustc_codegen_nvvm/src/nvvm.rs:120:9: Malformed NVVM IR program rejected by libnvvm, dumping verifier log: error: Error: : Global Variable `_ZN12gemm_kernels10gemm_tiled10gemm_tiled6TILE_A17hc9c66e758c373a7eE': context: @_ZN12gemm_kernels10gemm_tiled10gemm_tiled6TILE_A17hc9c66e758c373a7eE = internal unnamed_addr addrspace(3) global <{ [1024 x i8] }> zeroinitializer, align 4 Shared variables can't be initialized ``` This memory looks like it's initialized to zero but isn't, and then is written and read normally. This is incredibly dodgy and very likely UB. The proper way to deal with uninitialized memory in Rust is with `MaybeUninit`, and there are strict rules around its used, e.g. writes must be done with `write` and `assume_init` must be used values after they are written. This commit changes `gemm` to use `MaybeUninit` for the shared memory. This fixes the error on CUDA 13.0 and the example runs correctly. (This is the only executed use of GPU shared memory in rust-cuda. There is a `shared_array!` macro defined but it's only used in a compiletest where it is compiled but not run. That macro is extremely dubious but I will deal with it in a separate PR because it's not necessary to get CUDA 13.0 working.)
1 parent 2623e21 commit 995eaed

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/cuda/gemm/kernels/src/gemm_tiled.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use core::mem::MaybeUninit;
12
use cuda_std::address_space;
23
use cuda_std::kernel;
34
use cuda_std::thread;
@@ -38,11 +39,17 @@ pub unsafe fn gemm_tiled(
3839
beta: f32,
3940
) {
4041
const TILE_SIZE: usize = 16;
42+
const TILE_SIZE_2D: usize = TILE_SIZE * TILE_SIZE;
4143

44+
// Shared GPU memory is modelled with `#[address_space(shared)] static mut`. Unlike normal
45+
// `static mut`, it is not initialized, and only exists for the duration of the kernel's
46+
// (multi-)execution. Because it is not initialized, it must be marked with `MaybeUninit`,
47+
// written with `write` (in unsafe blocks because writing a `static mut` is unsafe), and
48+
// subsequently read with `assume_init`.
4249
#[address_space(shared)]
43-
static mut TILE_A: [f32; TILE_SIZE * TILE_SIZE] = [0.; TILE_SIZE * TILE_SIZE];
50+
static mut TILE_A: [MaybeUninit<f32>; TILE_SIZE_2D] = [MaybeUninit::uninit(); TILE_SIZE_2D];
4451
#[address_space(shared)]
45-
static mut TILE_B: [f32; TILE_SIZE * TILE_SIZE] = [0.; TILE_SIZE * TILE_SIZE];
52+
static mut TILE_B: [MaybeUninit<f32>; TILE_SIZE_2D] = [MaybeUninit::uninit(); TILE_SIZE_2D];
4653

4754
// Thread indices within the block.
4855
let tx = thread::thread_idx_x() as usize;
@@ -57,20 +64,30 @@ pub unsafe fn gemm_tiled(
5764
for kk in (0..k).step_by(TILE_SIZE) {
5865
// Collaborative loading of tiles into shared memory.
5966
if row < m && (kk + tx) < k {
60-
unsafe { TILE_A[ty * TILE_SIZE + tx] = mat_a[row * k + (kk + tx)] };
67+
unsafe {
68+
TILE_A[ty * TILE_SIZE + tx].write(mat_a[row * k + (kk + tx)]);
69+
}
6170
} else {
62-
unsafe { TILE_A[ty * TILE_SIZE + tx] = 0.0f32 };
71+
unsafe {
72+
TILE_A[ty * TILE_SIZE + tx].write(0.0f32);
73+
}
6374
}
6475
if col < n && (kk + ty) < k {
65-
unsafe { TILE_B[ty * TILE_SIZE + tx] = mat_b[(kk + ty) * n + col] };
76+
unsafe {
77+
TILE_B[ty * TILE_SIZE + tx].write(mat_b[(kk + ty) * n + col]);
78+
}
6679
} else {
67-
unsafe { TILE_B[ty * TILE_SIZE + tx] = 0.0f32 };
80+
unsafe {
81+
TILE_B[ty * TILE_SIZE + tx].write(0.0f32);
82+
}
6883
}
6984
thread::sync_threads();
7085

7186
// Perform the computation on the tile.
7287
for i in 0..TILE_SIZE {
73-
sum += unsafe { TILE_A[ty * TILE_SIZE + i] * TILE_B[i * TILE_SIZE + tx] };
88+
sum += unsafe {
89+
TILE_A[ty * TILE_SIZE + i].assume_init() * TILE_B[i * TILE_SIZE + tx].assume_init()
90+
};
7491
}
7592
thread::sync_threads();
7693
}

0 commit comments

Comments
 (0)