Skip to content

Commit 9b2686b

Browse files
committed
Update base for Update on "[Executorch] parallelize op_choose_qparams"
When doing prefill for quantized kv cache, with large prefill length, parallelizing this op helps. Differential Revision: [D84962234](https://our.internmc.facebook.com/intern/diff/D84962234/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D84962234/)! [ghstack-poisoned]
2 parents 993254c + 15a0fcd commit 9b2686b

File tree

145 files changed

+2702
-1280
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

145 files changed

+2702
-1280
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e6f766c7d750d40603eee3f66c5915bac606b3ea
1+
556fc09a9f67f24ca5591ec049c5d0c347c5f62a

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ jobs:
347347
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
348348
setup_script_args="--target-toolchain zephyr"
349349
toolchain_prefix=arm-zephyr-eabi-
350-
threshold="135240" # 132 KiB
350+
threshold="135656" # 132 KiB
351351
toolchain_cmake=examples/zephyr/x86_64-linux-arm-zephyr-eabi-gcc.cmake
352352
else
353353
echo "Fail unsupport OS selection ${{ matrix.os }}"

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
1414
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
1515
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
16+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1617
from executorch.exir.backend.utils import WhyNoPartitionReporter
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.pass_base import ExportPass
@@ -50,6 +51,15 @@ def get_view(op):
5051
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5152

5253

54+
def get_quantization(op):
55+
"""Returns quant and dequant op of same type (per_channel/ tensor) as op if op is a dequant node, None otherwise."""
56+
if op in DQ_OPS:
57+
# Input of op can be placeholder, can't use that to get quant node directly.
58+
quant_type_index = DQ_OPS.index(op)
59+
return Q_OPS[quant_type_index], op
60+
return None
61+
62+
5363
class DecomposeMeanDimPass(ArmPass):
5464
"""
5565
Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
@@ -121,6 +131,7 @@ def call_operator(self, op, args, kwargs, meta):
121131
dims_to_reduce = [dim - 1 for dim in dims_to_reduce]
122132

123133
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
134+
x = self._maybe_insert_q_dq_after(x, meta)
124135

125136
# Reduce (h,w) dims by avg pool if possible
126137
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
@@ -133,7 +144,7 @@ def call_operator(self, op, args, kwargs, meta):
133144
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]
134145

135146
x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)
136-
147+
x = self._maybe_insert_q_dq_after(x, meta)
137148
# Reduce remaining dims by sum
138149
x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)
139150

@@ -156,6 +167,45 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
156167
full = super().call_operator(
157168
full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True
158169
)
170+
if (quant_ops := get_quantization(input_node.node.target)) is not None:
171+
# Insert Q and DQ nodes after full op.
172+
# Since the value of full is known, we can compute quant params such that dq(q_max_value)
173+
q_op, dq_op = quant_ops
174+
qmax = input_node.node.args[4]
175+
full_quant_args = (
176+
1 / (N * qmax), # Scale to map qmax to 1/N
177+
0, # Zero point
178+
*input_node.node.args[3:],
179+
)
180+
q_args = (full, *full_quant_args)
181+
full = super().call_operator(
182+
q_op,
183+
q_args,
184+
kwargs={},
185+
meta=meta,
186+
updated=True,
187+
)
188+
dq_args = (full, *full_quant_args)
189+
full = super().call_operator(
190+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
191+
)
192+
193+
# Insert Q and DQ nodes after sum op.
194+
# Scale needs to be adjusted with N, since it was computed on data after the division with N.
195+
sum_quant_args = (input_node.node.args[1] * N, *input_node.node.args[2:])
196+
q_args = (sum, *sum_quant_args)
197+
sum = super().call_operator(
198+
q_op,
199+
q_args,
200+
kwargs={},
201+
meta=meta,
202+
updated=True,
203+
)
204+
dq_args = (sum, *sum_quant_args)
205+
sum = super().call_operator(
206+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
207+
)
208+
159209
return super().call_operator(mul_op, (sum, full), {}, meta, True)
160210

161211
def _reduce_by_average_pool(self, op, input_node, dims, meta):
@@ -190,10 +240,38 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
190240
)
191241

192242
if is_supported:
243+
out = super().call_operator(avgpool_op, args, {}, meta, True)
244+
out = self._maybe_insert_q_dq_after(out, meta)
193245
return (
194-
super().call_operator(avgpool_op, args, {}, meta, True),
246+
out,
195247
dims_to_reduce_by_sum,
196248
)
197249

198250
else:
199251
return input_node, dims
252+
253+
def _maybe_insert_q_dq_after(self, op, meta):
254+
"""If the input node of op is a dequant node, insert a q-dq pair after op with identical quantization parameters."""
255+
256+
if len(op.node.all_input_nodes) > 1:
257+
raise ValueError(
258+
f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}"
259+
)
260+
input_node = op.node.all_input_nodes[0]
261+
if (quant_ops := get_quantization(input_node.target)) is not None:
262+
q_op, dq_op = quant_ops
263+
quant_args = list(input_node.args[1:])
264+
q_args = (op, *quant_args)
265+
out = super().call_operator(
266+
q_op,
267+
q_args,
268+
kwargs={},
269+
meta=meta,
270+
updated=True,
271+
)
272+
dq_args = (out, *quant_args)
273+
return super().call_operator(
274+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
275+
)
276+
else:
277+
return op

backends/arm/runtime/EthosUBackend.cpp

Lines changed: 172 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
326326
ET_LOG(Error, "Ethos-U invocation failed error (%d)", result);
327327
return Error::InvalidProgram;
328328
}
329-
int tensor_dim = 0, io_dim = 0;
329+
size_t tensor_bytes_total = 0;
330+
size_t io_bytes_total = 0;
330331
// Write outputs from scratch into EValue pointers
331332
for (int i = 0; i < handles.outputs->count; i++) {
332333
int tensor_count = 1, io_count = 1;
@@ -338,23 +339,39 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
338339
calculate_dimensions(
339340
tensor_out, &handles.outputs->io[i], &tensor_count, &io_count);
340341

341-
// At times the topological order of the outputs may change.
342-
// Lets instead ensure that the sum of dimensions match.
343-
tensor_dim = tensor_dim + tensor_count;
344-
io_dim = io_dim + io_count;
342+
size_t tensor_bytes = tensor_out.nbytes();
343+
size_t io_bytes = static_cast<size_t>(io_count) *
344+
static_cast<size_t>(handles.outputs->io[i].elem_size);
345+
346+
if (tensor_bytes != io_bytes) {
347+
Error status = copy_with_layout_adjustment(
348+
handles.outputs->io[i], i, output_addr, tensor_out, tensor_bytes);
349+
if (status != Error::Ok) {
350+
return status;
351+
}
352+
io_bytes_total += tensor_bytes;
353+
} else {
354+
EXECUTORCH_PROF_SCOPE(
355+
event_tracer, "+EthosUBackend::execute()handles.output.memcpy()");
345356

346-
EXECUTORCH_PROF_SCOPE(
347-
event_tracer, "+EthosUBackend::execute()handles.output.memcpy()");
357+
memcpy(
358+
tensor_out.mutable_data_ptr<char>(),
359+
static_cast<const char*>(output_addr),
360+
tensor_bytes);
361+
io_bytes_total += io_bytes;
362+
}
348363

349-
memcpy(
350-
tensor_out.mutable_data_ptr<char>(),
351-
static_cast<const char*>(output_addr),
352-
tensor_out.nbytes());
364+
// At times the topological order of the outputs may change.
365+
// Lets instead ensure that the sum of output bytes match.
366+
tensor_bytes_total += tensor_bytes;
353367
}
354-
if (tensor_dim != io_dim) {
368+
if (tensor_bytes_total != io_bytes_total) {
355369
ET_LOG(Error, "Total output tensor sizes do not match");
356370
ET_LOG(
357-
Error, "Program expects size of %d but got %d", tensor_dim, io_dim);
371+
Error,
372+
"Program expects %zu bytes but got %zu",
373+
io_bytes_total,
374+
tensor_bytes_total);
358375
return Error::InvalidProgram;
359376
}
360377
return Error::Ok;
@@ -365,6 +382,147 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
365382
}
366383

367384
private:
385+
// Copies Vela output into the ExecuTorch tensor, adjusting for padding or
386+
// packed layouts produced by the delegate.
387+
Error copy_with_layout_adjustment(
388+
const VelaIO& output_io,
389+
int output_index,
390+
const char* src,
391+
executorch::aten::Tensor& tensor_out,
392+
size_t tensor_bytes) const {
393+
const int elem_size = output_io.elem_size;
394+
if (elem_size == 0) {
395+
ET_LOG(
396+
Error, "Ethos-U output %d reports zero element size", output_index);
397+
return Error::InvalidProgram;
398+
}
399+
400+
size_t chunk_count = 1;
401+
for (int dim = 0; dim < shapeDim - 1; ++dim) {
402+
const int vela_dim = output_io.shape[dim];
403+
chunk_count *= static_cast<size_t>(vela_dim == 0 ? 1 : vela_dim);
404+
}
405+
const int last_dim = output_io.shape[shapeDim - 1];
406+
const size_t vela_chunk_elems =
407+
static_cast<size_t>(last_dim == 0 ? 1 : last_dim);
408+
const size_t vela_chunk_size =
409+
vela_chunk_elems * static_cast<size_t>(elem_size);
410+
411+
if (tensor_bytes % chunk_count != 0) {
412+
ET_LOG(
413+
Error,
414+
"Ethos-U output %d tensor bytes %zu not divisible by chunk count %zu",
415+
output_index,
416+
tensor_bytes,
417+
chunk_count);
418+
return Error::InvalidProgram;
419+
}
420+
421+
const size_t chunk_size = tensor_bytes / chunk_count;
422+
423+
// If Vela writes fewer bytes than the tensor expects we may need to
424+
// expand 4-bit data to 8-bit. Ethos-U outputs may be
425+
// packed 4-bit values but ExecuTorch tensors are at least 8-bit.
426+
if (vela_chunk_size < chunk_size) {
427+
if (chunk_size % vela_chunk_size != 0) {
428+
ET_LOG(
429+
Error,
430+
"Ethos-U output %d chunk bytes %zu not divisible by vela chunk bytes %zu",
431+
output_index,
432+
chunk_size,
433+
vela_chunk_size);
434+
return Error::InvalidProgram;
435+
}
436+
437+
const size_t expand_factor = chunk_size / vela_chunk_size;
438+
if (expand_factor == 2 && elem_size == 1 &&
439+
tensor_out.scalar_type() == ScalarType::Char) {
440+
return unpack_chunks_4bit_to_int8(
441+
reinterpret_cast<const uint8_t*>(src),
442+
tensor_out.mutable_data_ptr<int8_t>(),
443+
chunk_count,
444+
chunk_size,
445+
vela_chunk_size);
446+
}
447+
448+
ET_LOG(
449+
Error,
450+
"Ethos-U output %d expansion factor %zu with element size %d not supported",
451+
output_index,
452+
expand_factor,
453+
elem_size);
454+
return Error::InvalidProgram;
455+
}
456+
457+
return strip_delegate_padding(
458+
src,
459+
tensor_out.mutable_data_ptr<char>(),
460+
chunk_count,
461+
chunk_size,
462+
vela_chunk_size);
463+
}
464+
465+
Error unpack_chunks_4bit_to_int8(
466+
const uint8_t* src,
467+
int8_t* dest,
468+
size_t chunk_count,
469+
size_t dest_chunk_size,
470+
size_t src_chunk_size) const {
471+
const uint8_t* chunk_src = src;
472+
int8_t* chunk_dest = dest;
473+
for (size_t chunk_idx = 0; chunk_idx < chunk_count; ++chunk_idx) {
474+
unpack_single_chunk_4bit_to_int8(chunk_src, chunk_dest, src_chunk_size);
475+
chunk_src += src_chunk_size;
476+
chunk_dest += dest_chunk_size;
477+
}
478+
return Error::Ok;
479+
}
480+
481+
void unpack_single_chunk_4bit_to_int8(
482+
const uint8_t* src,
483+
int8_t* dest,
484+
size_t chunk_size) const {
485+
for (size_t byte_idx = 0; byte_idx < chunk_size; ++byte_idx) {
486+
const uint8_t packed = src[byte_idx];
487+
int8_t low = static_cast<int8_t>(packed & 0x0F);
488+
int8_t high = static_cast<int8_t>((packed >> 4) & 0x0F);
489+
if (low >= 8) {
490+
low -= 16;
491+
}
492+
if (high >= 8) {
493+
high -= 16;
494+
}
495+
dest[2 * byte_idx] = low;
496+
dest[2 * byte_idx + 1] = high;
497+
}
498+
}
499+
500+
Error strip_delegate_padding(
501+
const char* src,
502+
char* dest,
503+
size_t chunk_count,
504+
size_t dest_chunk_size,
505+
size_t src_chunk_size) const {
506+
if (dest_chunk_size > src_chunk_size) {
507+
ET_LOG(
508+
Error,
509+
"dest chunk size %zu must not exceed src chunk size %zu",
510+
dest_chunk_size,
511+
src_chunk_size);
512+
return Error::InvalidProgram;
513+
}
514+
if (src == nullptr || dest == nullptr) {
515+
ET_LOG(Error, "Ethos-U padded copy received null buffer");
516+
return Error::InvalidState;
517+
}
518+
for (size_t chunk_idx = 0; chunk_idx < chunk_count; ++chunk_idx) {
519+
memcpy(dest, src, dest_chunk_size);
520+
src += src_chunk_size;
521+
dest += dest_chunk_size;
522+
}
523+
return Error::Ok;
524+
}
525+
368526
void calculate_dimensions(
369527
const executorch::aten::Tensor tensor,
370528
VelaIO* io,
@@ -389,4 +547,4 @@ static auto registered = register_backend(backend_id);
389547

390548
} // namespace arm
391549
} // namespace backends
392-
} // namespace executorch
550+
} // namespace executorch

0 commit comments

Comments
 (0)