Skip to content

Commit 2891f7d

Browse files
nnethercoteLegNeato
authored andcommitted
Share TILE_SIZE.
One of the nice things about using Rust for both CPU and GPU code is the ability to share things between them. So let's do that in `gemm`.
1 parent 995eaed commit 2891f7d

File tree

5 files changed

+6
-5
lines changed

5 files changed

+6
-5
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/cuda/gemm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ blastoff = { path = "../../../crates/blastoff" }
88
cuda_std = { path = "../../../crates/cuda_std" }
99
cust = { path = "../../../crates/cust" }
1010
cust_raw = { path = "../../../crates/cust_raw", features = ["driver"] }
11+
gemm-kernels = { path = "kernels" }
1112
ndarray = { version = "0.16", features = ["approx"] }
1213
ndarray-rand = "0.15.0"
1314
rand = "0.9"

examples/cuda/gemm/kernels/src/gemm_tiled.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use cuda_std::address_space;
33
use cuda_std::kernel;
44
use cuda_std::thread;
55

6+
pub const TILE_SIZE: usize = 16;
7+
68
#[kernel]
79
#[allow(improper_ctypes_definitions)]
810
/// Tiled GEMM kernel for C = alpha * A * B + beta * C.
@@ -38,7 +40,6 @@ pub unsafe fn gemm_tiled(
3840
alpha: f32,
3941
beta: f32,
4042
) {
41-
const TILE_SIZE: usize = 16;
4243
const TILE_SIZE_2D: usize = TILE_SIZE * TILE_SIZE;
4344

4445
// Shared GPU memory is modelled with `#[address_space(shared)] static mut`. Unlike normal

examples/cuda/gemm/kernels/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ mod gemm_naive;
22
mod gemm_tiled;
33

44
pub use crate::gemm_naive::gemm_naive;
5-
pub use crate::gemm_tiled::gemm_tiled;
5+
pub use crate::gemm_tiled::{TILE_SIZE, gemm_tiled};

examples/cuda/gemm/src/main.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use cust::memory::CopyDestination as _;
1313
use cust::module;
1414
use cust::stream;
1515
use cust::util::SliceExt as _;
16+
use gemm_kernels::TILE_SIZE;
1617
use ndarray::Array;
1718
use ndarray_rand::RandomExt as _;
1819
use ndarray_rand::rand_distr::Uniform;
@@ -430,9 +431,6 @@ pub fn gemm_tiled(
430431
assert_eq!(mat_b.len(), k * n);
431432
assert_eq!(mat_c.len(), m * n);
432433

433-
// These values must be aligned with the kernel code.
434-
const TILE_SIZE: usize = 16;
435-
436434
let kernel_cell = cell::LazyCell::new(|| {
437435
module
438436
.get_function("gemm_tiled")

0 commit comments

Comments
 (0)