Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 4fe49e0

Browse files
Cleanup
1 parent 6561ad7 commit 4fe49e0

File tree

7 files changed

+118
-92
lines changed

7 files changed

+118
-92
lines changed

src/device/matmul_kernels/epilogue.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ struct Default end
2323
block_tile = Tile(BLOCK_SHAPE)
2424

2525
# Cooperatively store a BLOCK_SHAPE.M x BLOCK_SHAPE.N tile of D from shared to global memory within one threadblock
26-
@unroll for warp_tile = parallellise(block_tile.MN, MEM_CD_WARP, warpId, WARPS_PER_BLOCK)
27-
@unroll for thread_tile = parallellise(warp_tile, MEM_CD_THREAD, laneId, 32)
28-
x = Layout.load(SHARED_D_LAYOUT, shmem_d, thread_tile, block_tile.MN.size)
26+
@unroll for warp_tile = parallellise(block_tile.MN, Tile(MEM_CD_WARP), warpId, WARPS_PER_BLOCK)
27+
@unroll for thread_tile = parallellise(warp_tile, Tile(MEM_CD_THREAD), laneId, 32)
28+
x = Layout.load(SHARED_D_LAYOUT, shmem_d, thread_tile)
2929
x = transform(x, thread_tile)
30-
Layout.store!(GLOBAL_D_LAYOUT, d, x, translate(thread_tile, (M = block_i, N = block_j)), gemm_sz.MN.size)
30+
Layout.store!(GLOBAL_D_LAYOUT, d, x, translate(thread_tile, (M = block_i, N = block_j)))
3131
end
3232
end
3333
end

src/device/matmul_kernels/kernel.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,25 @@ function matmul_impl(a, b, c, d,
2828
# (1) Cooperatively load a BLOCK_SHAPE.M x BLOCK_SHAPE.N tile of C from global to shared memory within one threadblock
2929
shmem_c = @cuDynamicSharedMem(Layout.eltype(SHARED_C_LAYOUT), Layout.size(SHARED_C_LAYOUT, block_tile.MN.size))
3030

31-
@unroll for warp_tile = parallellise(block_tile.MN, MEM_CD_WARP, warpId, WARPS_PER_BLOCK)
32-
@unroll for thread_tile = parallellise(warp_tile, MEM_CD_THREAD, laneId, 32)
33-
x = Layout.load(GLOBAL_C_LAYOUT, c, translate(thread_tile, (M = block_i, N = block_j)), gemm_sz.MN.size)
31+
@unroll for warp_tile = parallellise(block_tile.MN, Tile(MEM_CD_WARP), warpId, WARPS_PER_BLOCK)
32+
@unroll for thread_tile = parallellise(warp_tile, Tile(MEM_CD_THREAD), laneId, 32)
33+
x = Layout.load(GLOBAL_C_LAYOUT, c, translate(thread_tile, (M = block_i, N = block_j)))
3434
x = transf_gl2sh_c(x, thread_tile)
35-
Layout.store!(SHARED_C_LAYOUT, shmem_c, x, thread_tile, block_tile.MN.size)
35+
Layout.store!(SHARED_C_LAYOUT, shmem_c, x, thread_tile)
3636
end
3737
end
3838

3939
sync_threads()
4040

4141
# (2) Load a COMPUTE_WARP.M x COMPUTE_WARP.N tile of C from shared memory into registers
42-
warp_tile = subdivide(block_tile.MN, (M = COMPUTE_WARP.M, N = COMPUTE_WARP.N), warpId, WARPS_PER_BLOCK)
42+
warp_tile = subdivide(block_tile.MN, Tile(COMPUTE_WARP).MN, warpId, WARPS_PER_BLOCK)
4343

4444
c_frags = MArray{Tuple{NUM_FRAGMENTS_M, NUM_FRAGMENTS_N}, Operator.fragtype_accum(OPERATOR, SHARED_C_LAYOUT)}(undef)
4545

4646
@unroll for i = 1 : NUM_FRAGMENTS_M
4747
@unroll for j = 1 : NUM_FRAGMENTS_N
4848
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
49-
@inbounds c_frags[i, j] = transf_sh2rf_c(Operator.load_c(OPERATOR, SHARED_C_LAYOUT, shmem_c, tile, block_tile.MN.size), tile)
49+
@inbounds c_frags[i, j] = transf_sh2rf_c(Operator.load_c(OPERATOR, SHARED_C_LAYOUT, shmem_c, tile), tile)
5050
end
5151
end
5252

@@ -59,41 +59,41 @@ function matmul_impl(a, b, c, d,
5959

6060
@unroll for block_k = 0 : block_tile.size.K : gemm_sz.size.K - 1
6161
# (3.1) Cooperatively load a BLOCK_SHAPE.M x BLOCK_SHAPE.K tile of A from global to shared memory within one threadblock
62-
@unroll for warp_tile = parallellise(block_tile.MK, MEM_A_WARP, warpId, WARPS_PER_BLOCK)
63-
@unroll for thread_tile = parallellise(warp_tile, MEM_A_THREAD, laneId, 32)
64-
x = Layout.load(GLOBAL_A_LAYOUT, a, translate(thread_tile, (M = block_i, K = block_k)), gemm_sz.MK.size)
62+
@unroll for warp_tile = parallellise(block_tile.MK, Tile(MEM_A_WARP), warpId, WARPS_PER_BLOCK)
63+
@unroll for thread_tile = parallellise(warp_tile, Tile(MEM_A_THREAD), laneId, 32)
64+
x = Layout.load(GLOBAL_A_LAYOUT, a, translate(thread_tile, (M = block_i, K = block_k)))
6565
x = transf_gl2sh_a(x, thread_tile)
66-
Layout.store!(SHARED_A_LAYOUT, shmem_a, x, thread_tile, block_tile.MK.size)
66+
Layout.store!(SHARED_A_LAYOUT, shmem_a, x, thread_tile)
6767
end
6868
end
6969

7070
# (3.2) Cooperatively load a BLOCK_SHAPE.K x BLOCK_SHAPE.N tile of B from global to shared memory within one threadblock
71-
@unroll for warp_tile = parallellise(block_tile.KN, MEM_B_WARP, warpId, WARPS_PER_BLOCK)
72-
@unroll for thread_tile = parallellise(warp_tile, MEM_B_THREAD, laneId, 32)
73-
x = Layout.load(GLOBAL_B_LAYOUT, b, translate(thread_tile, (K = block_k, N = block_j)), gemm_sz.KN.size)
71+
@unroll for warp_tile = parallellise(block_tile.KN, Tile(MEM_B_WARP), warpId, WARPS_PER_BLOCK)
72+
@unroll for thread_tile = parallellise(warp_tile, Tile(MEM_B_THREAD), laneId, 32)
73+
x = Layout.load(GLOBAL_B_LAYOUT, b, translate(thread_tile, (K = block_k, N = block_j)))
7474
x = transf_gl2sh_b(x, thread_tile)
75-
Layout.store!(SHARED_B_LAYOUT, shmem_b, x, thread_tile, block_tile.KN.size)
75+
Layout.store!(SHARED_B_LAYOUT, shmem_b, x, thread_tile)
7676
end
7777
end
7878

7979
sync_threads()
8080

8181
# (3.3) Calculate a COMPUTE_WARP.M x COMPUTE_WARP.N tile of D, using a COMPUTE_WARP.M x COMPUTE_WARP.N x COMPUTE_WARP.K operation
82-
@unroll for warp_tile = parallellise(block_tile, COMPUTE_WARP, warpId, WARPS_PER_BLOCK)
82+
@unroll for warp_tile = parallellise(block_tile, Tile(COMPUTE_WARP), warpId, WARPS_PER_BLOCK)
8383
# (3.3.1) Load a COMPUTE_WARP.M x COMPUTE_WARP.K tile of A from shared memory into registers
8484
a_frags = MArray{Tuple{NUM_FRAGMENTS_M}, Operator.fragtype_a(OPERATOR, SHARED_A_LAYOUT)}(undef)
8585

8686
@unroll for i = 1 : NUM_FRAGMENTS_M
8787
a_tile = translate(warp_tile.MK, (M = (i-1)*COMPUTE_OP_SHAPE.M, K = 0))
88-
@inbounds a_frags[i] = transf_sh2rf_a(Operator.load_a(OPERATOR, SHARED_A_LAYOUT, shmem_a, a_tile, block_tile.MK.size), a_tile)
88+
@inbounds a_frags[i] = transf_sh2rf_a(Operator.load_a(OPERATOR, SHARED_A_LAYOUT, shmem_a, a_tile), a_tile)
8989
end
9090

9191
# (3.3.2) Load a COMPUTE_WARP.K x COMPUTE_WARP.N tile of B from shared memory into registers
9292
b_frags = MArray{Tuple{NUM_FRAGMENTS_N}, Operator.fragtype_b(OPERATOR, SHARED_B_LAYOUT)}(undef)
9393

9494
@unroll for j = 1 : NUM_FRAGMENTS_N
9595
b_tile = translate(warp_tile.KN, (K = 0, N = (j-1)*COMPUTE_OP_SHAPE.N))
96-
@inbounds b_frags[j] = transf_sh2rf_b(Operator.load_b(OPERATOR, SHARED_B_LAYOUT, shmem_b, b_tile, block_tile.KN.size), b_tile)
96+
@inbounds b_frags[j] = transf_sh2rf_b(Operator.load_b(OPERATOR, SHARED_B_LAYOUT, shmem_b, b_tile), b_tile)
9797
end
9898

9999
# (3.3.3) Compute a COMPUTE_WARP.M x COMPUTE_WARP.N x COMPUTE_WARP.K matrix product within one warp
@@ -110,12 +110,12 @@ function matmul_impl(a, b, c, d,
110110
# (4) Store the COMPUTE_WARP.M x COMPUTE_WARP.N tile of D from registers to shared memory
111111
shmem_d = @cuDynamicSharedMem(Layout.eltype(SHARED_D_LAYOUT), Layout.size(SHARED_D_LAYOUT, block_tile.MN.size))
112112

113-
warp_tile = subdivide(block_tile.MN, (M = COMPUTE_WARP.M, N = COMPUTE_WARP.N), warpId, WARPS_PER_BLOCK)
113+
warp_tile = subdivide(block_tile.MN, Tile(COMPUTE_WARP).MN, warpId, WARPS_PER_BLOCK)
114114

115115
@unroll for i = 1 : NUM_FRAGMENTS_M
116116
@unroll for j = 1 : NUM_FRAGMENTS_N
117117
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
118-
Operator.store_d(OPERATOR, SHARED_D_LAYOUT, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile, block_tile.MN.size)
118+
Operator.store_d(OPERATOR, SHARED_D_LAYOUT, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
119119
end
120120
end
121121

src/device/matmul_kernels/layout.jl

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ module Layout
33

44
using CUDAnative
55
using CUDAnative.Tiling
6+
using GPUifyLoops
7+
using StaticArrays
68

79
# -----------
810
# Layout base
@@ -26,25 +28,44 @@ end
2628

2729
@inline eltype(::Type{Padded{L, P}}) where {L, P} = eltype(L)
2830
@inline size(::Type{Padded{L, P}}, logical_size::NamedTuple) where {L, P} = size(L, pad_logical_coord(Padded{L, P}, logical_size))
29-
@inline load(::Type{Padded{L, P}}, workspace, tile::Tile, logical_size::NamedTuple) where {L, P} = load(L, workspace, tile, pad_logical_coord(Padded{L, P}, logical_size))
30-
@inline store!(::Type{Padded{L, P}}, workspace, value, tile::Tile, logical_size::NamedTuple) where {L, P} = store!(L, workspace, value, tile::Tile, pad_logical_coord(Padded{L, P}, logical_size))
31+
@inline load(::Type{Padded{L, P}}, workspace, tile::Tile, logical_size::NamedTuple) where {L, P} = load(L, workspace, tile)
32+
@inline store!(::Type{Padded{L, P}}, workspace, value, tile::Tile) where {L, P} = store!(L, workspace, value, tile::Tile)
3133

3234
# ---------------
3335
# AlignedColMajor
3436
# ---------------
3537

3638
struct AlignedColMajor{T} <: LayoutBase{T} end
3739

38-
@inline function load(::Type{AlignedColMajor{T}}, workspace, tile::Tile, logical_size::NamedTuple) where {T}
39-
N = 16 ÷ sizeof(T)
40-
ptr = pointer(workspace, linearise(tile.base, logical_size))
41-
return vloada(Vec{N, T}, ptr, linearise(tile.offset, logical_size))
40+
# TODO: cleanup vectorisation
41+
@inline function load(::Type{AlignedColMajor{T}}, workspace, tile::Tile{size}) where {T, size}
42+
vec_len = 16 ÷ sizeof(T)
43+
N = (sizeof(T) * vec_len) ÷ sizeof(Float32)
44+
res = MArray{Tuple{size[1] ÷ vec_len, size[2]}, NTuple{N, VecElement{Float32}}}(undef)
45+
46+
@unroll for j = 1 : size[2]
47+
@unroll for i = 1 : vec_len : size[1]
48+
t = translate(tile, (i - 1, j - 1))
49+
ind = Tuple(t.index) .+ 1
50+
@inbounds linear_index = LinearIndices(Base.size(workspace))[ind...]
51+
@inbounds res[i, j] = vloada(Vec{vec_len, T}, pointer(workspace), linear_index)
52+
end
53+
end
54+
55+
return res
4256
end
4357

44-
@inline function store!(::Type{AlignedColMajor{T}}, workspace, value, tile::Tile, logical_size::NamedTuple) where {T}
45-
N = 16 ÷ sizeof(T)
46-
ptr = pointer(workspace, linearise(tile.base, logical_size))
47-
return vstorea!(Vec{N, T}, ptr, value, linearise(tile.offset, logical_size))
58+
@inline function store!(::Type{AlignedColMajor{T}}, workspace, value, tile::Tile{size}) where {T, size}
59+
vec_len = 16 ÷ sizeof(T)
60+
61+
@unroll for j = 1 : size[2]
62+
@unroll for i = 1 : vec_len : size[1]
63+
t = translate(tile, (i - 1, j - 1))
64+
ind = Tuple(t.index) .+ 1
65+
@inbounds linear_index = LinearIndices(Base.size(workspace))[ind...]
66+
vstorea!(Vec{vec_len, T}, pointer(workspace), value[i, j], linear_index)
67+
end
68+
end
4869
end
4970

5071
end

src/device/matmul_kernels/operator.jl

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,10 @@ using CUDAnative.Tiling
99
# Default definition for padded layouts
1010
# -------------------------------------
1111

12-
# Fragment types
13-
for f in (:fragtype_a, :fragtype_b, :fragtype_accum)
12+
for f in (:fragtype_a, :fragtype_b, :fragtype_accum, :load_a, :load_b, :load_c, :store_d)
1413
@eval @inline $f(op, ::Type{Layout.Padded{L, P}}, args...) where {L, P} = $f(op, L, args...)
1514
end
1615

17-
# Load fragments
18-
for f in (:load_a, :load_b, :load_c)
19-
@eval @inline $f(op, ::Type{Layout.Padded{L, P}}, workspace, tile::Tile, logical_size::NamedTuple) where {L, P} = $f(op, L, workspace, tile, Layout.pad_logical_coord(Layout.Padded{L, P}, logical_size))
20-
end
21-
22-
# Store fragments
23-
@inline store_d(op, ::Type{Layout.Padded{L, P}}, workspace, frag, tile::Tile, logical_size::NamedTuple) where {L, P} = store_d(op, L, workspace, frag, tile, Layout.pad_logical_coord(Layout.Padded{L, P}, logical_size))
24-
2516
# ----
2617
# WMMA
2718
# ----
@@ -34,28 +25,36 @@ struct WMMAOp{M, N, K} end
3425
@inline fragtype_b(::Type{WMMAOp{16, 16, 16}}, ::Type{Layout.AlignedColMajor{Float16}}) = WMMA.Fragment{16, 16, 16, 16, Float16, WMMA.ColMajor, WMMA.MatrixB}
3526
@inline fragtype_accum(::Type{WMMAOp{16, 16, 16}}, ::Type{Layout.AlignedColMajor{Float32}}) = WMMA.Fragment{16, 16, 16, 8, Float32, WMMA.Unspecified, WMMA.Accumulator}
3627

37-
function load_a(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile, logical_size::NamedTuple) where {M, N, K}
28+
function load_a(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
3829
conf = WMMA.Config{M, N, K, Float32}
39-
ptr = pointer(workspace, linearise(tile.index, logical_size))
40-
return WMMA.load_a(ptr, logical_size.M, WMMA.ColMajor, conf)
30+
ind = Tuple(tile.index) .+ 1
31+
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
32+
ptr = pointer(workspace, linear_index)
33+
return WMMA.load_a(ptr, size(workspace, 1), WMMA.ColMajor, conf)
4134
end
4235

43-
function load_b(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile, logical_size::NamedTuple) where {M, N, K}
36+
function load_b(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
4437
conf = WMMA.Config{M, N, K, Float32}
45-
ptr = pointer(workspace, linearise(tile.index, logical_size))
46-
return WMMA.load_b(ptr, logical_size.K, WMMA.ColMajor, conf)
38+
ind = Tuple(tile.index) .+ 1
39+
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
40+
ptr = pointer(workspace, linear_index)
41+
return WMMA.load_b(ptr, size(workspace, 1), WMMA.ColMajor, conf)
4742
end
4843

49-
function load_c(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, tile::Tile, logical_size::NamedTuple) where {M, N, K}
44+
function load_c(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, tile::Tile) where {M, N, K}
5045
conf = WMMA.Config{M, N, K, Float32}
51-
ptr = pointer(workspace, linearise(tile.index, logical_size))
52-
return WMMA.load_c(ptr, logical_size.M, WMMA.ColMajor, conf)
46+
ind = Tuple(tile.index) .+ 1
47+
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
48+
ptr = pointer(workspace, linear_index)
49+
return WMMA.load_c(ptr, size(workspace, 1), WMMA.ColMajor, conf)
5350
end
5451

55-
function store_d(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, frag, tile::Tile, logical_size::NamedTuple) where {M, N, K}
52+
function store_d(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, frag, tile::Tile) where {M, N, K}
5653
conf = WMMA.Config{M, N, K, Float32}
57-
ptr = pointer(workspace, linearise(tile.index, logical_size))
58-
WMMA.store_d(ptr, frag, logical_size.M, WMMA.ColMajor, conf)
54+
ind = Tuple(tile.index) .+ 1
55+
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
56+
ptr = pointer(workspace, linear_index)
57+
WMMA.store_d(ptr, frag, size(workspace, 1), WMMA.ColMajor, conf)
5958
end
6059

6160
function mma(::Type{WMMAOp{M, N, K}}, a_frag, b_frag, c_frag) where {M, N, K}

0 commit comments

Comments
 (0)