Skip to content

Commit 9185bb2

Browse files
committed
Feat: Load/Store masked API
1. Adds new masked API compile time masks (store_masked and load_masked) 2. General use case optimization 3. New tests 4. x86 kernels 5. Adds new APIs to batch_bool_constant for convenience resembling #include<bit> 6. Tests the new APIs
1 parent 37e5d9f commit 9185bb2

14 files changed

+1533
-12
lines changed

docs/source/api/data_transfer.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Data transfer
1010
From memory:
1111

1212
+---------------------------------------+----------------------------------------------------+
13-
| :cpp:func:`load` | load values from memory |
13+
| :cpp:func:`load` | load values from memory (optionally masked) |
1414
+---------------------------------------+----------------------------------------------------+
1515
| :cpp:func:`load_aligned` | load values from aligned memory |
1616
+---------------------------------------+----------------------------------------------------+
@@ -30,7 +30,7 @@ From a scalar:
3030
To memory:
3131

3232
+---------------------------------------+----------------------------------------------------+
33-
| :cpp:func:`store` | store values to memory |
33+
| :cpp:func:`store` | store values to memory (optionally masked) |
3434
+---------------------------------------+----------------------------------------------------+
3535
| :cpp:func:`store_aligned` | store values to aligned memory |
3636
+---------------------------------------+----------------------------------------------------+

include/xsimd/arch/common/xsimd_common_arithmetic.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <limits>
1717
#include <type_traits>
1818

19+
#include "../../types/xsimd_batch_constant.hpp"
1920
#include "./xsimd_common_details.hpp"
2021

2122
namespace xsimd

include/xsimd/arch/common/xsimd_common_memory.hpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#define XSIMD_COMMON_MEMORY_HPP
1414

1515
#include <algorithm>
16+
#include <array>
1617
#include <complex>
1718
#include <stdexcept>
1819

@@ -341,6 +342,102 @@ namespace xsimd
341342
return detail::load_unaligned<A>(mem, cvt, common {}, detail::conversion_type<A, T_in, T_out> {});
342343
}
343344

345+
template <class A, class T>
346+
XSIMD_INLINE batch<T, A> load(T const* mem, aligned_mode, requires_arch<A>) noexcept
347+
{
348+
return load_aligned<A>(mem, convert<T> {}, A {});
349+
}
350+
351+
template <class A, class T>
352+
XSIMD_INLINE batch<T, A> load(T const* mem, unaligned_mode, requires_arch<A>) noexcept
353+
{
354+
return load_unaligned<A>(mem, convert<T> {}, A {});
355+
}
356+
357+
template <class A, class T_in, class T_out, bool... Values, class alignment>
358+
XSIMD_INLINE batch<T_out, A>
359+
load_masked(T_in const* mem, batch_bool_constant<T_out, A, Values...>, convert<T_out>, alignment, requires_arch<common>) noexcept
360+
{
361+
constexpr std::size_t size = batch<T_out, A>::size;
362+
alignas(A::alignment()) std::array<T_out, size> buffer {};
363+
constexpr bool mask[size] = { Values... };
364+
365+
for (std::size_t i = 0; i < size; ++i)
366+
buffer[i] = mask[i] ? static_cast<T_out>(mem[i]) : T_out(0);
367+
368+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
369+
}
370+
371+
template <class A, class T_in, class T_out, bool... Values, class alignment>
372+
XSIMD_INLINE void
373+
store_masked(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...>, alignment, requires_arch<common>) noexcept
374+
{
375+
constexpr std::size_t size = batch<T_in, A>::size;
376+
constexpr bool mask[size] = { Values... };
377+
378+
for (std::size_t i = 0; i < size; ++i)
379+
if (mask[i])
380+
{
381+
mem[i] = static_cast<T_out>(src.get(i));
382+
}
383+
}
384+
385+
template <class A, bool... Values, class Mode>
386+
XSIMD_INLINE batch<int32_t, A> load_masked(int32_t const* mem, batch_bool_constant<int32_t, A, Values...>, convert<int32_t>, Mode, requires_arch<A>) noexcept
387+
{
388+
const auto f = load_masked<A>(reinterpret_cast<const float*>(mem), batch_bool_constant<float, A, Values...> {}, convert<float> {}, Mode {}, A {});
389+
return bitwise_cast<int32_t>(f);
390+
}
391+
392+
template <class A, bool... Values, class Mode>
393+
XSIMD_INLINE batch<uint32_t, A> load_masked(uint32_t const* mem, batch_bool_constant<uint32_t, A, Values...>, convert<uint32_t>, Mode, requires_arch<A>) noexcept
394+
{
395+
const auto f = load_masked<A>(reinterpret_cast<const float*>(mem), batch_bool_constant<float, A, Values...> {}, convert<float> {}, Mode {}, A {});
396+
return bitwise_cast<uint32_t>(f);
397+
}
398+
399+
template <class A, bool... Values, class Mode>
400+
XSIMD_INLINE typename std::enable_if<types::has_simd_register<double, A>::value, batch<int64_t, A>>::type
401+
load_masked(int64_t const* mem, batch_bool_constant<int64_t, A, Values...>, convert<int64_t>, Mode, requires_arch<A>) noexcept
402+
{
403+
const auto d = load_masked<A>(reinterpret_cast<const double*>(mem), batch_bool_constant<double, A, Values...> {}, convert<double> {}, Mode {}, A {});
404+
return bitwise_cast<int64_t>(d);
405+
}
406+
407+
template <class A, bool... Values, class Mode>
408+
XSIMD_INLINE typename std::enable_if<types::has_simd_register<double, A>::value, batch<uint64_t, A>>::type
409+
load_masked(uint64_t const* mem, batch_bool_constant<uint64_t, A, Values...>, convert<uint64_t>, Mode, requires_arch<A>) noexcept
410+
{
411+
const auto d = load_masked<A>(reinterpret_cast<const double*>(mem), batch_bool_constant<double, A, Values...> {}, convert<double> {}, Mode {}, A {});
412+
return bitwise_cast<uint64_t>(d);
413+
}
414+
415+
template <class A, bool... Values, class Mode>
416+
XSIMD_INLINE void store_masked(int32_t* mem, batch<int32_t, A> const& src, batch_bool_constant<int32_t, A, Values...>, Mode, requires_arch<A>) noexcept
417+
{
418+
store_masked<A>(reinterpret_cast<float*>(mem), bitwise_cast<float>(src), batch_bool_constant<float, A, Values...> {}, Mode {}, A {});
419+
}
420+
421+
template <class A, bool... Values, class Mode>
422+
XSIMD_INLINE void store_masked(uint32_t* mem, batch<uint32_t, A> const& src, batch_bool_constant<uint32_t, A, Values...>, Mode, requires_arch<A>) noexcept
423+
{
424+
store_masked<A>(reinterpret_cast<float*>(mem), bitwise_cast<float>(src), batch_bool_constant<float, A, Values...> {}, Mode {}, A {});
425+
}
426+
427+
template <class A, bool... Values, class Mode>
428+
XSIMD_INLINE typename std::enable_if<types::has_simd_register<double, A>::value, void>::type
429+
store_masked(int64_t* mem, batch<int64_t, A> const& src, batch_bool_constant<int64_t, A, Values...>, Mode, requires_arch<A>) noexcept
430+
{
431+
store_masked<A>(reinterpret_cast<double*>(mem), bitwise_cast<double>(src), batch_bool_constant<double, A, Values...> {}, Mode {}, A {});
432+
}
433+
434+
template <class A, bool... Values, class Mode>
435+
XSIMD_INLINE typename std::enable_if<types::has_simd_register<double, A>::value, void>::type
436+
store_masked(uint64_t* mem, batch<uint64_t, A> const& src, batch_bool_constant<uint64_t, A, Values...>, Mode, requires_arch<A>) noexcept
437+
{
438+
store_masked<A>(reinterpret_cast<double*>(mem), bitwise_cast<double>(src), batch_bool_constant<double, A, Values...> {}, Mode {}, A {});
439+
}
440+
344441
// rotate_right
345442
template <size_t N, class A, class T>
346443
XSIMD_INLINE batch<T, A> rotate_right(batch<T, A> const& self, requires_arch<common>) noexcept

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <type_traits>
1919

2020
#include "../types/xsimd_avx_register.hpp"
21+
#include "../types/xsimd_batch_constant.hpp"
2122

2223
namespace xsimd
2324
{
@@ -871,6 +872,132 @@ namespace xsimd
871872
return _mm256_loadu_pd(mem);
872873
}
873874

875+
// AVX helpers to avoid type-based branching in the generic load_masked
876+
namespace detail
877+
{
878+
template <class A>
879+
XSIMD_INLINE batch<float, A> maskload(float const* mem, batch<as_integer_t<float>, A> const& mask) noexcept
880+
{
881+
return _mm256_maskload_ps(mem, mask);
882+
}
883+
884+
template <class A>
885+
XSIMD_INLINE batch<double, A> maskload(double const* mem, batch<as_integer_t<double>, A> const& mask) noexcept
886+
{
887+
return _mm256_maskload_pd(mem, mask);
888+
}
889+
890+
template <class A>
891+
XSIMD_INLINE batch<float, A> zero_extend(batch<float, A> const& hi) noexcept
892+
{
893+
return _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 1);
894+
}
895+
896+
template <class A>
897+
XSIMD_INLINE batch<double, A> zero_extend(batch<double, A> const& hi) noexcept
898+
{
899+
return _mm256_insertf128_pd(_mm256_setzero_pd(), hi, 1);
900+
}
901+
902+
// allow inserting a 128-bit SSE batch into the upper half of an AVX batch
903+
template <class A, class SrcA>
904+
XSIMD_INLINE batch<float, A> zero_extend(batch<float, SrcA> const& hi) noexcept
905+
{
906+
return _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 1);
907+
}
908+
909+
template <class A, class SrcA>
910+
XSIMD_INLINE batch<double, A> zero_extend(batch<double, SrcA> const& hi) noexcept
911+
{
912+
return _mm256_insertf128_pd(_mm256_setzero_pd(), hi, 1);
913+
}
914+
}
915+
916+
// load_masked (single overload for float/double)
917+
template <class A, class T, bool... Values, class Mode, class = typename std::enable_if<std::is_floating_point<T>::value>::type>
918+
XSIMD_INLINE batch<T, A> load_masked(T const* mem, batch_bool_constant<T, A, Values...> mask, convert<T>, Mode, requires_arch<avx>) noexcept
919+
{
920+
using int_t = as_integer_t<T>;
921+
constexpr size_t half_size = batch<T, A>::size / 2;
922+
923+
XSIMD_IF_CONSTEXPR(mask.none())
924+
{
925+
return batch<T, A>(T { 0 });
926+
}
927+
else XSIMD_IF_CONSTEXPR(mask.all())
928+
{
929+
return load<A>(mem, Mode {});
930+
}
931+
// confined to lower 128-bit half → forward to SSE2
932+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= half_size)
933+
{
934+
constexpr auto mlo = ::xsimd::detail::lower_half<sse4_2>(batch_bool_constant<int_t, A, Values...> {});
935+
const auto lo = load_masked(reinterpret_cast<int_t const*>(mem), mlo, convert<int_t> {}, Mode {}, sse4_2 {});
936+
return bitwise_cast<T>(batch<int_t, A>(_mm256_zextsi128_si256(lo)));
937+
}
938+
// confined to upper 128-bit half → forward to SSE2
939+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= half_size)
940+
{
941+
constexpr auto mhi = ::xsimd::detail::upper_half<sse4_2>(mask);
942+
const auto hi = load_masked(mem + half_size, mhi, convert<T> {}, Mode {}, sse4_2 {});
943+
return detail::zero_extend<A>(hi);
944+
}
945+
else
946+
{
947+
// crossing 128-bit boundary → use 256-bit masked load
948+
return detail::maskload<A>(mem, mask.as_batch());
949+
}
950+
}
951+
952+
// store_masked
953+
namespace detail
954+
{
955+
template <class A>
956+
XSIMD_INLINE void maskstore(float* mem, batch_bool<float, A> const& mask, batch<float, A> const& src) noexcept
957+
{
958+
_mm256_maskstore_ps(mem, mask, src);
959+
}
960+
961+
template <class A>
962+
XSIMD_INLINE void maskstore(double* mem, batch_bool<double, A> const& mask, batch<double, A> const& src) noexcept
963+
{
964+
_mm256_maskstore_pd(mem, mask, src);
965+
}
966+
}
967+
968+
template <class A, class T, bool... Values, class Mode>
969+
XSIMD_INLINE void store_masked(T* mem, batch<T, A> const& src, batch_bool_constant<T, A, Values...> mask, Mode, requires_arch<avx>) noexcept
970+
{
971+
constexpr size_t half_size = batch<T, A>::size / 2;
972+
973+
XSIMD_IF_CONSTEXPR(mask.none())
974+
{
975+
return;
976+
}
977+
else XSIMD_IF_CONSTEXPR(mask.all())
978+
{
979+
src.store(mem, Mode {});
980+
}
981+
// confined to lower 128-bit half → forward to SSE2
982+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= half_size)
983+
{
984+
constexpr auto mlo = ::xsimd::detail::lower_half<sse4_2>(mask);
985+
const auto lo = detail::lower_half(src);
986+
store_masked<sse4_2>(mem, lo, mlo, Mode {}, sse4_2 {});
987+
}
988+
// confined to upper 128-bit half → forward to SSE2
989+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= half_size)
990+
{
991+
constexpr auto mhi = ::xsimd::detail::upper_half<sse4_2>(mask);
992+
const auto hi = detail::upper_half(src);
993+
store_masked<sse4_2>(mem + half_size, hi, mhi, Mode {}, sse4_2 {});
994+
}
995+
else
996+
{
997+
detail::maskstore(mem, mask.as_batch(), src);
998+
}
999+
}
1000+
8741001
// lt
8751002
template <class A>
8761003
XSIMD_INLINE batch_bool<float, A> lt(batch<float, A> const& self, batch<float, A> const& other, requires_arch<avx>) noexcept

0 commit comments

Comments
 (0)