Skip to content

Commit 90c0965

Browse files
authored
Fix L0_infer_cudashm failure (#108)
* Fix L0_infer_cudashm failure * Review edit
1 parent 2e7af66 commit 90c0965

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

src/onnxruntime.cc

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,8 +1576,7 @@ ModelInstanceState::ProcessRequests(
15761576
.c_str()));
15771577
} else if (iit->second.type_ != ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
15781578
// Query the memory type of destination output buffer. Bind the
1579-
// output
1580-
// to this destination memory type. The destination memory type
1579+
// output to this destination memory type. The destination memory type
15811580
// for an output for all requests should be same. So use any request
15821581
// for this query.
15831582
memory_type = preferred_memory_type;
@@ -1598,6 +1597,13 @@ ModelInstanceState::ProcessRequests(
15981597
memory_type_id = 0;
15991598
}
16001599
}
1600+
1601+
// If the cuda allocator is not set, bind the output to CPU.
1602+
if (cuda_allocator_info_ == nullptr) {
1603+
memory_type = TRITONSERVER_MEMORY_CPU;
1604+
memory_type_id = 0;
1605+
}
1606+
16011607
// finally save the derived mem type and device id as we need it for
16021608
// reading the outputs.
16031609
output_device_info_.insert(
@@ -1775,13 +1781,12 @@ ModelInstanceState::SetInputTensors(
17751781
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>>
17761782
allowed_input_types;
17771783
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
1778-
allowed_input_types = {
1779-
{TRITONSERVER_MEMORY_GPU, DeviceId()},
1780-
{TRITONSERVER_MEMORY_CPU_PINNED, 0},
1781-
{TRITONSERVER_MEMORY_CPU, 0}};
1784+
allowed_input_types = {{TRITONSERVER_MEMORY_GPU, DeviceId()},
1785+
{TRITONSERVER_MEMORY_CPU_PINNED, 0},
1786+
{TRITONSERVER_MEMORY_CPU, 0}};
17821787
} else {
1783-
allowed_input_types = {
1784-
{TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}};
1788+
allowed_input_types = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
1789+
{TRITONSERVER_MEMORY_CPU, 0}};
17851790
}
17861791

17871792
RETURN_IF_ERROR(collector->ProcessTensor(

0 commit comments

Comments
 (0)