Skip to content

Commit 640c16b

Browse files
authored
[ET-VK] Implementation of gather (#15749)
Title says it all! This diff implements the gather op in ET-VK https://docs.pytorch.org/docs/stable/generated/torch.gather.html Differential Revision: [D86674167](https://our.internmc.facebook.com/intern/diff/D86674167/)
1 parent d300a81 commit 640c16b

File tree

8 files changed

+296
-2
lines changed

8 files changed

+296
-2
lines changed

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ def register_view_ops():
711711
exir_ops.edge.aten.unsqueeze_copy.default,
712712
exir_ops.edge.aten.clone.default,
713713
exir_ops.edge.aten.permute_copy.default,
714+
exir_ops.edge.aten.gather.default,
714715
]
715716
)
716717
def register_view_ops_with_buffer_meta():
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
15+
${define_active_storage_type("buffer")}
16+
${define_required_extensions(DTYPE)}
17+
18+
#extension GL_EXT_control_flow_attributes : require
19+
20+
layout(std430) buffer;
21+
22+
#include "indexing.glslh"
23+
24+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
25+
${layout_declare_tensor(B, "r", "t_input", DTYPE, "buffer")}
26+
${layout_declare_tensor(B, "r", "t_index", "int", "buffer")}
27+
28+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
29+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
30+
${layout_declare_ubo(B, "BufferMetadata", "index")}
31+
32+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
33+
34+
layout(constant_id = 3) const int gather_dim = 0;
35+
36+
void main() {
37+
const uint out_bufi = gl_GlobalInvocationID.x;
38+
if (out_of_bounds(out_bufi, outp)) {
39+
return;
40+
}
41+
42+
TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi);
43+
44+
// Load the index value at the same position in the index tensor
45+
const uint index_bufi = tensor_idx_to_linear_idx(index, out_tidx);
46+
const int gather_idx = t_index[index_bufi];
47+
48+
// Construct the input tensor index by replacing the gather dimension
49+
// with the gathered index value
50+
TensorIndex input_tidx = out_tidx;
51+
input_tidx.data[div_4(gather_dim)][mod_4(gather_dim)] = gather_idx;
52+
53+
// Load from input tensor and store to output
54+
const uint input_bufi = tensor_idx_to_linear_idx(inp, input_tidx);
55+
56+
t_out[out_bufi] = t_input[input_bufi];
57+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
gather_buffer:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: buffer
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
- VALUE: float
15+
shader_variants:
16+
- NAME: gather_buffer
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, "texture3d")}
14+
#define T ${texel_load_component_type(DTYPE, "texture3d")}
15+
16+
${define_active_storage_type("texture3d")}
17+
${define_required_extensions(DTYPE)}
18+
19+
#extension GL_EXT_control_flow_attributes : require
20+
21+
layout(std430) buffer;
22+
23+
#include "common.glslh"
24+
#include "indexing.glslh"
25+
26+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
27+
${layout_declare_tensor(B, "r", "t_input", DTYPE, "texture3d")}
28+
${layout_declare_tensor(B, "r", "t_index", "int", "texture3d")}
29+
30+
${layout_declare_ubo(B, "TextureMetadata", "outp")}
31+
${layout_declare_ubo(B, "TextureMetadata", "inp")}
32+
${layout_declare_ubo(B, "TextureMetadata", "index")}
33+
34+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
35+
36+
layout(constant_id = 3) const int gather_dim = 0;
37+
38+
void main() {
39+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
40+
41+
if (out_of_bounds(out_pos, outp)) {
42+
return;
43+
}
44+
45+
TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos);
46+
ivec4 idx_texel = texelFetch(t_index, out_pos, 0);
47+
48+
VEC4_T out_texel = VEC4_T(0);
49+
50+
int limit = min(
51+
4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]);
52+
for (int comp = 0; comp < 4; comp++) {
53+
TensorIndex4D input_tidx = out_tidx;
54+
int gather_idx = idx_texel[comp];
55+
input_tidx.data[gather_dim] = gather_idx;
56+
57+
TextureElementIndex input_elem_pos = tensor4d_idx_to_texture_element_idx_simple(
58+
inp, input_tidx);
59+
60+
VEC4_T input_texel = texelFetch(t_input, input_elem_pos.pos, 0);
61+
out_texel[comp] = input_texel[input_elem_pos.comp];
62+
63+
out_tidx.data[outp.packed_dim]++;
64+
}
65+
66+
imageStore(t_out, out_pos, out_texel);
67+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
gather_texture:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
generate_variant_forall:
11+
DTYPE:
12+
- VALUE: half
13+
- VALUE: float
14+
shader_variants:
15+
- NAME: gather_texture3d
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16+
17+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18+
19+
#include <executorch/backends/vulkan/runtime/utils/StorageUtils.h>
20+
21+
namespace vkcompute {
22+
23+
using utils::GPUMemoryLayout;
24+
using utils::StorageType;
25+
26+
void resize_gather_node(
27+
ComputeGraph* graph,
28+
const std::vector<ArgGroup>& args,
29+
const std::vector<ValueRef>& resize_args) {
30+
(void)resize_args;
31+
const ValueRef out = args.at(0).refs.at(0);
32+
const ValueRef index = args.at(1).refs.at(1);
33+
34+
// Output shape is the same as index shape
35+
std::vector<int64_t> out_sizes = graph->sizes_of(index);
36+
graph->virtual_resize(out, out_sizes);
37+
}
38+
39+
void add_gather_node(
40+
ComputeGraph& graph,
41+
const ValueRef input,
42+
const int64_t dim,
43+
const ValueRef index,
44+
const ValueRef out) {
45+
std::string kernel_name = "gather";
46+
kernel_name.reserve(kShaderNameReserve);
47+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
48+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
49+
50+
vkapi::ParamsBindList param_ubos = {
51+
graph.meta_ubo(out), graph.meta_ubo(input), graph.meta_ubo(index)};
52+
53+
const int64_t dim_whcn = graph.dim_of(input) - dim - 1;
54+
55+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
56+
graph,
57+
VK_KERNEL_FROM_STR(kernel_name),
58+
default_pick_global_wg_size,
59+
default_pick_local_wg_size,
60+
// Inputs and Outputs
61+
{{out, vkapi::kWrite}, {{input, index}, vkapi::kRead}},
62+
// Shader params buffers
63+
param_ubos,
64+
// Push Constants
65+
{},
66+
// Specialization Constants
67+
{static_cast<int32_t>(dim_whcn)},
68+
// Resize Args
69+
{},
70+
// Resizing Logic
71+
resize_gather_node));
72+
}
73+
74+
void gather(ComputeGraph& graph, const std::vector<ValueRef>& args) {
75+
ValueRef input = args[0];
76+
ValueRef dim_ref = args[1];
77+
ValueRef index = args[2];
78+
ValueRef out = args[4];
79+
80+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
81+
82+
add_gather_node(graph, input, dim, index, out);
83+
}
84+
85+
REGISTER_OPERATORS {
86+
VK_REGISTER_OP(aten.gather.default, gather);
87+
}
88+
89+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,55 @@ def get_embedding_inputs():
11401140
return test_suite_wpack
11411141

11421142

1143+
@register_test_suite("aten.gather.default")
1144+
def get_gather_inputs():
1145+
Test = namedtuple("GatherTest", ["input", "dim", "index"])
1146+
Test.__new__.__defaults__ = (None, None, None)
1147+
1148+
test_cases = [
1149+
# Simple 2D case
1150+
Test(input=[4, 4], dim=1, index=[[1, 2], [2, 1], [3, 3], [3, 1]]),
1151+
# # 1D cases
1152+
Test(input=[10], dim=0, index=[0, 2, 5, 7, 9]),
1153+
Test(input=[8], dim=0, index=[1, 3, 5]),
1154+
# # 2D cases with different dims
1155+
Test(input=[5, 8], dim=0, index=[[0, 1], [2, 3], [4, 0]]),
1156+
Test(
1157+
input=[5, 8],
1158+
dim=1,
1159+
index=[[0, 2, 4], [1, 3, 5], [6, 7, 0], [1, 2, 3], [4, 5, 6]],
1160+
),
1161+
# # 3D cases
1162+
Test(
1163+
input=[3, 4, 5],
1164+
dim=0,
1165+
index=[
1166+
[[0, 1, 2, 0, 1], [1, 2, 0, 1, 2], [2, 0, 1, 2, 0], [0, 1, 2, 0, 1]]
1167+
],
1168+
),
1169+
Test(
1170+
input=[3, 4, 5],
1171+
dim=1,
1172+
index=[
1173+
[[0, 1, 2, 3], [1, 2, 3, 0], [2, 3, 0, 1], [3, 0, 1, 2], [0, 1, 2, 3]]
1174+
],
1175+
),
1176+
Test(
1177+
input=[3, 4, 5], dim=2, index=[[[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 0]]]
1178+
),
1179+
]
1180+
1181+
test_suite = VkTestSuite(
1182+
[tuple(tc) + (False, "false", "false") for tc in test_cases]
1183+
)
1184+
1185+
test_suite.dtypes = ["at::kFloat"]
1186+
test_suite.layouts = ["utils::kWidthPacked", "utils::kChannelsPacked"]
1187+
test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
1188+
1189+
return test_suite
1190+
1191+
11431192
@register_test_suite("aten.unsqueeze_copy.default")
11441193
def get_unsqueeze_inputs():
11451194
test_suite = VkTestSuite(

backends/vulkan/test/op_tests/utils/gen_correctness_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def generate_suite_cpp(self) -> str:
363363
static_cast<int64_t>(indices[0].size())}};
364364
365365
// Flatten indices as from_blob reads garbage otherwise.
366-
std::vector<int64_t> acc;
366+
std::vector<int32_t> acc;
367367
for (auto& vec: indices) {{
368368
acc.insert(acc.end(), vec.begin(), vec.end());
369369
}}
@@ -380,7 +380,7 @@ def generate_suite_cpp(self) -> str:
380380
static_cast<int64_t>(indices[0][0].size())}};
381381
382382
// Flatten indices as from_blob reads garbage otherwise.
383-
std::vector<int64_t> acc;
383+
std::vector<int32_t> acc;
384384
for (auto& v: indices) {{
385385
for (auto& vv: v) {{
386386
acc.insert(acc.end(), vv.begin(), vv.end());

0 commit comments

Comments
 (0)