Skip to content

Commit 4393266

Browse files
authored
Add 8/16 bits AVX2 swizzle (#1197)
* Add dynamic swizzle for uin8_t on avx2 * Add dynamic swizzle for uin16_t on avx2 * Add fallback AVX2 swizzle constant mask
1 parent 9d41ad9 commit 4393266

File tree

1 file changed

+45
-20
lines changed

1 file changed

+45
-20
lines changed

include/xsimd/arch/xsimd_avx2.hpp

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,42 +1095,67 @@ namespace xsimd
10951095
}
10961096
}
10971097

1098-
// swizzle (dynamic mask)
1098+
// swizzle (dynamic mask) on 8 and 16 bits; see avx for 32 and 64 bits versions
10991099
template <class A>
1100-
XSIMD_INLINE batch<float, A> swizzle(batch<float, A> const& self, batch<uint32_t, A> mask, requires_arch<avx2>) noexcept
1100+
XSIMD_INLINE batch<uint8_t, A> swizzle(batch<uint8_t, A> const& self, batch<uint8_t, A> mask, requires_arch<avx2>) noexcept
11011101
{
1102-
return swizzle(self, mask, avx {});
1102+
// swap lanes
1103+
__m256i swapped = _mm256_permute2x128_si256(self, self, 0x01); // [high | low]
1104+
1105+
// normalize mask taking modulo 16
1106+
batch<uint8_t, A> half_mask = mask & 0b1111u;
1107+
1108+
// permute bytes within each lane (AVX2 only)
1109+
__m256i r0 = _mm256_shuffle_epi8(self, half_mask);
1110+
__m256i r1 = _mm256_shuffle_epi8(swapped, half_mask);
1111+
1112+
// select lane by the mask index divided by 16
1113+
constexpr auto lane = batch_constant<
1114+
uint8_t, A,
1115+
00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00,
1116+
16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16> {};
1117+
batch_bool<uint8_t, A> blend_mask = (mask & 0b10000u) != lane;
1118+
return _mm256_blendv_epi8(r0, r1, blend_mask);
11031119
}
1104-
template <class A>
1105-
XSIMD_INLINE batch<double, A> swizzle(batch<double, A> const& self, batch<uint64_t, A> mask, requires_arch<avx2>) noexcept
1120+
1121+
template <class A, typename T, detail::enable_sized_t<T, 1> = 0>
1122+
XSIMD_INLINE batch<T, A> swizzle(batch<T, A> const& self, batch<uint8_t, A> const& mask, requires_arch<avx2> req) noexcept
11061123
{
1107-
batch<uint32_t, A> broadcaster = { 0, 1, 0, 1, 0, 1, 0, 1 };
1108-
constexpr uint64_t comb = 0x0000000100000001ul * 2;
1109-
return bitwise_cast<double>(swizzle(bitwise_cast<float>(self), bitwise_cast<uint32_t>(mask * comb) + broadcaster, avx2 {}));
1124+
return bitwise_cast<T>(swizzle(bitwise_cast<uint8_t>(self), mask, req));
11101125
}
11111126

11121127
template <class A>
1113-
XSIMD_INLINE batch<uint64_t, A> swizzle(batch<uint64_t, A> const& self, batch<uint64_t, A> mask, requires_arch<avx2>) noexcept
1128+
XSIMD_INLINE batch<uint16_t, A> swizzle(
1129+
batch<uint16_t, A> const& self, batch<uint16_t, A> mask, requires_arch<avx2> req) noexcept
11141130
{
1115-
return bitwise_cast<uint64_t>(swizzle(bitwise_cast<double>(self), mask, avx2 {}));
1131+
// No blend/shuffle for 16 bits, we need to use the 8 bits version
1132+
const auto self_bytes = bitwise_cast<uint8_t>(self);
1133+
// If a mask entry is k, we want 2k in low byte and 2k+1 in high byte
1134+
const auto mask_2k_2kp1 = bitwise_cast<uint8_t>((mask << 1) | (mask << 9) | 0x100);
1135+
return bitwise_cast<uint16_t>(swizzle(self_bytes, mask_2k_2kp1, req));
11161136
}
1117-
template <class A>
1118-
XSIMD_INLINE batch<int64_t, A> swizzle(batch<int64_t, A> const& self, batch<uint64_t, A> mask, requires_arch<avx2>) noexcept
1137+
1138+
template <class A, typename T, detail::enable_sized_t<T, 2> = 0>
1139+
XSIMD_INLINE batch<T, A> swizzle(batch<T, A> const& self, batch<uint16_t, A> const& mask, requires_arch<avx2> req) noexcept
11191140
{
1120-
return bitwise_cast<int64_t>(swizzle(bitwise_cast<double>(self), mask, avx2 {}));
1141+
return bitwise_cast<T>(swizzle(bitwise_cast<uint16_t>(self), mask, req));
11211142
}
1122-
template <class A>
1123-
XSIMD_INLINE batch<uint32_t, A> swizzle(batch<uint32_t, A> const& self, batch<uint32_t, A> mask, requires_arch<avx2>) noexcept
1143+
1144+
// swizzle (constant mask)
1145+
template <class A, typename T, uint8_t... Vals, detail::enable_sized_t<T, 1> = 0>
1146+
XSIMD_INLINE batch<T, A> swizzle(batch<T, A> const& self, batch_constant<uint8_t, A, Vals...> mask, requires_arch<avx2> req) noexcept
11241147
{
1125-
return swizzle(self, mask, avx {});
1148+
static_assert(sizeof...(Vals) == 32, "Must contain as many uint8_t as can fit in avx register");
1149+
return swizzle(self, mask.as_batch(), req);
11261150
}
1127-
template <class A>
1128-
XSIMD_INLINE batch<int32_t, A> swizzle(batch<int32_t, A> const& self, batch<uint32_t, A> mask, requires_arch<avx2>) noexcept
1151+
1152+
template <class A, typename T, uint16_t... Vals, detail::enable_sized_t<T, 2> = 0>
1153+
XSIMD_INLINE batch<T, A> swizzle(batch<T, A> const& self, batch_constant<uint16_t, A, Vals...> mask, requires_arch<avx2> req) noexcept
11291154
{
1130-
return bitwise_cast<int32_t>(swizzle(bitwise_cast<uint32_t>(self), mask, avx2 {}));
1155+
static_assert(sizeof...(Vals) == 16, "Must contain as many uint16_t as can fit in avx register");
1156+
return swizzle(self, mask.as_batch(), req);
11311157
}
11321158

1133-
// swizzle (constant mask)
11341159
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7>
11351160
XSIMD_INLINE batch<float, A> swizzle(batch<float, A> const& self, batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7> mask, requires_arch<avx2>) noexcept
11361161
{

0 commit comments

Comments
 (0)