Skip to content

Commit 68717ea

Browse files
committed
unify SRV, first attempt
1 parent 1c9752f commit 68717ea

File tree

1 file changed

+47
-109
lines changed

1 file changed

+47
-109
lines changed

lib/API/DX/Device.cpp

Lines changed: 47 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ static D3D12_RESOURCE_DESC getResourceDescription(const Resource &R) {
151151
const uint32_t Height = R.isTexture() ? B.OutputProps.Height : 1;
152152
D3D12_TEXTURE_LAYOUT Layout;
153153
if (R.isTexture())
154-
Layout = getDXKind(R.Kind) == SRV
154+
Layout = getDXKind(R.Kind) == SRV || getDXKind(R.Kind) == UAV
155155
? D3D12_TEXTURE_LAYOUT_64KB_UNDEFINED_SWIZZLE
156156
: D3D12_TEXTURE_LAYOUT_UNKNOWN;
157157
else
@@ -528,6 +528,19 @@ class DXDevice : public offloadtest::Device {
528528
addUploadEndBarrier(IS, Destination, R.isReadWrite());
529529
}
530530

531+
UINT GetNumTiles(std::optional<int> numTiles, UINT64 width) {
532+
UINT ret;
533+
if (numTiles.has_value()) {
534+
ret = static_cast<UINT>(*numTiles);
535+
} else {
536+
// Map the entire buffer by computing how many 64KB tiles cover it
537+
ret = static_cast<UINT>(
538+
(width + D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT - 1) /
539+
D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT);
540+
}
541+
return ret;
542+
}
543+
531544
llvm::Expected<ResourceBundle> createSRV(Resource &R, InvocationState &IS) {
532545
ResourceBundle Bundle;
533546
const D3D12_RESOURCE_DESC ResDesc = getResourceDescription(R);
@@ -566,15 +579,8 @@ class DXDevice : public offloadtest::Device {
566579
return Err;
567580

568581
// Tile mapping setup (optional if NumTiles > 0)
569-
UINT NumTiles = 0;
570-
if (R.TilesMapped.has_value()) {
571-
NumTiles = static_cast<UINT>(*R.TilesMapped);
572-
} else {
573-
// Map the entire buffer by computing how many 64KB tiles cover it
574-
NumTiles = static_cast<UINT>(
575-
(ResDesc.Width + D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT - 1) /
576-
D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT);
577-
}
582+
UINT NumTiles = GetNumTiles(R.TilesMapped, ResDesc.Width);
583+
578584
ComPtr<ID3D12Heap> Heap; // optional, only created if NumTiles > 0
579585

580586
if (NumTiles > 0) {
@@ -660,10 +666,10 @@ class DXDevice : public offloadtest::Device {
660666
return HeapIdx;
661667
}
662668

663-
llvm::Expected<ResourceBundle> createReservedUAV(Resource &R,
664-
InvocationState &IS) {
669+
llvm::Expected<ResourceBundle> createUAV(Resource &R, InvocationState &IS) {
665670
ResourceBundle Bundle;
666671
const uint32_t BufferSize = getUAVBufferSize(R);
672+
667673
const D3D12_RESOURCE_DESC ResDesc = getResourceDescription(R);
668674

669675
const D3D12_HEAP_PROPERTIES ReadBackHeapProp =
@@ -680,6 +686,9 @@ class DXDevice : public offloadtest::Device {
680686
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
681687
D3D12_RESOURCE_FLAG_NONE};
682688

689+
const D3D12_HEAP_PROPERTIES UploadHeapProps =
690+
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD);
691+
683692
const D3D12_RESOURCE_DESC UploadResDesc =
684693
CD3DX12_RESOURCE_DESC::Buffer(BufferSize);
685694

@@ -694,18 +703,17 @@ class DXDevice : public offloadtest::Device {
694703
llvm::outs() << " }\n";
695704

696705
// Reserved UAV resource
706+
697707
ComPtr<ID3D12Resource> Buffer;
698708
if (auto Err =
699709
HR::toError(Device->CreateReservedResource(
700-
&ResDesc, D3D12_RESOURCE_STATE_COPY_DEST, nullptr,
710+
&ResDesc, D3D12_RESOURCE_STATE_COMMON, nullptr,
701711
IID_PPV_ARGS(&Buffer)),
702712
"Failed to create reserved resource (buffer)."))
703713
return Err;
704714

705715
// Committed Upload Buffer (CPU visible)
706716
ComPtr<ID3D12Resource> UploadBuffer;
707-
const D3D12_HEAP_PROPERTIES UploadHeapProps =
708-
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD);
709717

710718
if (auto Err = HR::toError(
711719
Device->CreateCommittedResource(
@@ -726,7 +734,8 @@ class DXDevice : public offloadtest::Device {
726734
return Err;
727735

728736
// Tile mapping setup (optional if NumTiles > 0)
729-
const UINT NumTiles = static_cast<UINT>(*R.TilesMapped);
737+
const UINT NumTiles = GetNumTiles(R.TilesMapped, ResDesc.Width);
738+
730739
ComPtr<ID3D12Heap> Heap; // optional, only created if NumTiles > 0
731740

732741
if (NumTiles > 0) {
@@ -743,7 +752,7 @@ class DXDevice : public offloadtest::Device {
743752
HeapDesc.Alignment = D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT;
744753
HeapDesc.SizeInBytes = static_cast<UINT64>(NumTiles) *
745754
D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT;
746-
HeapDesc.Flags = D3D12_HEAP_FLAG_ALLOW_ONLY_BUFFERS;
755+
HeapDesc.Flags = D3D12_HEAP_FLAG_ALLOW_ALL_BUFFERS_AND_TEXTURES;
747756

748757
if (auto Err =
749758
HR::toError(Device->CreateHeap(&HeapDesc, IID_PPV_ARGS(&Heap)),
@@ -796,89 +805,6 @@ class DXDevice : public offloadtest::Device {
796805
return Bundle;
797806
}
798807

799-
llvm::Expected<ResourceBundle> createCommittedUAV(Resource &R,
800-
InvocationState &IS) {
801-
ResourceBundle Bundle;
802-
const uint32_t BufferSize = getUAVBufferSize(R);
803-
804-
const D3D12_HEAP_PROPERTIES HeapProp =
805-
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT);
806-
const D3D12_RESOURCE_DESC ResDesc = getResourceDescription(R);
807-
808-
const D3D12_HEAP_PROPERTIES ReadBackHeapProp =
809-
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_READBACK);
810-
const D3D12_RESOURCE_DESC ReadBackResDesc = {
811-
D3D12_RESOURCE_DIMENSION_BUFFER,
812-
0,
813-
BufferSize,
814-
1,
815-
1,
816-
1,
817-
DXGI_FORMAT_UNKNOWN,
818-
{1, 0},
819-
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
820-
D3D12_RESOURCE_FLAG_NONE};
821-
822-
const D3D12_HEAP_PROPERTIES UploadHeapProp =
823-
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD);
824-
const D3D12_RESOURCE_DESC UploadResDesc =
825-
CD3DX12_RESOURCE_DESC::Buffer(BufferSize);
826-
827-
uint32_t RegOffset = 0;
828-
for (const auto &ResData : R.BufferPtr->Data) {
829-
llvm::outs() << "Creating UAV: { Size = " << BufferSize
830-
<< ", Register = u" << R.DXBinding.Register + RegOffset
831-
<< ", Space = " << R.DXBinding.Space
832-
<< ", HasCounter = " << R.HasCounter << " }\n";
833-
834-
ComPtr<ID3D12Resource> Buffer;
835-
if (auto Err = HR::toError(
836-
Device->CreateCommittedResource(
837-
&HeapProp, D3D12_HEAP_FLAG_NONE, &ResDesc,
838-
D3D12_RESOURCE_STATE_COMMON, nullptr, IID_PPV_ARGS(&Buffer)),
839-
"Failed to create committed resource (buffer)."))
840-
return Err;
841-
842-
ComPtr<ID3D12Resource> UploadBuffer;
843-
if (auto Err = HR::toError(
844-
Device->CreateCommittedResource(
845-
&UploadHeapProp, D3D12_HEAP_FLAG_NONE, &UploadResDesc,
846-
D3D12_RESOURCE_STATE_GENERIC_READ, nullptr,
847-
IID_PPV_ARGS(&UploadBuffer)),
848-
"Failed to create committed resource (upload buffer)."))
849-
return Err;
850-
851-
ComPtr<ID3D12Resource> ReadBackBuffer;
852-
if (auto Err = HR::toError(
853-
Device->CreateCommittedResource(
854-
&ReadBackHeapProp, D3D12_HEAP_FLAG_NONE, &ReadBackResDesc,
855-
D3D12_RESOURCE_STATE_COPY_DEST, nullptr,
856-
IID_PPV_ARGS(&ReadBackBuffer)),
857-
"Failed to create committed resource (readback buffer)."))
858-
return Err;
859-
860-
// Initialize the UAV data
861-
void *ResDataPtr = nullptr;
862-
if (auto Err = HR::toError(UploadBuffer->Map(0, nullptr, &ResDataPtr),
863-
"Failed to acquire UAV data pointer."))
864-
return Err;
865-
memcpy(ResDataPtr, ResData.get(), R.size());
866-
UploadBuffer->Unmap(0, nullptr);
867-
868-
addResourceUploadCommands(R, IS, Buffer, UploadBuffer);
869-
870-
Bundle.emplace_back(UploadBuffer, Buffer, ReadBackBuffer);
871-
RegOffset++;
872-
}
873-
return Bundle;
874-
}
875-
876-
llvm::Expected<ResourceBundle> createUAV(Resource &R, InvocationState &IS) {
877-
if (R.TilesMapped)
878-
return createReservedUAV(R, IS);
879-
return createCommittedUAV(R, IS);
880-
}
881-
882808
// returns the next available HeapIdx
883809
uint32_t bindUAV(Resource &R, InvocationState &IS, uint32_t HeapIdx,
884810
ResourceBundle ResBundle) {
@@ -1703,6 +1629,18 @@ class DXDevice : public offloadtest::Device {
17031629
return llvm::Error::success();
17041630
}
17051631

1632+
llvm::Error waitThenReturnErr(llvm::Error Err, InvocationState &IS) {
1633+
// Wait on the GPU before returning the error
1634+
llvm::Error WaitErr = waitForSignal(IS);
1635+
if (WaitErr)
1636+
// joinErrors returns an Error by value (move-only). Just return it
1637+
// directly.
1638+
return llvm::joinErrors(std::move(WaitErr), std::move(Err));
1639+
1640+
// No waiting error, just return the moved original.
1641+
return Err;
1642+
}
1643+
17061644
llvm::Error executeProgram(Pipeline &P) override {
17071645
llvm::sys::AddSignalHandler(
17081646
[](void *Cookie) {
@@ -1746,7 +1684,7 @@ class DXDevice : public offloadtest::Device {
17461684
return Err;
17471685
llvm::outs() << "Buffers created.\n";
17481686
if (auto Err = createEvent(State))
1749-
return Err;
1687+
return waitThenReturnErr(std::move(Err), State);
17501688
llvm::outs() << "Event prepared.\n";
17511689

17521690
if (P.isCompute()) {
@@ -1756,33 +1694,33 @@ class DXDevice : public offloadtest::Device {
17561694
std::errc::invalid_argument,
17571695
"Compute pipeline must have exactly one compute shader.");
17581696
if (auto Err = createComputePSO(P.Shaders[0].Shader->getBuffer(), State))
1759-
return Err;
1697+
return waitThenReturnErr(std::move(Err), State);
17601698
llvm::outs() << "PSO created.\n";
17611699
if (auto Err = createComputeCommands(P, State))
1762-
return Err;
1700+
return waitThenReturnErr(std::move(Err), State);
17631701
llvm::outs() << "Compute command list created.\n";
17641702

17651703
} else {
17661704
// Create render target, readback and vertex buffer and PSO.
17671705
if (auto Err = createRenderTarget(P, State))
1768-
return Err;
1706+
return waitThenReturnErr(std::move(Err), State);
17691707
llvm::outs() << "Render target created.\n";
17701708
if (auto Err = createVertexBuffer(P, State))
1771-
return Err;
1709+
return waitThenReturnErr(std::move(Err), State);
17721710
llvm::outs() << "Vertex buffer created.\n";
17731711
if (auto Err = createGraphicsPSO(P, State))
1774-
return Err;
1712+
return waitThenReturnErr(std::move(Err), State);
17751713
llvm::outs() << "Graphics PSO created.\n";
17761714
if (auto Err = createGraphicsCommands(P, State))
1777-
return Err;
1715+
return waitThenReturnErr(std::move(Err), State);
17781716
llvm::outs() << "Graphics command list created complete.\n";
17791717
}
17801718

17811719
if (auto Err = executeCommandList(State))
1782-
return Err;
1720+
return waitThenReturnErr(std::move(Err), State);
17831721
llvm::outs() << "Compute commands executed.\n";
17841722
if (auto Err = readBack(P, State))
1785-
return Err;
1723+
return waitThenReturnErr(std::move(Err), State);
17861724
llvm::outs() << "Read data back.\n";
17871725

17881726
return llvm::Error::success();

0 commit comments

Comments
 (0)