2525// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
2727#include < stdint.h>
28-
2928#include < mutex>
3029#include < vector>
3130
@@ -81,10 +80,10 @@ class ModelState : public BackendModel {
8180 // onnx file, return in 'session' and 'allocator' the ORT session
8281 // and allocator.
8382 TRITONSERVER_Error* LoadModel (
84- const std::string& artifact_name,
83+ const std::string& artifact_name, const std::string& instance_name,
8584 const TRITONSERVER_InstanceGroupKind instance_group_kind,
8685 const int32_t instance_group_device_id, std::string* model_path,
87- OrtSession** session, OrtAllocator** default_allocator,
86+ std::shared_ptr< OrtSession>& session, OrtAllocator** default_allocator,
8887 cudaStream_t stream);
8988
9089 const std::map<std::string, std::pair<int64_t , int64_t >>& ModelOutputs ()
@@ -101,6 +100,11 @@ class ModelState : public BackendModel {
101100 TRITONSERVER_Error* AutoCompleteIO (
102101 const char * key, const OnnxTensorInfoMap& io_infos);
103102
103+ TRITONSERVER_Error* GetSessionForGroup (
104+ const std::string& group_name, std::shared_ptr<OrtSession>& session);
105+ TRITONSERVER_Error* SetSessionForGroup (
106+ const std::string& group_name, const std::shared_ptr<OrtSession>& session);
107+
104108 // Session options used when creating a ORT session.
105109 std::unique_ptr<OrtSessionOptions, SessionOptionsDeleter> session_options_;
106110
@@ -110,6 +114,17 @@ class ModelState : public BackendModel {
110114 // is specified both in the output section and state section, it indicates
111115 // that the backend must return the output state to the client too.
112116 std::map<std::string, std::pair<int64_t , int64_t >> model_outputs_;
117+
118+ // Indicate if an onnxrt session should be shared or not. This is a model
119+ // global and applies to all instances. So, storing it in the model state
120+ bool share_session_;
121+
122+ // maintain a map of group id to onnx_rt session. This is only useful if
123+ // share_session is set to true in parameters. share_session is a global model
124+ // config and the user should be careful when setting this. There is no way to
125+ // set this per instance group.
126+ std::unordered_map<std::string, std::shared_ptr<OrtSession>>
127+ groupInstanceSessionMap_;
113128};
114129
115130TRITONSERVER_Error*
@@ -188,7 +203,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
188203}
189204
190205ModelState::ModelState (TRITONBACKEND_Model* triton_model)
191- : BackendModel(triton_model)
206+ : BackendModel(triton_model), share_session_( false )
192207{
193208 // Create session options that will be cloned and used for each
194209 // instance when creating that instance's session.
@@ -338,20 +353,31 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
338353 }
339354 }
340355 }
341-
342- // FIXME. Is it possible to share a single OrtSession across
343- // multiple instances? If so then should move loading and validation
344- // of the session to here instead of creating a session for each
345- // instance in ModelStateInstance::Create().
356+
357+ // This setting will apply across multiple instance groups.
358+ // If this value is set all instances within an instance group will share
359+ // the ort session
360+ {
361+ bool share_session;
362+ triton::common::TritonJson::Value params;
363+ if (ModelConfig ().Find (" parameters" , ¶ms)) {
364+ THROW_IF_BACKEND_MODEL_ERROR (TryParseModelStringParameter (
365+ params, " share_session" , &share_session, false ));
366+ }
367+ share_session_ = share_session;
368+ }
346369}
347370
348371TRITONSERVER_Error*
349372ModelState::LoadModel (
350- const std::string& artifact_name,
373+ const std::string& artifact_name, const std::string& instance_name,
351374 const TRITONSERVER_InstanceGroupKind instance_group_kind,
352375 const int32_t instance_group_device_id, std::string* model_path,
353- OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream)
376+ std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
377+ cudaStream_t stream)
354378{
379+ // Get the group name for the instance
380+ std::string instance_group_name (GetInstanceGroupName (Name (), instance_name));
355381 // Find the ONNX file that describes the model itself. If the model
356382 // configuration doesn't have an explicit model file specified then
357383 // use the default name ("model.onnx").
@@ -363,6 +389,10 @@ ModelState::LoadModel(
363389 *model_path = JoinPath (
364390 {RepositoryPath (), std::to_string (Version ()), cc_model_filename});
365391
392+ // get default cpu allocator
393+ RETURN_IF_ORT_ERROR (
394+ ort_api->GetAllocatorWithDefaultOptions (default_allocator));
395+
366396 // If the model path is a directory then the actual model is
367397 // <dir>/model.onnx.
368398 {
@@ -373,6 +403,20 @@ ModelState::LoadModel(
373403 }
374404 }
375405
406+ // Check is we are sharing the session. If so get the session pointer and
407+ // return
408+ if (share_session_) {
409+ if (GetSessionForGroup (instance_group_name, session) == nullptr ) {
410+ LOG_MESSAGE (
411+ TRITONSERVER_LOG_INFO,
412+ (std::string (" Reusing session for group: " ) + instance_group_name)
413+ .c_str ());
414+ // Return the session
415+ return nullptr ;
416+ }
417+ // In case of error carry on with the code
418+ }
419+
376420 {
377421 bool exists;
378422 RETURN_IF_ERROR (FileExists (*model_path, &exists));
@@ -636,12 +680,22 @@ ModelState::LoadModel(
636680 glock.lock ();
637681 }
638682
639- RETURN_IF_ERROR (OnnxLoader::LoadSession (
640- true /* is_path */ , *model_path, soptions, session));
683+ {
684+ // This will be allocated by OnnxRT here but will be freed when the last
685+ // instance of shared_ptr is released
686+ OrtSession* session_ptr;
687+ RETURN_IF_ERROR (OnnxLoader::LoadSession (
688+ true /* is_path */ , *model_path, soptions, &session_ptr));
641689
642- // get default cpu allocator
643- RETURN_IF_ORT_ERROR (
644- ort_api->GetAllocatorWithDefaultOptions (default_allocator));
690+ session = std::shared_ptr<OrtSession>(session_ptr, SessionDeleter ());
691+
692+ if (share_session_) {
693+ // The session was created fine this is not a critical error
694+ LOG_IF_ERROR (
695+ SetSessionForGroup (instance_group_name, session),
696+ " Failed to map ort session to the group for sharing" );
697+ }
698+ }
645699
646700 return nullptr ; // success
647701}
@@ -685,7 +739,7 @@ ModelState::AutoCompleteConfig()
685739
686740 // Must cleanup 'session'. 'allocator' is default allocator which
687741 // is managed by ONNX Runtime so don't need to free/release
688- std::unique_ptr <OrtSession, SessionDeleter > session;
742+ std::shared_ptr <OrtSession> session;
689743 OrtAllocator* default_allocator;
690744 std::string model_path;
691745 {
@@ -714,12 +768,9 @@ ModelState::AutoCompleteConfig()
714768 }
715769 }
716770#endif // TRITON_ENABLE_GPU
717-
718- OrtSession* sptr = nullptr ;
719771 RETURN_IF_ERROR (LoadModel (
720- artifact_name, kind, 0 , &model_path, &sptr, &default_allocator,
721- nullptr ));
722- session.reset (sptr);
772+ artifact_name, " " , kind, 0 , &model_path,
773+ session, &default_allocator, nullptr ));
723774 }
724775 OnnxTensorInfoMap input_tensor_infos;
725776 RETURN_IF_ERROR (
@@ -881,6 +932,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
881932 return nullptr ; // success
882933}
883934
935+ TRITONSERVER_Error*
936+ ModelState::GetSessionForGroup (
937+ const std::string& group_name, std::shared_ptr<OrtSession>& session)
938+ {
939+ RETURN_ERROR_IF_TRUE (
940+ group_name.empty (), TRITONSERVER_ERROR_INVALID_ARG,
941+ std::string (" Invalid group name" ));
942+ {
943+ std::unordered_map<std::string, std::shared_ptr<OrtSession>>::iterator
944+ sessionEntry;
945+ sessionEntry = groupInstanceSessionMap_.find (group_name);
946+ RETURN_ERROR_IF_TRUE (
947+ (sessionEntry == groupInstanceSessionMap_.end ()),
948+ TRITONSERVER_ERROR_NOT_FOUND, std::string (" No such group" ));
949+
950+ session = sessionEntry->second ;
951+ }
952+ return nullptr ;
953+ }
954+
955+ TRITONSERVER_Error*
956+ ModelState::SetSessionForGroup (
957+ const std::string& group_name, const std::shared_ptr<OrtSession>& session)
958+ {
959+ RETURN_ERROR_IF_TRUE (
960+ group_name.empty (), TRITONSERVER_ERROR_INVALID_ARG,
961+ std::string (" Invalid group name" ));
962+
963+ groupInstanceSessionMap_[group_name] = session;
964+ return nullptr ;
965+ }
966+
884967//
885968// ModelInstanceState
886969//
@@ -967,7 +1050,7 @@ class ModelInstanceState : public BackendModelInstance {
9671050
9681051 // Onnx Runtime variables that are used across runs on this
9691052 // instance.
970- OrtSession* session_;
1053+ std::shared_ptr< OrtSession> session_;
9711054 OrtAllocator* default_allocator_;
9721055 OrtMemoryInfo* cuda_allocator_info_;
9731056 const OrtMemoryInfo* cpu_allocator_info_;
@@ -1013,7 +1096,7 @@ ModelInstanceState::ModelInstanceState(
10131096 io_binding_(nullptr ), output_buffer_(nullptr )
10141097{
10151098 THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
1016- ArtifactFilename (), Kind (), DeviceId (), &model_path_, & session_,
1099+ ArtifactFilename (), Name (), Kind (), DeviceId (), &model_path_, session_,
10171100 &default_allocator_, CudaStream ()));
10181101
10191102 if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
@@ -1026,7 +1109,7 @@ ModelInstanceState::ModelInstanceState(
10261109 ort_api->AllocatorGetInfo (default_allocator_, &cpu_allocator_info_));
10271110
10281111 THROW_IF_BACKEND_INSTANCE_ORT_ERROR (
1029- ort_api->CreateIoBinding (session_, &io_binding_));
1112+ ort_api->CreateIoBinding (session_. get () , &io_binding_));
10301113
10311114 THROW_IF_BACKEND_INSTANCE_ORT_ERROR (ort_api->CreateRunOptions (&runOptions_));
10321115
@@ -1114,9 +1197,6 @@ ModelInstanceState::~ModelInstanceState()
11141197 ort_api->ReleaseRunOptions (runOptions_);
11151198 ort_api->ReleaseIoBinding (io_binding_);
11161199 ort_api->ReleaseMemoryInfo (cuda_allocator_info_);
1117- if (session_ != nullptr ) {
1118- OnnxLoader::UnloadSession (session_);
1119- }
11201200 // 'default_allocator_' is default allocator which is managed by ONNX
11211201 // Runtime
11221202}
@@ -1176,7 +1256,7 @@ ModelInstanceState::ValidateBooleanSequenceControl(
11761256 if (*have_control) {
11771257 OnnxTensorInfoMap input_tensor_infos;
11781258 RETURN_IF_ERROR (
1179- InputInfos (session_, default_allocator_, input_tensor_infos));
1259+ InputInfos (session_. get () , default_allocator_, input_tensor_infos));
11801260 const auto & iit = input_tensor_infos.find (tensor_name);
11811261 if (iit == input_tensor_infos.end ()) {
11821262 return TRITONSERVER_ErrorNew (
@@ -1233,7 +1313,7 @@ ModelInstanceState::ValidateTypedSequenceControl(
12331313 if (*have_control) {
12341314 OnnxTensorInfoMap input_tensor_infos;
12351315 RETURN_IF_ERROR (
1236- InputInfos (session_, default_allocator_, input_tensor_infos));
1316+ InputInfos (session_. get () , default_allocator_, input_tensor_infos));
12371317 const auto & iit = input_tensor_infos.find (tensor_name);
12381318 if (iit == input_tensor_infos.end ()) {
12391319 return TRITONSERVER_ErrorNew (
@@ -1280,10 +1360,11 @@ TRITONSERVER_Error*
12801360ModelInstanceState::ValidateInputs (const size_t expected_input_cnt)
12811361{
12821362 std::set<std::string> input_tensor_names;
1283- RETURN_IF_ERROR (InputNames (session_, input_tensor_names));
1363+ RETURN_IF_ERROR (InputNames (session_. get () , input_tensor_names));
12841364
12851365 OnnxTensorInfoMap input_tensor_infos;
1286- RETURN_IF_ERROR (InputInfos (session_, default_allocator_, input_tensor_infos));
1366+ RETURN_IF_ERROR (
1367+ InputInfos (session_.get (), default_allocator_, input_tensor_infos));
12871368
12881369 if (input_tensor_infos.size () != expected_input_cnt) {
12891370 return TRITONSERVER_ErrorNew (
@@ -1368,10 +1449,10 @@ TRITONSERVER_Error*
13681449ModelInstanceState::ValidateOutputs ()
13691450{
13701451 std::set<std::string> output_tensor_names;
1371- RETURN_IF_ERROR (OutputNames (session_, output_tensor_names));
1452+ RETURN_IF_ERROR (OutputNames (session_. get () , output_tensor_names));
13721453
13731454 RETURN_IF_ERROR (
1374- OutputInfos (session_, default_allocator_, output_tensor_infos_));
1455+ OutputInfos (session_. get () , default_allocator_, output_tensor_infos_));
13751456
13761457 triton::common::TritonJson::Value ios;
13771458 RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" output" , &ios));
@@ -1765,7 +1846,7 @@ ModelInstanceState::OrtRun(
17651846 const uint32_t response_count)
17661847{
17671848 RETURN_IF_ORT_ERROR (
1768- ort_api->RunWithBinding (session_, runOptions_, io_binding_));
1849+ ort_api->RunWithBinding (session_. get () , runOptions_, io_binding_));
17691850 return nullptr ;
17701851}
17711852
@@ -2267,7 +2348,6 @@ ModelInstanceState::ReadOutputTensors(
22672348 }
22682349 }
22692350
2270-
22712351 } else {
22722352 char * output_buffer = nullptr ;
22732353 RETURN_IF_ORT_ERROR (
0 commit comments