Skip to content

Commit 353a9d1

Browse files
authored
MP executor wrapper for streaming callbacks (#3822)
1 parent dfd7add commit 353a9d1

16 files changed

+130
-66
lines changed

src/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ ovms_cc_library(
477477
"capi_frontend/capi.cpp",
478478
"cleaner_utils.cpp",
479479
"cleaner_utils.hpp",
480+
"copyable_object_wrapper.hpp",
480481
"customloaderconfig.hpp",
481482
"customloaders.hpp",
482483
"customloaders.cpp",

src/copyable_object_wrapper.hpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//*****************************************************************************
2+
// Copyright 2025 Intel Corporation
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//*****************************************************************************
16+
#pragma once
17+
18+
#include <memory>
19+
#include <utility>
20+
21+
template <typename T>
22+
class UniqueObjectHolder {
23+
std::unique_ptr<T> object = nullptr;
24+
25+
public:
26+
std::unique_ptr<T>& get() {
27+
return object;
28+
}
29+
30+
void reset() {
31+
object.reset();
32+
}
33+
34+
bool valid() const {
35+
return object != nullptr;
36+
}
37+
};
38+
39+
template <typename T>
40+
class CopyableObjectWrapper {
41+
std::shared_ptr<UniqueObjectHolder<T>> objectHolder = nullptr;
42+
43+
public:
44+
explicit CopyableObjectWrapper() :
45+
objectHolder(std::make_shared<UniqueObjectHolder<T>>()) {}
46+
47+
explicit CopyableObjectWrapper(std::shared_ptr<UniqueObjectHolder<T>> objectHolder) :
48+
objectHolder(std::move(objectHolder)) {}
49+
50+
std::shared_ptr<UniqueObjectHolder<T>>& getObjectHolder() {
51+
return objectHolder;
52+
}
53+
};

src/http_rest_api_handler.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
#include "status.hpp"
6363
#include "stringutils.hpp"
6464
#include "timer.hpp"
65+
#include "copyable_object_wrapper.hpp"
6566

6667
#if (MEDIAPIPE_DISABLE == 0)
6768
#include "http_frontend/http_client_connection.hpp"
@@ -703,21 +704,28 @@ Status HttpRestApiHandler::processV3(const std::string_view uri, const HttpReque
703704
return status;
704705
}
705706

706-
std::shared_ptr<MediapipeGraphExecutor> executor;
707+
CopyableObjectWrapper<MediapipeGraphExecutor> executorWrapper;
708+
auto& executor = executorWrapper.getObjectHolder()->get();
707709
status = this->modelManager.createPipeline(executor, modelName);
708710
if (!status.ok()) {
709711
return status;
710712
}
711713

714+
if (!executorWrapper.getObjectHolder()->valid()) {
715+
SPDLOG_ERROR("Failed to acquire MediaPipe graph executor for model: {}", modelName);
716+
return StatusCode::INTERNAL_ERROR;
717+
}
718+
712719
if (streamFieldVal == false) {
713720
ExecutionContext executionContext{ExecutionContext::Interface::REST, ExecutionContext::Method::V3Unary};
714721
return executor->infer(&request, &response, executionContext);
715722
} else {
716723
serverReaderWriter->OverwriteResponseHeader("Content-Type", "text/event-stream");
717724
serverReaderWriter->OverwriteResponseHeader("Cache-Control", "no-cache");
718725
serverReaderWriter->OverwriteResponseHeader("Connection", "keep-alive");
719-
serverReaderWriter->PartialReplyBegin([executor = std::move(executor), serverReaderWriter, request = std::move(request)] {
726+
serverReaderWriter->PartialReplyBegin([executorWrapper = executorWrapper, serverReaderWriter, request = std::move(request)]() mutable {
720727
ExecutionContext executionContext{ExecutionContext::Interface::REST, ExecutionContext::Method::V3Stream};
728+
auto& executor = executorWrapper.getObjectHolder()->get();
721729
auto status = executor->inferStream(request, *serverReaderWriter, executionContext);
722730
if (!status.ok()) {
723731
rapidjson::StringBuffer buffer;
@@ -728,6 +736,7 @@ Status HttpRestApiHandler::processV3(const std::string_view uri, const HttpReque
728736
writer.EndObject();
729737
serverReaderWriter->PartialReplyWithStatus(buffer.GetString(), HTTPStatusCode::BAD_REQUEST);
730738
}
739+
executor.reset();
731740
serverReaderWriter->PartialReplyEnd();
732741
});
733742
return StatusCode::PARTIAL_END;

src/kfs_frontend/kfs_grpc_inference_service.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ Status KFSInferenceServiceImpl::ModelInferImpl(::grpc::ServerContext* context, c
303303
if (status == StatusCode::PIPELINE_DEFINITION_NAME_MISSING) {
304304
SPDLOG_DEBUG("Requested DAG: {} does not exist. Searching for mediapipe graph with that name...", request->model_name());
305305
#if (MEDIAPIPE_DISABLE == 0)
306-
std::shared_ptr<MediapipeGraphExecutor> executor;
306+
std::unique_ptr<MediapipeGraphExecutor> executor;
307307
status = this->modelManager.createPipeline(executor, request->model_name());
308308
if (!status.ok()) {
309309
return status;
@@ -346,7 +346,7 @@ Status KFSInferenceServiceImpl::ModelStreamInferImpl(::grpc::ServerContext* cont
346346
SPDLOG_DEBUG(status.string());
347347
return status;
348348
}
349-
std::shared_ptr<MediapipeGraphExecutor> executor;
349+
std::unique_ptr<MediapipeGraphExecutor> executor;
350350
auto status = this->modelManager.createPipeline(executor, firstRequest.model_name());
351351
if (!status.ok()) {
352352
return status;

src/mediapipe_internal/mediapipefactory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ Status MediapipeFactory::reloadDefinition(const std::string& name,
104104
return mgd->reload(manager, config);
105105
}
106106

107-
Status MediapipeFactory::create(std::shared_ptr<MediapipeGraphExecutor>& pipeline,
107+
Status MediapipeFactory::create(std::unique_ptr<MediapipeGraphExecutor>& pipeline,
108108
const std::string& name,
109109
ModelManager& manager) const {
110110
std::shared_lock lock(definitionsMtx);

src/mediapipe_internal/mediapipefactory.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class MediapipeFactory {
6666
ModelManager& manager) const;
6767

6868
public:
69-
Status create(std::shared_ptr<MediapipeGraphExecutor>& pipeline,
69+
Status create(std::unique_ptr<MediapipeGraphExecutor>& pipeline,
7070
const std::string& name,
7171
ModelManager& manager) const;
7272

src/mediapipe_internal/mediapipegraphdefinition.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ Status MediapipeGraphDefinition::createOutputsInfo() {
253253
return StatusCode::OK;
254254
}
255255

256-
Status MediapipeGraphDefinition::create(std::shared_ptr<MediapipeGraphExecutor>& pipeline) {
256+
Status MediapipeGraphDefinition::create(std::unique_ptr<MediapipeGraphExecutor>& pipeline) {
257257
std::unique_ptr<MediapipeGraphDefinitionUnloadGuard> unloadGuard;
258258
Status status = waitForLoaded(unloadGuard);
259259
if (!status.ok()) {
@@ -262,7 +262,7 @@ Status MediapipeGraphDefinition::create(std::shared_ptr<MediapipeGraphExecutor>&
262262
}
263263
SPDLOG_DEBUG("Creating Mediapipe graph executor: {}", getName());
264264

265-
pipeline = std::make_shared<MediapipeGraphExecutor>(getName(), std::to_string(getVersion()),
265+
pipeline = std::make_unique<MediapipeGraphExecutor>(getName(), std::to_string(getVersion()),
266266
this->config, this->inputTypes, this->outputTypes, this->inputNames, this->outputNames,
267267
this->sidePacketMaps,
268268
this->pythonBackend, this->reporter.get());

src/mediapipe_internal/mediapipegraphdefinition.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class MediapipeGraphDefinition {
119119
const tensor_map_t getOutputsInfo() const;
120120
const MediapipeGraphConfig& getMediapipeGraphConfig() const { return this->mgconfig; }
121121
MediapipeServableMetricReporter& getMetricReporter() const { return *this->reporter; }
122-
Status create(std::shared_ptr<MediapipeGraphExecutor>& pipeline);
122+
Status create(std::unique_ptr<MediapipeGraphExecutor>& pipeline);
123123

124124
Status reload(ModelManager& manager, const MediapipeGraphConfig& config);
125125
Status validate(ModelManager& manager);

src/mediapipe_internal/mediapipegraphexecutor.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,4 +393,5 @@ class MediapipeGraphExecutor {
393393
}
394394
}
395395
};
396+
396397
} // namespace ovms

src/modelmanager.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,7 @@ const std::vector<std::string> ModelManager::getNamesOfAvailableModels() const {
16971697
return names;
16981698
}
16991699

1700-
Status ModelManager::createPipeline(std::shared_ptr<MediapipeGraphExecutor>& graph,
1700+
Status ModelManager::createPipeline(std::unique_ptr<MediapipeGraphExecutor>& graph,
17011701
const std::string& name) {
17021702
#if (MEDIAPIPE_DISABLE == 0)
17031703
return this->mediapipeFactory.create(graph, name, *this);

0 commit comments

Comments
 (0)