File tree Expand file tree Collapse file tree 5 files changed +6
-5
lines changed Expand file tree Collapse file tree 5 files changed +6
-5
lines changed Original file line number Diff line number Diff line change @@ -8,6 +8,7 @@ blastoff = { path = "../../../crates/blastoff" }
88cuda_std = { path = " ../../../crates/cuda_std" }
99cust = { path = " ../../../crates/cust" }
1010cust_raw = { path = " ../../../crates/cust_raw" , features = [" driver" ] }
11+ gemm-kernels = { path = " kernels" }
1112ndarray = { version = " 0.16" , features = [" approx" ] }
1213ndarray-rand = " 0.15.0"
1314rand = " 0.9"
Original file line number Diff line number Diff line change @@ -3,6 +3,8 @@ use cuda_std::address_space;
33use cuda_std:: kernel;
44use cuda_std:: thread;
55
6+ pub const TILE_SIZE : usize = 16 ;
7+
68#[ kernel]
79#[ allow( improper_ctypes_definitions) ]
810/// Tiled GEMM kernel for C = alpha * A * B + beta * C.
@@ -38,7 +40,6 @@ pub unsafe fn gemm_tiled(
3840 alpha : f32 ,
3941 beta : f32 ,
4042) {
41- const TILE_SIZE : usize = 16 ;
4243 const TILE_SIZE_2D : usize = TILE_SIZE * TILE_SIZE ;
4344
4445 // Shared GPU memory is modelled with `#[address_space(shared)] static mut`. Unlike normal
Original file line number Diff line number Diff line change @@ -2,4 +2,4 @@ mod gemm_naive;
22mod gemm_tiled;
33
44pub use crate :: gemm_naive:: gemm_naive;
5- pub use crate :: gemm_tiled:: gemm_tiled;
5+ pub use crate :: gemm_tiled:: { TILE_SIZE , gemm_tiled} ;
Original file line number Diff line number Diff line change @@ -13,6 +13,7 @@ use cust::memory::CopyDestination as _;
1313use cust:: module;
1414use cust:: stream;
1515use cust:: util:: SliceExt as _;
16+ use gemm_kernels:: TILE_SIZE ;
1617use ndarray:: Array ;
1718use ndarray_rand:: RandomExt as _;
1819use ndarray_rand:: rand_distr:: Uniform ;
@@ -430,9 +431,6 @@ pub fn gemm_tiled(
430431 assert_eq ! ( mat_b. len( ) , k * n) ;
431432 assert_eq ! ( mat_c. len( ) , m * n) ;
432433
433- // These values must be aligned with the kernel code.
434- const TILE_SIZE : usize = 16 ;
435-
436434 let kernel_cell = cell:: LazyCell :: new ( || {
437435 module
438436 . get_function ( "gemm_tiled" )
You can’t perform that action at this time.
0 commit comments