@@ -28,25 +28,25 @@ function matmul_impl(a, b, c, d,
2828 # (1) Cooperatively load a BLOCK_SHAPE.M x BLOCK_SHAPE.N tile of C from global to shared memory within one threadblock
2929 shmem_c = @cuDynamicSharedMem (Layout. eltype (SHARED_C_LAYOUT), Layout. size (SHARED_C_LAYOUT, block_tile. MN. size))
3030
31- @unroll for warp_tile = parallellise (block_tile. MN, MEM_CD_WARP, warpId, WARPS_PER_BLOCK)
32- @unroll for thread_tile = parallellise (warp_tile, MEM_CD_THREAD, laneId, 32 )
33- x = Layout. load (GLOBAL_C_LAYOUT, c, translate (thread_tile, (M = block_i, N = block_j)), gemm_sz . MN . size )
31+ @unroll for warp_tile = parallellise (block_tile. MN, Tile ( MEM_CD_WARP) , warpId, WARPS_PER_BLOCK)
32+ @unroll for thread_tile = parallellise (warp_tile, Tile ( MEM_CD_THREAD) , laneId, 32 )
33+ x = Layout. load (GLOBAL_C_LAYOUT, c, translate (thread_tile, (M = block_i, N = block_j)))
3434 x = transf_gl2sh_c (x, thread_tile)
35- Layout. store! (SHARED_C_LAYOUT, shmem_c, x, thread_tile, block_tile . MN . size )
35+ Layout. store! (SHARED_C_LAYOUT, shmem_c, x, thread_tile)
3636 end
3737 end
3838
3939 sync_threads ()
4040
4141 # (2) Load a COMPUTE_WARP.M x COMPUTE_WARP.N tile of C from shared memory into registers
42- warp_tile = subdivide (block_tile. MN, (M = COMPUTE_WARP. M, N = COMPUTE_WARP . N) , warpId, WARPS_PER_BLOCK)
42+ warp_tile = subdivide (block_tile. MN, Tile ( COMPUTE_WARP) . MN , warpId, WARPS_PER_BLOCK)
4343
4444 c_frags = MArray {Tuple{NUM_FRAGMENTS_M, NUM_FRAGMENTS_N}, Operator.fragtype_accum(OPERATOR, SHARED_C_LAYOUT)} (undef)
4545
4646 @unroll for i = 1 : NUM_FRAGMENTS_M
4747 @unroll for j = 1 : NUM_FRAGMENTS_N
4848 tile = translate (warp_tile, (M = (i- 1 )* COMPUTE_OP_SHAPE. M, N = (j- 1 )* COMPUTE_OP_SHAPE. N))
49- @inbounds c_frags[i, j] = transf_sh2rf_c (Operator. load_c (OPERATOR, SHARED_C_LAYOUT, shmem_c, tile, block_tile . MN . size ), tile)
49+ @inbounds c_frags[i, j] = transf_sh2rf_c (Operator. load_c (OPERATOR, SHARED_C_LAYOUT, shmem_c, tile), tile)
5050 end
5151 end
5252
@@ -59,41 +59,41 @@ function matmul_impl(a, b, c, d,
5959
6060 @unroll for block_k = 0 : block_tile. size. K : gemm_sz. size. K - 1
6161 # (3.1) Cooperatively load a BLOCK_SHAPE.M x BLOCK_SHAPE.K tile of A from global to shared memory within one threadblock
62- @unroll for warp_tile = parallellise (block_tile. MK, MEM_A_WARP, warpId, WARPS_PER_BLOCK)
63- @unroll for thread_tile = parallellise (warp_tile, MEM_A_THREAD, laneId, 32 )
64- x = Layout. load (GLOBAL_A_LAYOUT, a, translate (thread_tile, (M = block_i, K = block_k)), gemm_sz . MK . size )
62+ @unroll for warp_tile = parallellise (block_tile. MK, Tile ( MEM_A_WARP) , warpId, WARPS_PER_BLOCK)
63+ @unroll for thread_tile = parallellise (warp_tile, Tile ( MEM_A_THREAD) , laneId, 32 )
64+ x = Layout. load (GLOBAL_A_LAYOUT, a, translate (thread_tile, (M = block_i, K = block_k)))
6565 x = transf_gl2sh_a (x, thread_tile)
66- Layout. store! (SHARED_A_LAYOUT, shmem_a, x, thread_tile, block_tile . MK . size )
66+ Layout. store! (SHARED_A_LAYOUT, shmem_a, x, thread_tile)
6767 end
6868 end
6969
7070 # (3.2) Cooperatively load a BLOCK_SHAPE.K x BLOCK_SHAPE.N tile of B from global to shared memory within one threadblock
71- @unroll for warp_tile = parallellise (block_tile. KN, MEM_B_WARP, warpId, WARPS_PER_BLOCK)
72- @unroll for thread_tile = parallellise (warp_tile, MEM_B_THREAD, laneId, 32 )
73- x = Layout. load (GLOBAL_B_LAYOUT, b, translate (thread_tile, (K = block_k, N = block_j)), gemm_sz . KN . size )
71+ @unroll for warp_tile = parallellise (block_tile. KN, Tile ( MEM_B_WARP) , warpId, WARPS_PER_BLOCK)
72+ @unroll for thread_tile = parallellise (warp_tile, Tile ( MEM_B_THREAD) , laneId, 32 )
73+ x = Layout. load (GLOBAL_B_LAYOUT, b, translate (thread_tile, (K = block_k, N = block_j)))
7474 x = transf_gl2sh_b (x, thread_tile)
75- Layout. store! (SHARED_B_LAYOUT, shmem_b, x, thread_tile, block_tile . KN . size )
75+ Layout. store! (SHARED_B_LAYOUT, shmem_b, x, thread_tile)
7676 end
7777 end
7878
7979 sync_threads ()
8080
8181 # (3.3) Calculate a COMPUTE_WARP.M x COMPUTE_WARP.N tile of D, using a COMPUTE_WARP.M x COMPUTE_WARP.N x COMPUTE_WARP.K operation
82- @unroll for warp_tile = parallellise (block_tile, COMPUTE_WARP, warpId, WARPS_PER_BLOCK)
82+ @unroll for warp_tile = parallellise (block_tile, Tile ( COMPUTE_WARP) , warpId, WARPS_PER_BLOCK)
8383 # (3.3.1) Load a COMPUTE_WARP.M x COMPUTE_WARP.K tile of A from shared memory into registers
8484 a_frags = MArray {Tuple{NUM_FRAGMENTS_M}, Operator.fragtype_a(OPERATOR, SHARED_A_LAYOUT)} (undef)
8585
8686 @unroll for i = 1 : NUM_FRAGMENTS_M
8787 a_tile = translate (warp_tile. MK, (M = (i- 1 )* COMPUTE_OP_SHAPE. M, K = 0 ))
88- @inbounds a_frags[i] = transf_sh2rf_a (Operator. load_a (OPERATOR, SHARED_A_LAYOUT, shmem_a, a_tile, block_tile . MK . size ), a_tile)
88+ @inbounds a_frags[i] = transf_sh2rf_a (Operator. load_a (OPERATOR, SHARED_A_LAYOUT, shmem_a, a_tile), a_tile)
8989 end
9090
9191 # (3.3.2) Load a COMPUTE_WARP.K x COMPUTE_WARP.N tile of B from shared memory into registers
9292 b_frags = MArray {Tuple{NUM_FRAGMENTS_N}, Operator.fragtype_b(OPERATOR, SHARED_B_LAYOUT)} (undef)
9393
9494 @unroll for j = 1 : NUM_FRAGMENTS_N
9595 b_tile = translate (warp_tile. KN, (K = 0 , N = (j- 1 )* COMPUTE_OP_SHAPE. N))
96- @inbounds b_frags[j] = transf_sh2rf_b (Operator. load_b (OPERATOR, SHARED_B_LAYOUT, shmem_b, b_tile, block_tile . KN . size ), b_tile)
96+ @inbounds b_frags[j] = transf_sh2rf_b (Operator. load_b (OPERATOR, SHARED_B_LAYOUT, shmem_b, b_tile), b_tile)
9797 end
9898
9999 # (3.3.3) Compute a COMPUTE_WARP.M x COMPUTE_WARP.N x COMPUTE_WARP.K matrix product within one warp
@@ -110,12 +110,12 @@ function matmul_impl(a, b, c, d,
110110 # (4) Store the COMPUTE_WARP.M x COMPUTE_WARP.N tile of D from registers to shared memory
111111 shmem_d = @cuDynamicSharedMem (Layout. eltype (SHARED_D_LAYOUT), Layout. size (SHARED_D_LAYOUT, block_tile. MN. size))
112112
113- warp_tile = subdivide (block_tile. MN, (M = COMPUTE_WARP. M, N = COMPUTE_WARP . N) , warpId, WARPS_PER_BLOCK)
113+ warp_tile = subdivide (block_tile. MN, Tile ( COMPUTE_WARP) . MN , warpId, WARPS_PER_BLOCK)
114114
115115 @unroll for i = 1 : NUM_FRAGMENTS_M
116116 @unroll for j = 1 : NUM_FRAGMENTS_N
117117 tile = translate (warp_tile, (M = (i- 1 )* COMPUTE_OP_SHAPE. M, N = (j- 1 )* COMPUTE_OP_SHAPE. N))
118- Operator. store_d (OPERATOR, SHARED_D_LAYOUT, shmem_d, transf_rf2sh_d (c_frags[i, j], tile), tile, block_tile . MN . size )
118+ Operator. store_d (OPERATOR, SHARED_D_LAYOUT, shmem_d, transf_rf2sh_d (c_frags[i, j], tile), tile)
119119 end
120120 end
121121
0 commit comments