-
Notifications
You must be signed in to change notification settings - Fork 67
Bitonic_Sort #943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Bitonic_Sort #943
Changes from 23 commits
5d6322a
264650c
268949e
73aa820
dcb5e6e
b84a4bd
779815e
78a307a
ad7a4c5
b80283a
7c91744
4d253f3
f03b8b2
555dcbe
0eb301e
52a7cb8
0b60d0c
cbfa188
0de0167
1b1ba15
a20ba6e
8c7b7e5
06af50b
ecb7182
08867d6
e2937ce
686618c
c8e990d
55b7813
d0cd7a3
e8f6134
034cd33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| #ifndef _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_ | ||
| #define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_ | ||
|
|
||
| #include <nbl/builtin/hlsl/cpp_compat.hlsl> | ||
| #include <nbl/builtin/hlsl/concepts.hlsl> | ||
| #include <nbl/builtin/hlsl/math/intutil.hlsl> | ||
| #include <nbl/builtin/hlsl/pair.hlsl> | ||
|
|
||
| namespace nbl | ||
| { | ||
| namespace hlsl | ||
| { | ||
| namespace bitonic_sort | ||
| { | ||
|
|
||
| template<typename KeyType, typename ValueType, typename Comparator> | ||
| void compareExchangeWithPartner( | ||
| bool takeLarger, | ||
| NBL_REF_ARG(KeyType) loKey, | ||
| NBL_CONST_REF_ARG(KeyType) partnerLoKey, | ||
| NBL_REF_ARG(KeyType) hiKey, | ||
| NBL_CONST_REF_ARG(KeyType) partnerHiKey, | ||
| NBL_REF_ARG(ValueType) loVal, | ||
| NBL_CONST_REF_ARG(ValueType) partnerLoVal, | ||
| NBL_REF_ARG(ValueType) hiVal, | ||
| NBL_CONST_REF_ARG(ValueType) partnerHiVal, | ||
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { | ||
| const bool loSelfSmaller = comp(loKey, partnerLoKey); | ||
| const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller; | ||
| loKey = takePartnerLo ? partnerLoKey : loKey; | ||
| loVal = takePartnerLo ? partnerLoVal : loVal; | ||
|
|
||
| const bool hiSelfSmaller = comp(hiKey, partnerHiKey); | ||
| const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller; | ||
| hiKey = takePartnerHi ? partnerHiKey : hiKey; | ||
| hiVal = takePartnerHi ? partnerHiVal : hiVal; | ||
| } | ||
|
|
||
|
|
||
| template<typename KeyType, typename ValueType, typename Comparator> | ||
| void compareSwap( | ||
| bool ascending, | ||
| NBL_REF_ARG(KeyType) loKey, | ||
| NBL_REF_ARG(KeyType) hiKey, | ||
| NBL_REF_ARG(ValueType) loVal, | ||
| NBL_REF_ARG(ValueType) hiVal, | ||
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { | ||
| const bool shouldSwap = comp(hiKey, loKey); | ||
|
|
||
| const bool doSwap = (shouldSwap == ascending); | ||
|
|
||
| KeyType tempKey = loKey; | ||
| loKey = doSwap ? hiKey : loKey; | ||
| hiKey = doSwap ? tempKey : hiKey; | ||
|
|
||
| ValueType tempVal = loVal; | ||
| loVal = doSwap ? hiVal : loVal; | ||
| hiVal = doSwap ? tempVal : hiVal; | ||
| } | ||
|
|
||
| template<typename KeyType, typename ValueType> | ||
| void swap( | ||
| NBL_REF_ARG(KeyType) loKey, | ||
| NBL_REF_ARG(KeyType) hiKey, | ||
| NBL_REF_ARG(ValueType) loVal, | ||
| NBL_REF_ARG(ValueType) hiVal) | ||
| { | ||
| KeyType tempKey = loKey; | ||
| loKey = hiKey; | ||
| hiKey = tempKey; | ||
|
|
||
| ValueType tempVal = loVal; | ||
| loVal = hiVal; | ||
| hiVal = tempVal; | ||
| } | ||
|
|
||
|
|
||
|
|
||
| template<typename KeyType, typename ValueType, typename Comparator> | ||
| void compareExchangeWithPartner( | ||
| bool takeLarger, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) loPair, | ||
| NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerLoPair, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) hiPair, | ||
| NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerHiPair, | ||
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { | ||
| const bool loSelfSmaller = comp(loPair.first, partnerLoPair.first); | ||
| const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller; | ||
| loPair.first = takePartnerLo ? partnerLoPair.first : loPair.first; | ||
| loPair.second = takePartnerLo ? partnerLoPair.second : loPair.second; | ||
|
|
||
| const bool hiSelfSmaller = comp(hiPair.first, partnerHiPair.first); | ||
| const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller; | ||
| hiPair.first = takePartnerHi ? partnerHiPair.first : hiPair.first; | ||
| hiPair.second = takePartnerHi ? partnerHiPair.second : hiPair.second; | ||
| } | ||
|
||
|
|
||
| template<typename KeyType, typename ValueType, typename Comparator> | ||
| void compareSwap( | ||
| bool ascending, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) loPair, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) hiPair, | ||
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { | ||
| const bool shouldSwap = comp(hiPair.first, loPair.first); | ||
| const bool doSwap = (shouldSwap == ascending); | ||
|
|
||
| KeyType tempKey = loPair.first; | ||
| ValueType tempVal = loPair.second; | ||
| loPair.first = doSwap ? hiPair.first : loPair.first; | ||
| loPair.second = doSwap ? hiPair.second : loPair.second; | ||
| hiPair.first = doSwap ? tempKey : hiPair.first; | ||
| hiPair.second = doSwap ? tempVal : hiPair.second; | ||
| } | ||
devshgraphicsprogramming marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| template<typename KeyType, typename ValueType> | ||
| void swap( | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) loPair, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) hiPair) | ||
| { | ||
| pair<KeyType, ValueType> temp = loPair; | ||
| loPair = hiPair; | ||
| hiPair = temp; | ||
| } | ||
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| #ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_BITONIC_SORT_INCLUDED_ | ||
| #define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_BITONIC_SORT_INCLUDED_ | ||
|
|
||
| #include "nbl/builtin/hlsl/concepts/accessors/generic_shared_data.hlsl" | ||
|
|
||
| namespace nbl | ||
| { | ||
| namespace hlsl | ||
| { | ||
| namespace workgroup | ||
| { | ||
| namespace bitonic_sort | ||
| { | ||
| // The SharedMemoryAccessor MUST provide the following methods: | ||
| // * void get(uint32_t index, NBL_REF_ARG(uint32_t) value); | ||
| // * void set(uint32_t index, in uint32_t value); | ||
| // * void workgroupExecutionAndMemoryBarrier(); | ||
| template<typename T, typename V = uint32_t, typename I = uint32_t> | ||
| NBL_BOOL_CONCEPT BitonicSortSharedMemoryAccessor = concepts::accessors::GenericSharedMemoryAccessor<T, V, I>; | ||
|
|
||
| // The Accessor MUST provide the following methods: | ||
| // * void get(uint32_t index, NBL_REF_ARG(pair<KeyType, ValueType>) value); | ||
| // * void set(uint32_t index, in pair<KeyType, ValueType> value); | ||
| template<typename T, typename KeyType, typename ValueType, typename I = uint32_t> | ||
| NBL_BOOL_CONCEPT BitonicSortAccessor = concepts::accessors::GenericDataAccessor<T, pair<KeyType, ValueType>, I>; | ||
|
Comment on lines
+14
to
+25
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| } | ||
| } | ||
| } | ||
| } | ||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| // Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O. | ||
| // This file is part of the "Nabla Engine". | ||
| // For conditions of distribution and use, see copyright notice in nabla.h | ||
| #ifndef _NBL_BUILTIN_HLSL_PAIR_INCLUDED_ | ||
| #define _NBL_BUILTIN_HLSL_PAIR_INCLUDED_ | ||
|
|
||
| #include "nbl/builtin/hlsl/type_traits.hlsl" | ||
|
|
||
| namespace nbl | ||
| { | ||
| namespace hlsl | ||
| { | ||
|
|
||
| template<typename T1, typename T2> | ||
| struct pair | ||
| { | ||
| using first_type = T1; | ||
| using second_type = T2; | ||
|
|
||
| first_type first; | ||
| second_type second; | ||
| }; | ||
|
|
||
|
|
||
| // Helper to make a pair (similar to std::make_pair) | ||
| template<typename T1, typename T2> | ||
| pair<T1, T2> make_pair(T1 f, T2 s) | ||
| { | ||
| pair<T1, T2> p; | ||
| p.first = f; | ||
| p.second = s; | ||
| return p; | ||
| } | ||
devshgraphicsprogramming marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| } | ||
| } | ||
|
|
||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| #ifndef NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED | ||
| #define NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED | ||
| #include "nbl/builtin/hlsl/bitonic_sort/common.hlsl" | ||
| #include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" | ||
| #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl" | ||
| #include "nbl/builtin/hlsl/functional.hlsl" | ||
| namespace nbl | ||
| { | ||
| namespace hlsl | ||
| { | ||
| namespace subgroup | ||
| { | ||
|
|
||
| template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> > | ||
| struct bitonic_sort_config | ||
| { | ||
|
Comment on lines
+14
to
+16
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you also want to handle "multiple items per thread" in the subgroup sort |
||
| using key_t = KeyType; | ||
| using value_t = ValueType; | ||
| using comparator_t = Comparator; | ||
| }; | ||
|
Comment on lines
+14
to
+20
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where's the SubgroupSizeLog2 ? |
||
|
|
||
| template<typename Config, class device_capabilities = void> | ||
| struct bitonic_sort; | ||
|
|
||
| template<typename KeyType, typename ValueType, typename Comparator, class device_capabilities> | ||
| struct bitonic_sort<bitonic_sort_config<KeyType, ValueType, Comparator>, device_capabilities> | ||
| { | ||
| using config_t = bitonic_sort_config<KeyType, ValueType, Comparator>; | ||
| using key_t = typename config_t::key_t; | ||
| using value_t = typename config_t::value_t; | ||
| using comparator_t = typename config_t::comparator_t; | ||
|
|
||
| static void mergeStage(uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey, | ||
| NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal) | ||
devshgraphicsprogramming marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| { | ||
| comparator_t comp; | ||
|
|
||
| [unroll] | ||
| for (uint32_t pass = 0; pass <= stage; pass++) | ||
| { | ||
| const uint32_t stride = 1u << (stage - pass); // Element stride | ||
|
Comment on lines
+39
to
+41
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not write the loop as |
||
| const uint32_t threadStride = stride >> 1; | ||
| if (threadStride == 0) | ||
| { | ||
| // Local compare and swap for stage 0 | ||
| nbl::hlsl::bitonic_sort::compareSwap(bitonicAscending, loKey, hiKey, loVal, hiVal, comp); | ||
| } | ||
| else | ||
| { | ||
| // Shuffle from partner using XOR | ||
| const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride); | ||
| const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride); | ||
| const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride); | ||
| const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride); | ||
devshgraphicsprogramming marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| const bool isUpper = bool(invocationID & threadStride); | ||
| const bool takeLarger = isUpper == bitonicAscending; | ||
|
|
||
| nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loKey, pLoKey, hiKey, pHiKey, loVal, pLoVal, hiVal, pHiVal, comp); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| static void __call(bool ascending, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey, | ||
devshgraphicsprogramming marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal) | ||
| { | ||
| const uint32_t invocationID = glsl::gl_SubgroupInvocationID(); | ||
| const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2(); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. needs to be compile time constant from config like the FFT |
||
| [unroll] | ||
| for (uint32_t stage = 0; stage <= subgroupSizeLog2; stage++) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually I'd ask for an even cooler feature, would be nice if Why? Because it would be super useful for things like sorting arrays which are smaller than what a subgroup can process (so we can pack multiple arrays to be sorted independently into a single subgroup)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the |
||
| { | ||
| const bool bitonicAscending = (stage == subgroupSizeLog2) ? ascending : !bool(invocationID & (1u << stage)); | ||
| mergeStage(stage, bitonicAscending, invocationID, loKey, hiKey, loVal, hiVal); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } | ||
| } | ||
| } | ||
| #endif | ||
Uh oh!
There was an error while loading. Please reload this page.