1111
1212#include " API/Device.h"
1313#include " Support/Pipeline.h"
14+ #include " llvm/ADT/DenseSet.h"
1415#include " llvm/Support/Error.h"
1516
1617#include < memory>
2021
2122using namespace offloadtest ;
2223
23- #define VKFormats (FMT ) \
24+ #define VKFormats (FMT, BITS ) \
2425 if (Channels == 1 ) \
25- return VK_FORMAT_R32_## FMT; \
26+ return VK_FORMAT_R##BITS##_## FMT; \
2627 if (Channels == 2 ) \
27- return VK_FORMAT_R32G32_## FMT; \
28+ return VK_FORMAT_R##BITS##G##BITS##_## FMT; \
2829 if (Channels == 3 ) \
29- return VK_FORMAT_R32G32B32_## FMT; \
30+ return VK_FORMAT_R##BITS##G##BITS##B##BITS##_## FMT; \
3031 if (Channels == 4 ) \
31- return VK_FORMAT_R32G32B32A32_ ##FMT;
32+ return VK_FORMAT_R##BITS##G##BITS##B##BITS##A##BITS##_ ##FMT;
3233
3334static VkFormat getVKFormat (DataFormat Format, int Channels) {
3435 switch (Format) {
36+ case DataFormat::Int16:
37+ VKFormats (SINT, 16 ) break ;
38+ case DataFormat::UInt16:
39+ VKFormats (UINT, 16 ) break ;
3540 case DataFormat::Int32:
36- VKFormats (SINT) break ;
41+ VKFormats (SINT, 32 ) break ;
42+ case DataFormat::UInt32:
43+ VKFormats (UINT, 32 ) break ;
3744 case DataFormat::Float32:
38- VKFormats (SFLOAT) break ;
45+ VKFormats (SFLOAT, 32 ) break ;
46+ case DataFormat::Int64:
47+ VKFormats (SINT, 64 ) break ;
48+ case DataFormat::UInt64:
49+ VKFormats (UINT, 64 ) break ;
50+ case DataFormat::Float64:
51+ VKFormats (SFLOAT, 64 ) break ;
3952 default :
4053 llvm_unreachable (" Unsupported Resource format specified" );
4154 }
@@ -1273,6 +1286,105 @@ class VKDevice : public offloadtest::Device {
12731286 return llvm::Error::success ();
12741287 }
12751288
1289+ static llvm::Error
1290+ parseSpecializationConstant (const SpecializationConstant &SpecConst,
1291+ VkSpecializationMapEntry &Entry,
1292+ llvm::SmallVector<char > &SpecData) {
1293+ Entry.constantID = SpecConst.ConstantID ;
1294+ Entry.offset = SpecData.size ();
1295+ switch (SpecConst.Type ) {
1296+ case DataFormat::Float32: {
1297+ float Value = 0 .0f ;
1298+ double Tmp = 0.0 ;
1299+ if (llvm::StringRef (SpecConst.Value ).getAsDouble (Tmp))
1300+ return llvm::createStringError (
1301+ std::errc::invalid_argument,
1302+ " Invalid float value for specialization constant '%s'" ,
1303+ SpecConst.Value .c_str ());
1304+ Value = static_cast <float >(Tmp);
1305+ Entry.size = sizeof (float );
1306+ SpecData.resize (SpecData.size () + sizeof (float ));
1307+ memcpy (SpecData.data () + Entry.offset , &Value, sizeof (float ));
1308+ break ;
1309+ }
1310+ case DataFormat::Float64: {
1311+ double Value = 0.0 ;
1312+ if (llvm::StringRef (SpecConst.Value ).getAsDouble (Value))
1313+ return llvm::createStringError (
1314+ std::errc::invalid_argument,
1315+ " Invalid double value for specialization constant '%s'" ,
1316+ SpecConst.Value .c_str ());
1317+ Entry.size = sizeof (double );
1318+ SpecData.resize (SpecData.size () + sizeof (double ));
1319+ memcpy (SpecData.data () + Entry.offset , &Value, sizeof (double ));
1320+ break ;
1321+ }
1322+ case DataFormat::Int16: {
1323+ int16_t Value = 0 ;
1324+ if (llvm::StringRef (SpecConst.Value ).getAsInteger (0 , Value))
1325+ return llvm::createStringError (
1326+ std::errc::invalid_argument,
1327+ " Invalid int16 value for specialization constant '%s'" ,
1328+ SpecConst.Value .c_str ());
1329+ Entry.size = sizeof (int16_t );
1330+ SpecData.resize (SpecData.size () + sizeof (int16_t ));
1331+ memcpy (SpecData.data () + Entry.offset , &Value, sizeof (int16_t ));
1332+ break ;
1333+ }
1334+ case DataFormat::UInt16: {
1335+ uint16_t Value = 0 ;
1336+ if (llvm::StringRef (SpecConst.Value ).getAsInteger (0 , Value))
1337+ return llvm::createStringError (
1338+ std::errc::invalid_argument,
1339+ " Invalid uint16 value for specialization constant '%s'" ,
1340+ SpecConst.Value .c_str ());
1341+ Entry.size = sizeof (uint16_t );
1342+ SpecData.resize (SpecData.size () + sizeof (uint16_t ));
1343+ memcpy (SpecData.data () + Entry.offset , &Value, sizeof (uint16_t ));
1344+ break ;
1345+ }
1346+ case DataFormat::Int32: {
1347+ int32_t Value = 0 ;
1348+ if (llvm::StringRef (SpecConst.Value ).getAsInteger (0 , Value))
1349+ return llvm::createStringError (
1350+ std::errc::invalid_argument,
1351+ " Invalid int32 value for specialization constant '%s'" ,
1352+ SpecConst.Value .c_str ());
1353+ Entry.size = sizeof (int32_t );
1354+ SpecData.resize (SpecData.size () + sizeof (int32_t ));
1355+ memcpy (SpecData.data () + Entry.offset , &Value, sizeof (int32_t ));
1356+ break ;
1357+ }
1358+ case DataFormat::UInt32: {
1359+ uint32_t Value = 0 ;
1360+ if (llvm::StringRef (SpecConst.Value ).getAsInteger (0 , Value))
1361+ return llvm::createStringError (
1362+ std::errc::invalid_argument,
1363+ " Invalid uint32 value for specialization constant '%s'" ,
1364+ SpecConst.Value .c_str ());
1365+ Entry.size = sizeof (uint32_t );
1366+ SpecData.resize (SpecData.size () + sizeof (uint32_t ));
1367+ memcpy (SpecData.data () + Entry.offset , &Value, sizeof (uint32_t ));
1368+ break ;
1369+ }
1370+ case DataFormat::Bool: {
1371+ bool Value = false ;
1372+ if (llvm::StringRef (SpecConst.Value ).getAsInteger (0 , Value))
1373+ return llvm::createStringError (
1374+ std::errc::invalid_argument,
1375+ " Invalid bool value for specialization constant '%s'" ,
1376+ SpecConst.Value .c_str ());
1377+ Entry.size = sizeof (bool );
1378+ SpecData.resize (SpecData.size () + sizeof (bool ));
1379+ memcpy (SpecData.data () + Entry.offset , &Value, sizeof (bool ));
1380+ break ;
1381+ }
1382+ default :
1383+ llvm_unreachable (" Unsupported specialization constant type" );
1384+ }
1385+ return llvm::Error::success ();
1386+ }
1387+
12761388 llvm::Error createPipeline (Pipeline &P, InvocationState &IS) {
12771389 VkPipelineCacheCreateInfo CacheCreateInfo = {};
12781390 CacheCreateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO;
@@ -1282,15 +1394,43 @@ class VKDevice : public offloadtest::Device {
12821394 " Failed to create pipeline cache." );
12831395
12841396 if (P.isCompute ()) {
1285- const CompiledShader &S = IS .Shaders [0 ];
1397+ const offloadtest::Shader &Shader = P .Shaders [0 ];
12861398 assert (IS.Shaders .size () == 1 &&
12871399 " Currently only support one compute shader" );
1400+ const CompiledShader &S = IS.Shaders [0 ];
12881401 VkPipelineShaderStageCreateInfo StageInfo = {};
12891402 StageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
12901403 StageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
12911404 StageInfo.module = S.Shader ;
12921405 StageInfo.pName = S.Entry .c_str ();
12931406
1407+ llvm::SmallVector<VkSpecializationMapEntry> SpecEntries;
1408+ llvm::SmallVector<char > SpecData;
1409+ VkSpecializationInfo SpecInfo = {};
1410+ if (!Shader.SpecializationConstants .empty ()) {
1411+ llvm::DenseSet<uint32_t > SeenConstantIDs;
1412+ for (const auto &SpecConst : Shader.SpecializationConstants ) {
1413+ if (!SeenConstantIDs.insert (SpecConst.ConstantID ).second )
1414+ return llvm::createStringError (
1415+ std::errc::invalid_argument,
1416+ " Test configuration contains multiple entries for "
1417+ " specialization constant ID %u." ,
1418+ SpecConst.ConstantID );
1419+
1420+ VkSpecializationMapEntry Entry;
1421+ if (auto Err =
1422+ parseSpecializationConstant (SpecConst, Entry, SpecData))
1423+ return Err;
1424+ SpecEntries.push_back (Entry);
1425+ }
1426+
1427+ SpecInfo.mapEntryCount = SpecEntries.size ();
1428+ SpecInfo.pMapEntries = SpecEntries.data ();
1429+ SpecInfo.dataSize = SpecData.size ();
1430+ SpecInfo.pData = SpecData.data ();
1431+ StageInfo.pSpecializationInfo = &SpecInfo;
1432+ }
1433+
12941434 VkComputePipelineCreateInfo PipelineCreateInfo = {};
12951435 PipelineCreateInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
12961436 PipelineCreateInfo.stage = StageInfo;
0 commit comments