@@ -27,28 +27,32 @@ struct WMMAOp{M, N, K} end
2727
2828function load_a (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
2929 conf = WMMA. Config{M, N, K, Float32}
30- linear_index = linearise (tile. index, size (workspace))
30+ ind = Tuple (tile. index) .+ 1
31+ @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
3132 ptr = pointer (workspace, linear_index)
3233 return WMMA. load_a (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
3334end
3435
3536function load_b (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
3637 conf = WMMA. Config{M, N, K, Float32}
37- linear_index = linearise (tile. index, size (workspace))
38+ ind = Tuple (tile. index) .+ 1
39+ @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
3840 ptr = pointer (workspace, linear_index)
3941 return WMMA. load_b (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
4042end
4143
4244function load_c (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, tile:: Tile ) where {M, N, K}
4345 conf = WMMA. Config{M, N, K, Float32}
44- linear_index = linearise (tile. index, size (workspace))
46+ ind = Tuple (tile. index) .+ 1
47+ @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
4548 ptr = pointer (workspace, linear_index)
4649 return WMMA. load_c (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
4750end
4851
4952function store_d (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, frag, tile:: Tile ) where {M, N, K}
5053 conf = WMMA. Config{M, N, K, Float32}
51- linear_index = linearise (tile. index, size (workspace))
54+ ind = Tuple (tile. index) .+ 1
55+ @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
5256 ptr = pointer (workspace, linear_index)
5357 WMMA. store_d (ptr, frag, size (workspace, 1 ), WMMA. ColMajor, conf)
5458end
0 commit comments