Skip to content

Commit f32409f

Browse files
authored
[VULKAN] Add support for specialization constants (#499)
This commit introduces support for specialization constants in the Vulkan backend. Key changes: - Added struct to to represent a specialization constant with its ID, type, and value. - Updated YAML mapping in to parse specialization constants from the test configuration. - Modified to create and use when creating the compute pipeline, allowing specialization constants to be passed to the shader. - Added a new test case in to verify the functionality of specialization constants with various data types (bool, int, uint, float). Fixes llvm/llvm-project#142992
1 parent a3148b8 commit f32409f

File tree

14 files changed

+652
-8
lines changed

14 files changed

+652
-8
lines changed

include/Support/Pipeline.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,18 @@ struct IOBindings {
282282
}
283283
};
284284

285+
struct SpecializationConstant {
286+
uint32_t ConstantID;
287+
DataFormat Type;
288+
std::string Value;
289+
};
290+
285291
struct Shader {
286292
Stages Stage;
287293
std::string Entry;
288294
std::unique_ptr<llvm::MemoryBuffer> Shader;
289295
int DispatchSize[3];
296+
llvm::SmallVector<SpecializationConstant> SpecializationConstants;
290297
};
291298

292299
struct Pipeline {
@@ -335,6 +342,7 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::Shader)
335342
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::dx::RootParameter)
336343
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::Result)
337344
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::VertexAttribute)
345+
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::SpecializationConstant)
338346

339347
namespace llvm {
340348
namespace yaml {
@@ -399,6 +407,10 @@ template <> struct MappingTraits<offloadtest::RuntimeSettings> {
399407
static void mapping(IO &I, offloadtest::RuntimeSettings &S);
400408
};
401409

410+
template <> struct MappingTraits<offloadtest::SpecializationConstant> {
411+
static void mapping(IO &I, offloadtest::SpecializationConstant &C);
412+
};
413+
402414
template <> struct ScalarEnumerationTraits<offloadtest::Rule> {
403415
static void enumeration(IO &I, offloadtest::Rule &V) {
404416
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::Rule::Val)

lib/API/VK/Device.cpp

Lines changed: 148 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "API/Device.h"
1313
#include "Support/Pipeline.h"
14+
#include "llvm/ADT/DenseSet.h"
1415
#include "llvm/Support/Error.h"
1516

1617
#include <memory>
@@ -20,22 +21,34 @@
2021

2122
using namespace offloadtest;
2223

23-
#define VKFormats(FMT) \
24+
#define VKFormats(FMT, BITS) \
2425
if (Channels == 1) \
25-
return VK_FORMAT_R32_##FMT; \
26+
return VK_FORMAT_R##BITS##_##FMT; \
2627
if (Channels == 2) \
27-
return VK_FORMAT_R32G32_##FMT; \
28+
return VK_FORMAT_R##BITS##G##BITS##_##FMT; \
2829
if (Channels == 3) \
29-
return VK_FORMAT_R32G32B32_##FMT; \
30+
return VK_FORMAT_R##BITS##G##BITS##B##BITS##_##FMT; \
3031
if (Channels == 4) \
31-
return VK_FORMAT_R32G32B32A32_##FMT;
32+
return VK_FORMAT_R##BITS##G##BITS##B##BITS##A##BITS##_##FMT;
3233

3334
static VkFormat getVKFormat(DataFormat Format, int Channels) {
3435
switch (Format) {
36+
case DataFormat::Int16:
37+
VKFormats(SINT, 16) break;
38+
case DataFormat::UInt16:
39+
VKFormats(UINT, 16) break;
3540
case DataFormat::Int32:
36-
VKFormats(SINT) break;
41+
VKFormats(SINT, 32) break;
42+
case DataFormat::UInt32:
43+
VKFormats(UINT, 32) break;
3744
case DataFormat::Float32:
38-
VKFormats(SFLOAT) break;
45+
VKFormats(SFLOAT, 32) break;
46+
case DataFormat::Int64:
47+
VKFormats(SINT, 64) break;
48+
case DataFormat::UInt64:
49+
VKFormats(UINT, 64) break;
50+
case DataFormat::Float64:
51+
VKFormats(SFLOAT, 64) break;
3952
default:
4053
llvm_unreachable("Unsupported Resource format specified");
4154
}
@@ -1273,6 +1286,105 @@ class VKDevice : public offloadtest::Device {
12731286
return llvm::Error::success();
12741287
}
12751288

1289+
static llvm::Error
1290+
parseSpecializationConstant(const SpecializationConstant &SpecConst,
1291+
VkSpecializationMapEntry &Entry,
1292+
llvm::SmallVector<char> &SpecData) {
1293+
Entry.constantID = SpecConst.ConstantID;
1294+
Entry.offset = SpecData.size();
1295+
switch (SpecConst.Type) {
1296+
case DataFormat::Float32: {
1297+
float Value = 0.0f;
1298+
double Tmp = 0.0;
1299+
if (llvm::StringRef(SpecConst.Value).getAsDouble(Tmp))
1300+
return llvm::createStringError(
1301+
std::errc::invalid_argument,
1302+
"Invalid float value for specialization constant '%s'",
1303+
SpecConst.Value.c_str());
1304+
Value = static_cast<float>(Tmp);
1305+
Entry.size = sizeof(float);
1306+
SpecData.resize(SpecData.size() + sizeof(float));
1307+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(float));
1308+
break;
1309+
}
1310+
case DataFormat::Float64: {
1311+
double Value = 0.0;
1312+
if (llvm::StringRef(SpecConst.Value).getAsDouble(Value))
1313+
return llvm::createStringError(
1314+
std::errc::invalid_argument,
1315+
"Invalid double value for specialization constant '%s'",
1316+
SpecConst.Value.c_str());
1317+
Entry.size = sizeof(double);
1318+
SpecData.resize(SpecData.size() + sizeof(double));
1319+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(double));
1320+
break;
1321+
}
1322+
case DataFormat::Int16: {
1323+
int16_t Value = 0;
1324+
if (llvm::StringRef(SpecConst.Value).getAsInteger(0, Value))
1325+
return llvm::createStringError(
1326+
std::errc::invalid_argument,
1327+
"Invalid int16 value for specialization constant '%s'",
1328+
SpecConst.Value.c_str());
1329+
Entry.size = sizeof(int16_t);
1330+
SpecData.resize(SpecData.size() + sizeof(int16_t));
1331+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(int16_t));
1332+
break;
1333+
}
1334+
case DataFormat::UInt16: {
1335+
uint16_t Value = 0;
1336+
if (llvm::StringRef(SpecConst.Value).getAsInteger(0, Value))
1337+
return llvm::createStringError(
1338+
std::errc::invalid_argument,
1339+
"Invalid uint16 value for specialization constant '%s'",
1340+
SpecConst.Value.c_str());
1341+
Entry.size = sizeof(uint16_t);
1342+
SpecData.resize(SpecData.size() + sizeof(uint16_t));
1343+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(uint16_t));
1344+
break;
1345+
}
1346+
case DataFormat::Int32: {
1347+
int32_t Value = 0;
1348+
if (llvm::StringRef(SpecConst.Value).getAsInteger(0, Value))
1349+
return llvm::createStringError(
1350+
std::errc::invalid_argument,
1351+
"Invalid int32 value for specialization constant '%s'",
1352+
SpecConst.Value.c_str());
1353+
Entry.size = sizeof(int32_t);
1354+
SpecData.resize(SpecData.size() + sizeof(int32_t));
1355+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(int32_t));
1356+
break;
1357+
}
1358+
case DataFormat::UInt32: {
1359+
uint32_t Value = 0;
1360+
if (llvm::StringRef(SpecConst.Value).getAsInteger(0, Value))
1361+
return llvm::createStringError(
1362+
std::errc::invalid_argument,
1363+
"Invalid uint32 value for specialization constant '%s'",
1364+
SpecConst.Value.c_str());
1365+
Entry.size = sizeof(uint32_t);
1366+
SpecData.resize(SpecData.size() + sizeof(uint32_t));
1367+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(uint32_t));
1368+
break;
1369+
}
1370+
case DataFormat::Bool: {
1371+
bool Value = false;
1372+
if (llvm::StringRef(SpecConst.Value).getAsInteger(0, Value))
1373+
return llvm::createStringError(
1374+
std::errc::invalid_argument,
1375+
"Invalid bool value for specialization constant '%s'",
1376+
SpecConst.Value.c_str());
1377+
Entry.size = sizeof(bool);
1378+
SpecData.resize(SpecData.size() + sizeof(bool));
1379+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(bool));
1380+
break;
1381+
}
1382+
default:
1383+
llvm_unreachable("Unsupported specialization constant type");
1384+
}
1385+
return llvm::Error::success();
1386+
}
1387+
12761388
llvm::Error createPipeline(Pipeline &P, InvocationState &IS) {
12771389
VkPipelineCacheCreateInfo CacheCreateInfo = {};
12781390
CacheCreateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO;
@@ -1282,15 +1394,43 @@ class VKDevice : public offloadtest::Device {
12821394
"Failed to create pipeline cache.");
12831395

12841396
if (P.isCompute()) {
1285-
const CompiledShader &S = IS.Shaders[0];
1397+
const offloadtest::Shader &Shader = P.Shaders[0];
12861398
assert(IS.Shaders.size() == 1 &&
12871399
"Currently only support one compute shader");
1400+
const CompiledShader &S = IS.Shaders[0];
12881401
VkPipelineShaderStageCreateInfo StageInfo = {};
12891402
StageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
12901403
StageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
12911404
StageInfo.module = S.Shader;
12921405
StageInfo.pName = S.Entry.c_str();
12931406

1407+
llvm::SmallVector<VkSpecializationMapEntry> SpecEntries;
1408+
llvm::SmallVector<char> SpecData;
1409+
VkSpecializationInfo SpecInfo = {};
1410+
if (!Shader.SpecializationConstants.empty()) {
1411+
llvm::DenseSet<uint32_t> SeenConstantIDs;
1412+
for (const auto &SpecConst : Shader.SpecializationConstants) {
1413+
if (!SeenConstantIDs.insert(SpecConst.ConstantID).second)
1414+
return llvm::createStringError(
1415+
std::errc::invalid_argument,
1416+
"Test configuration contains multiple entries for "
1417+
"specialization constant ID %u.",
1418+
SpecConst.ConstantID);
1419+
1420+
VkSpecializationMapEntry Entry;
1421+
if (auto Err =
1422+
parseSpecializationConstant(SpecConst, Entry, SpecData))
1423+
return Err;
1424+
SpecEntries.push_back(Entry);
1425+
}
1426+
1427+
SpecInfo.mapEntryCount = SpecEntries.size();
1428+
SpecInfo.pMapEntries = SpecEntries.data();
1429+
SpecInfo.dataSize = SpecData.size();
1430+
SpecInfo.pData = SpecData.data();
1431+
StageInfo.pSpecializationInfo = &SpecInfo;
1432+
}
1433+
12941434
VkComputePipelineCreateInfo PipelineCreateInfo = {};
12951435
PipelineCreateInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
12961436
PipelineCreateInfo.stage = StageInfo;

lib/Support/Pipeline.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ void MappingTraits<offloadtest::Shader>::mapping(IO &I,
372372
offloadtest::Shader &S) {
373373
I.mapRequired("Stage", S.Stage);
374374
I.mapRequired("Entry", S.Entry);
375+
I.mapOptional("SpecializationConstants", S.SpecializationConstants);
375376

376377
if (S.Stage == Stages::Compute) {
377378
// Stage-specific data, not sure if this should be optional
@@ -380,6 +381,7 @@ void MappingTraits<offloadtest::Shader>::mapping(IO &I,
380381
I.mapRequired("DispatchSize", MutableDispatchSize);
381382
}
382383
}
384+
383385
void MappingTraits<offloadtest::Result>::mapping(IO &I,
384386
offloadtest::Result &R) {
385387
I.mapRequired("Result", R.Name);
@@ -402,5 +404,13 @@ void MappingTraits<offloadtest::Result>::mapping(IO &I,
402404
break;
403405
}
404406
}
407+
408+
void MappingTraits<offloadtest::SpecializationConstant>::mapping(
409+
IO &I, offloadtest::SpecializationConstant &C) {
410+
I.mapRequired("ConstantID", C.ConstantID);
411+
I.mapRequired("Type", C.Type);
412+
I.mapRequired("Value", C.Value);
413+
}
414+
405415
} // namespace yaml
406416
} // namespace llvm
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#--- duplicate_spec_id.hlsl
2+
[[vk::constant_id(0)]]
3+
const int spec_int_A = 0;
4+
[[vk::constant_id(0)]]
5+
const int spec_int_B = 0;
6+
7+
RWStructuredBuffer<int> Out : register(u0, space0);
8+
9+
[numthreads(1,1,1)]
10+
void main(uint GI : SV_GroupIndex) {
11+
Out[0] = spec_int_A;
12+
Out[1] = spec_int_B;
13+
}
14+
#--- duplicate_spec_id.yaml
15+
---
16+
Shaders:
17+
- Stage: Compute
18+
Entry: main
19+
DispatchSize: [1, 1, 1]
20+
SpecializationConstants:
21+
- { ConstantID: 0, Value: 123, Type: Int32 }
22+
Buffers:
23+
- { Name: Out, Format: Int32, Stride: 4, FillSize: 8 }
24+
DescriptorSets:
25+
- Resources:
26+
- Name: Out
27+
Kind: RWStructuredBuffer
28+
DirectXBinding: { Register: 0, Space: 0 }
29+
VulkanBinding: { Binding: 0 }
30+
...
31+
#--- end
32+
33+
# REQUIRES: Vulkan
34+
35+
# RUN: split-file %s %t
36+
# RUN: %dxc_target -T cs_6_2 -Fo %t.o %t/duplicate_spec_id.hlsl
37+
# RUN: %offloader %t/duplicate_spec_id.yaml %t.o | FileCheck %s
38+
39+
# CHECK: Data: [ 123, 123 ]
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#--- duplicate_spec_id_in_config.hlsl
2+
[[vk::constant_id(0)]]
3+
const int spec_int = 0;
4+
5+
RWStructuredBuffer<int> Out : register(u0, space0);
6+
7+
[numthreads(1,1,1)]
8+
void main(uint GI : SV_GroupIndex) {
9+
Out[GI] = spec_int;
10+
}
11+
#--- duplicate_spec_id_in_config.yaml
12+
---
13+
Shaders:
14+
- Stage: Compute
15+
Entry: main
16+
DispatchSize: [1, 1, 1]
17+
SpecializationConstants:
18+
- { ConstantID: 0, Value: 123, Type: Int32 }
19+
- { ConstantID: 0, Value: 456, Type: Int32 }
20+
Buffers:
21+
- { Name: Out, Format: Int32, Stride: 4, FillSize: 4 }
22+
DescriptorSets:
23+
- Resources:
24+
- Name: Out
25+
Kind: RWStructuredBuffer
26+
DirectXBinding: { Register: 0, Space: 0 }
27+
VulkanBinding: { Binding: 0 }
28+
...
29+
#--- end
30+
31+
# REQUIRES: Vulkan
32+
33+
# RUN: split-file %s %t
34+
# RUN: %dxc_target -T cs_6_2 -Fo %t.o %t/duplicate_spec_id_in_config.hlsl
35+
# RUN: not %offloader %t/duplicate_spec_id_in_config.yaml %t.o 2>&1 | FileCheck %s
36+
37+
# CHECK: gpu-exec: error: Test configuration contains multiple entries for specialization constant ID 0.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#--- invalid_bool.hlsl
2+
[[vk::constant_id(0)]]
3+
const bool spec_bool = false;
4+
RWStructuredBuffer<uint> OutBool : register(u0, space0);
5+
6+
[numthreads(1,1,1)]
7+
void main(uint GI : SV_GroupIndex) {
8+
OutBool[GI] = (uint)spec_bool;
9+
}
10+
#--- invalid_bool.yaml
11+
---
12+
Shaders:
13+
- Stage: Compute
14+
Entry: main
15+
DispatchSize: [1, 1, 1]
16+
SpecializationConstants:
17+
- { ConstantID: 0, Value: "not a number", Type: Bool }
18+
Buffers:
19+
- { Name: OutBool, Format: Int32, Stride: 4, FillSize: 4 }
20+
DescriptorSets:
21+
- Resources:
22+
- Name: OutBool
23+
Kind: RWStructuredBuffer
24+
DirectXBinding: { Register: 0, Space: 0 }
25+
VulkanBinding: { Binding: 0 }
26+
...
27+
#--- end
28+
29+
# REQUIRES: Vulkan
30+
31+
# RUN: split-file %s %t
32+
# RUN: %dxc_target -T cs_6_2 -Fo %t.o %t/invalid_bool.hlsl
33+
# RUN: not %offloader %t/invalid_bool.yaml %t.o 2>&1 | FileCheck %s
34+
35+
# CHECK: Invalid bool value for specialization constant 'not a number'

0 commit comments

Comments
 (0)