Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5d6322a
created a structured like fft
CrabExtra Oct 2, 2025
264650c
config question
CrabExtra Oct 2, 2025
268949e
subgroupsort
CrabExtra Oct 5, 2025
73aa820
added bitonic_sort name space
CrabExtra Oct 5, 2025
dcb5e6e
Merge branch 'Devsh-Graphics-Programming:master' into master
CrabExtra Oct 7, 2025
b84a4bd
subgroup changes
CrabExtra Oct 7, 2025
779815e
removed unused
CrabExtra Oct 7, 2025
78a307a
Merge branch 'master' of https://github.com/CrabExtra/Nabla
CrabExtra Oct 7, 2025
ad7a4c5
added last merge step as a function
CrabExtra Oct 19, 2025
b80283a
uncomplete workgroup fn
CrabExtra Oct 19, 2025
7c91744
complete the logic for some pr questions
CrabExtra Oct 19, 2025
4d253f3
Refactor bitonic sort for workgroup + Accessor support
CrabExtra Oct 22, 2025
f03b8b2
Update bitonic_sort.hlsl
CrabExtra Oct 22, 2025
555dcbe
VT implumentation
CrabExtra Oct 27, 2025
0eb301e
Update bitonic_sort.hlsl
CrabExtra Oct 29, 2025
52a7cb8
Update common.hlsl
CrabExtra Oct 29, 2025
0b60d0c
Update bitonic_sort.hlsl
CrabExtra Oct 29, 2025
cbfa188
Update common.hlsl
CrabExtra Nov 2, 2025
0de0167
Update bitonic_sort.hlsl
CrabExtra Nov 2, 2025
1b1ba15
Update CMakeLists.txt
CrabExtra Nov 2, 2025
a20ba6e
pair added
CrabExtra Nov 2, 2025
8c7b7e5
comment outdated pair impl
CrabExtra Nov 2, 2025
06af50b
bitonic sort acessor added
CrabExtra Nov 2, 2025
ecb7182
Update common.hlsl
CrabExtra Nov 5, 2025
08867d6
Update bitonic_sort.hlsl
CrabExtra Nov 5, 2025
e2937ce
Update bitonic_sort.hlsl
CrabExtra Nov 5, 2025
686618c
Update utility.hlsl
CrabExtra Nov 5, 2025
c8e990d
Delete include/nbl/builtin/hlsl/utility.hlsl
CrabExtra Nov 5, 2025
55b7813
Delete include/nbl/builtin/hlsl/pair.hlsl
CrabExtra Nov 5, 2025
d0cd7a3
Add files via upload
CrabExtra Nov 5, 2025
e8f6134
Update CMakeLists.txt
CrabExtra Nov 5, 2025
034cd33
Remove unused pair struct from memory_accessor.hlsl
CrabExtra Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions include/nbl/builtin/hlsl/bitonic_sort/common.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#ifndef _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
#define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_

#include <nbl/builtin/hlsl/cpp_compat.hlsl>
#include <nbl/builtin/hlsl/concepts.hlsl>
#include <nbl/builtin/hlsl/math/intutil.hlsl>
#include <nbl/builtin/hlsl/pair.hlsl>

namespace nbl
{
namespace hlsl
{
namespace bitonic_sort
{

template<typename KeyType, typename ValueType, typename Comparator>
void compareExchangeWithPartner(
bool takeLarger,
NBL_REF_ARG(KeyType) loKey,
NBL_CONST_REF_ARG(KeyType) partnerLoKey,
NBL_REF_ARG(KeyType) hiKey,
NBL_CONST_REF_ARG(KeyType) partnerHiKey,
NBL_REF_ARG(ValueType) loVal,
NBL_CONST_REF_ARG(ValueType) partnerLoVal,
NBL_REF_ARG(ValueType) hiVal,
NBL_CONST_REF_ARG(ValueType) partnerHiVal,
NBL_CONST_REF_ARG(Comparator) comp)
{
const bool loSelfSmaller = comp(loKey, partnerLoKey);
const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller;
loKey = takePartnerLo ? partnerLoKey : loKey;
loVal = takePartnerLo ? partnerLoVal : loVal;

const bool hiSelfSmaller = comp(hiKey, partnerHiKey);
const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller;
hiKey = takePartnerHi ? partnerHiKey : hiKey;
hiVal = takePartnerHi ? partnerHiVal : hiVal;
}


template<typename KeyType, typename ValueType, typename Comparator>
void compareSwap(
bool ascending,
NBL_REF_ARG(KeyType) loKey,
NBL_REF_ARG(KeyType) hiKey,
NBL_REF_ARG(ValueType) loVal,
NBL_REF_ARG(ValueType) hiVal,
NBL_CONST_REF_ARG(Comparator) comp)
{
const bool shouldSwap = comp(hiKey, loKey);

const bool doSwap = (shouldSwap == ascending);

KeyType tempKey = loKey;
loKey = doSwap ? hiKey : loKey;
hiKey = doSwap ? tempKey : hiKey;

ValueType tempVal = loVal;
loVal = doSwap ? hiVal : loVal;
hiVal = doSwap ? tempVal : hiVal;
}

template<typename KeyType, typename ValueType>
void swap(
NBL_REF_ARG(KeyType) loKey,
NBL_REF_ARG(KeyType) hiKey,
NBL_REF_ARG(ValueType) loVal,
NBL_REF_ARG(ValueType) hiVal)
{
KeyType tempKey = loKey;
loKey = hiKey;
hiKey = tempKey;

ValueType tempVal = loVal;
loVal = hiVal;
hiVal = tempVal;
}



template<typename KeyType, typename ValueType, typename Comparator>
void compareExchangeWithPartner(
bool takeLarger,
NBL_REF_ARG(pair<KeyType, ValueType>) loPair,
NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerLoPair,
NBL_REF_ARG(pair<KeyType, ValueType>) hiPair,
NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerHiPair,
NBL_CONST_REF_ARG(Comparator) comp)
{
const bool loSelfSmaller = comp(loPair.first, partnerLoPair.first);
const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller;
loPair.first = takePartnerLo ? partnerLoPair.first : loPair.first;
loPair.second = takePartnerLo ? partnerLoPair.second : loPair.second;

const bool hiSelfSmaller = comp(hiPair.first, partnerHiPair.first);
const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller;
hiPair.first = takePartnerHi ? partnerHiPair.first : hiPair.first;
hiPair.second = takePartnerHi ? partnerHiPair.second : hiPair.second;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define an operator= for pair, then this becomes

if(takePartnerLo)
    loPair = pLoPair;
if(takePartnerHi)
    hiPair = pHiPair;

don't worry about branching, since everything's an assignment it'll just be OpSelects under the hood

Copy link
Member

@devshgraphicsprogramming devshgraphicsprogramming Nov 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can't define an operator= in HLSL none of the operators which should return references can be defined in HLSL so assignment and array indexing (as well as compound assignment)

However all structs in HLSL are trivial so can be assigned with =


template<typename KeyType, typename ValueType, typename Comparator>
void compareSwap(
bool ascending,
NBL_REF_ARG(pair<KeyType, ValueType>) loPair,
NBL_REF_ARG(pair<KeyType, ValueType>) hiPair,
NBL_CONST_REF_ARG(Comparator) comp)
{
const bool shouldSwap = comp(hiPair.first, loPair.first);
const bool doSwap = (shouldSwap == ascending);

KeyType tempKey = loPair.first;
ValueType tempVal = loPair.second;
loPair.first = doSwap ? hiPair.first : loPair.first;
loPair.second = doSwap ? hiPair.second : loPair.second;
hiPair.first = doSwap ? tempKey : hiPair.first;
hiPair.second = doSwap ? tempVal : hiPair.second;
}

template<typename KeyType, typename ValueType>
void swap(
NBL_REF_ARG(pair<KeyType, ValueType>) loPair,
NBL_REF_ARG(pair<KeyType, ValueType>) hiPair)
{
pair<KeyType, ValueType> temp = loPair;
loPair = hiPair;
hiPair = temp;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work without a definition for operator= for pair. Your code compiles because you're not using it rn. We want to keep this version and rewrite all the swaps to use this version using pairs.

The definition for pair, the overload for operator= and this swap method all belong in https://github.com/Devsh-Graphics-Programming/Nabla/blob/master/include/nbl/builtin/hlsl/utility.hlsl, mimicking how std::pair is in the <utility> header in cpp. Move them over there.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

operator= can only be defined in a #ifndef __HLSL_VERSION macro block, for DXC reasons

}
}
}

#endif
31 changes: 31 additions & 0 deletions include/nbl/builtin/hlsl/concepts/accessors/bitonic_sort.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_BITONIC_SORT_INCLUDED_
#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_BITONIC_SORT_INCLUDED_

#include "nbl/builtin/hlsl/concepts/accessors/generic_shared_data.hlsl"

namespace nbl
{
namespace hlsl
{
namespace workgroup
{
namespace bitonic_sort
{
// The SharedMemoryAccessor MUST provide the following methods:
// * void get(uint32_t index, NBL_REF_ARG(uint32_t) value);
// * void set(uint32_t index, in uint32_t value);
// * void workgroupExecutionAndMemoryBarrier();
template<typename T, typename V = uint32_t, typename I = uint32_t>
NBL_BOOL_CONCEPT BitonicSortSharedMemoryAccessor = concepts::accessors::GenericSharedMemoryAccessor<T, V, I>;

// The Accessor MUST provide the following methods:
// * void get(uint32_t index, NBL_REF_ARG(pair<KeyType, ValueType>) value);
// * void set(uint32_t index, in pair<KeyType, ValueType> value);
template<typename T, typename KeyType, typename ValueType, typename I = uint32_t>
NBL_BOOL_CONCEPT BitonicSortAccessor = concepts::accessors::GenericDataAccessor<T, pair<KeyType, ValueType>, I>;
Comment on lines +14 to +25

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bitonic_sort::BitonicSort... is a tautology, drop the BitonicSort prefix from the names


}
}
}
}
#endif
20 changes: 10 additions & 10 deletions include/nbl/builtin/hlsl/memory_accessor.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ namespace hlsl
{

// TODO: flesh out and move to `nbl/builtin/hlsl/utility.hlsl`
template<typename T1, typename T2>
struct pair
{
using first_type = T1;
using second_type = T2;

first_type first;
second_type second;
};
//template<typename T1, typename T2>
//struct pair
//{
// using first_type = T1;
// using second_type = T2;
//
// first_type first;
// second_type second;
//};

namespace accessor_adaptors
{
Expand Down Expand Up @@ -227,4 +227,4 @@ struct Offset : impl::OffsetBase<IndexType,_Offset>
}
}
}
#endif
#endif
38 changes: 38 additions & 0 deletions include/nbl/builtin/hlsl/pair.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h
#ifndef _NBL_BUILTIN_HLSL_PAIR_INCLUDED_
#define _NBL_BUILTIN_HLSL_PAIR_INCLUDED_

#include "nbl/builtin/hlsl/type_traits.hlsl"

namespace nbl
{
namespace hlsl
{

template<typename T1, typename T2>
struct pair
{
using first_type = T1;
using second_type = T2;

first_type first;
second_type second;
};


// Helper to make a pair (similar to std::make_pair)
template<typename T1, typename T2>
pair<T1, T2> make_pair(T1 f, T2 s)
{
pair<T1, T2> p;
p.first = f;
p.second = s;
return p;
}

}
}

#endif
81 changes: 81 additions & 0 deletions include/nbl/builtin/hlsl/subgroup/bitonic_sort.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#ifndef NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#define NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
#include "nbl/builtin/hlsl/functional.hlsl"
namespace nbl
{
namespace hlsl
{
namespace subgroup
{

template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> >
struct bitonic_sort_config
{
Comment on lines +14 to +16

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also want to handle "multiple items per thread" in the subgroup sort

using key_t = KeyType;
using value_t = ValueType;
using comparator_t = Comparator;
};
Comment on lines +14 to +20

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where's the SubgroupSizeLog2 ?


template<typename Config, class device_capabilities = void>
struct bitonic_sort;

template<typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
struct bitonic_sort<bitonic_sort_config<KeyType, ValueType, Comparator>, device_capabilities>
{
using config_t = bitonic_sort_config<KeyType, ValueType, Comparator>;
using key_t = typename config_t::key_t;
using value_t = typename config_t::value_t;
using comparator_t = typename config_t::comparator_t;

static void mergeStage(uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
comparator_t comp;

[unroll]
for (uint32_t pass = 0; pass <= stage; pass++)
{
const uint32_t stride = 1u << (stage - pass); // Element stride
Comment on lines +39 to +41

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not write the loop as for (uint32_t stride=1u<<stage; stride; stride=stride>>1) and forget the pass ?

const uint32_t threadStride = stride >> 1;
if (threadStride == 0)
{
// Local compare and swap for stage 0
nbl::hlsl::bitonic_sort::compareSwap(bitonicAscending, loKey, hiKey, loVal, hiVal, comp);
}
else
{
// Shuffle from partner using XOR
const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride);
const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride);
const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride);
const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride);

const bool isUpper = bool(invocationID & threadStride);
const bool takeLarger = isUpper == bitonicAscending;

nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loKey, pLoKey, hiKey, pHiKey, loVal, pLoVal, hiVal, pHiVal, comp);
}
}
}

static void __call(bool ascending, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
const uint32_t invocationID = glsl::gl_SubgroupInvocationID();
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to be compile time constant from config like the FFT

[unroll]
for (uint32_t stage = 0; stage <= subgroupSizeLog2; stage++)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I'd ask for an even cooler feature, would be nice if stage<=sortSizeLog2 and __call had a last argument with a default const uint32_t sortSizeLog2=Config::SubgroupSizeLog2

Why? Because it would be super useful for things like sorting arrays which are smaller than what a subgroup can process (so we can pack multiple arrays to be sorted independently into a single subgroup)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the const bool bitonicAscending = (stage == subgroupSizeLog2) ? would have to change to const bool bitonicAscending = (stage == sortSizeLog2 ) ?

{
const bool bitonicAscending = (stage == subgroupSizeLog2) ? ascending : !bool(invocationID & (1u << stage));
mergeStage(stage, bitonicAscending, invocationID, loKey, hiKey, loVal, hiVal);
}
}
};

}
}
}
#endif
Loading