diff --git a/src/ATen/native/transformers/Attention.cpp b/src/ATen/native/transformers/Attention.cpp index df0a2c9bc0..186b8f0d28 100644 --- a/src/ATen/native/transformers/Attention.cpp +++ b/src/ATen/native/transformers/Attention.cpp @@ -3,6 +3,11 @@ #include #include #include +#include +#include +#include +#include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -294,5 +299,59 @@ std::tuple native_multi_head_attention_xpu( return std::make_tuple(std::move(proj), std::move(qkt)); } +/** + * get the mask for dropout. only used for testing, not much + * attention is paid to performance + */ +at::Tensor& _fill_mem_eff_dropout_mask_( + Tensor& self, + double dropout_p, + const int64_t seed, + const int64_t offset) { + auto mask = std::get<1>(xpu::dropout_kernel(self, dropout_p, true)); + self.copy_(mask); + return self; +} + +/** + * Fall back implementation of efficient attention + */ +std::tuple +_scaled_dot_product_efficient_attention_xpu( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const std::optional& attn_bias, + bool compute_log_sumexp, + double dropout_p, + bool is_causal, + std::optional scale) { + // Used for tracking usage statistics + C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention"); + constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1; + int64_t batch_size = query.size(0); + + if (batch_size > MAX_BATCH_SIZE) { + TORCH_CHECK( + dropout_p == 0.0, + "Efficient attention cannot produce valid seed and offset outputs when " + "the batch size exceeds (", + MAX_BATCH_SIZE, + ")."); + } + auto res = at::_scaled_dot_product_attention_math( + query, + key, + value, + attn_bias, + dropout_p, + is_causal, + std::nullopt, /*dropout_mask*/ + scale, + true); + return std::make_tuple( + std::get<0>(res), std::get<1>(res), Tensor(), Tensor()); +} + } // namespace native } // namespace at diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index a3281791de..11bb54535f 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -7922,6 +7922,17 @@ SparseCsrXPU: angle_sparse_csr_out tags: pointwise +- func: _fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!) + variants: function + dispatch: + XPU: _fill_mem_eff_dropout_mask_ + tags: nondeterministic_seeded + +- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) + dispatch: + XPU: _scaled_dot_product_efficient_attention_xpu + tags: nondeterministic_seeded + - func: special_airy_ai(Tensor x) -> Tensor python_module: special structured_delegate: special_airy_ai.out