2525// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
2727#include < stdint.h>
28-
2928#include < mutex>
3029#include < vector>
3130
@@ -107,10 +106,10 @@ class ModelState : public BackendModel {
107106 // onnx file, return in 'session' and 'allocator' the ORT session
108107 // and allocator.
109108 TRITONSERVER_Error* LoadModel (
110- const std::string& artifact_name,
109+ const std::string& artifact_name, const std::string& instance_name,
111110 const TRITONSERVER_InstanceGroupKind instance_group_kind,
112111 const int32_t instance_group_device_id, std::string* model_path,
113- OrtSession** session, OrtAllocator** default_allocator,
112+ std::shared_ptr< OrtSession>& session, OrtAllocator** default_allocator,
114113 cudaStream_t stream);
115114
116115 const std::map<std::string, std::pair<int64_t , int64_t >>& ModelOutputs ()
@@ -127,6 +126,11 @@ class ModelState : public BackendModel {
127126 TRITONSERVER_Error* AutoCompleteIO (
128127 const char * key, const OnnxTensorInfoMap& io_infos);
129128
129+ TRITONSERVER_Error* GetSessionForGroup (
130+ const std::string& group_name, std::shared_ptr<OrtSession>& session);
131+ TRITONSERVER_Error* SetSessionForGroup (
132+ const std::string& group_name, const std::shared_ptr<OrtSession>& session);
133+
130134 // Session options used when creating a ORT session.
131135 std::unique_ptr<OrtSessionOptions, SessionOptionsDeleter> session_options_;
132136
@@ -136,6 +140,17 @@ class ModelState : public BackendModel {
136140 // is specified both in the output section and state section, it indicates
137141 // that the backend must return the output state to the client too.
138142 std::map<std::string, std::pair<int64_t , int64_t >> model_outputs_;
143+
144+ // Indicate if an onnxrt session should be shared or not. This is a model
145+ // global and applies to all instances. So, storing it in the model state
146+ bool share_session_between_instances_;
147+
148+ // maintain a map of group id to onnx_rt session. This is only useful if
149+ // share_session_between_instances is set to true in parameters. share_session_between_instances is a global model
150+ // config and the user should be careful when setting this. There is no way to
151+ // set this per instance group.
152+ std::unordered_map<std::string, std::shared_ptr<OrtSession>>
153+ groupInstanceSessionMap_;
139154};
140155
141156TRITONSERVER_Error*
@@ -206,7 +221,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
206221}
207222
208223ModelState::ModelState (TRITONBACKEND_Model* triton_model)
209- : BackendModel(triton_model, true /* allow_optional */ )
224+ : BackendModel(triton_model, true /* allow_optional */ ), share_session_between_instances_( false )
210225{
211226 // Create session options that will be cloned and used for each
212227 // instance when creating that instance's session.
@@ -358,20 +373,31 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
358373 }
359374 }
360375 }
361-
362- // FIXME. Is it possible to share a single OrtSession across
363- // multiple instances? If so then should move loading and validation
364- // of the session to here instead of creating a session for each
365- // instance in ModelStateInstance::Create().
376+
377+ // This setting will apply across multiple instance groups.
378+ // If this value is set all instances within an instance group will share
379+ // the ort session
380+ {
381+ bool share_session_between_instances;
382+ triton::common::TritonJson::Value params;
383+ if (ModelConfig ().Find (" parameters" , ¶ms)) {
384+ THROW_IF_BACKEND_MODEL_ERROR (TryParseModelStringParameter (
385+ params, " share_session_between_instances" , &share_session_between_instances, false ));
386+ }
387+ share_session_between_instances_ = share_session_between_instances;
388+ }
366389}
367390
368391TRITONSERVER_Error*
369392ModelState::LoadModel (
370- const std::string& artifact_name,
393+ const std::string& artifact_name, const std::string& instance_name,
371394 const TRITONSERVER_InstanceGroupKind instance_group_kind,
372395 const int32_t instance_group_device_id, std::string* model_path,
373- OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream)
396+ std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
397+ cudaStream_t stream)
374398{
399+ // Get the group name for the instance
400+ std::string instance_group_name (GetInstanceGroupName (Name (), instance_name));
375401 // Find the ONNX file that describes the model itself. If the model
376402 // configuration doesn't have an explicit model file specified then
377403 // use the default name ("model.onnx").
@@ -383,6 +409,10 @@ ModelState::LoadModel(
383409 *model_path = JoinPath (
384410 {RepositoryPath (), std::to_string (Version ()), cc_model_filename});
385411
412+ // get default cpu allocator
413+ RETURN_IF_ORT_ERROR (
414+ ort_api->GetAllocatorWithDefaultOptions (default_allocator));
415+
386416 // If the model path is a directory then the actual model is
387417 // <dir>/model.onnx.
388418 {
@@ -393,6 +423,20 @@ ModelState::LoadModel(
393423 }
394424 }
395425
426+ // Check is we are sharing the session. If so get the session pointer and
427+ // return
428+ if (share_session_between_instances_) {
429+ if (GetSessionForGroup (instance_group_name, session) == nullptr ) {
430+ LOG_MESSAGE (
431+ TRITONSERVER_LOG_INFO,
432+ (std::string (" Reusing session for group: " ) + instance_group_name)
433+ .c_str ());
434+ // Return the session
435+ return nullptr ;
436+ }
437+ // In case of error carry on with the code
438+ }
439+
396440 {
397441 bool exists;
398442 RETURN_IF_ERROR (FileExists (*model_path, &exists));
@@ -656,12 +700,22 @@ ModelState::LoadModel(
656700 glock.lock ();
657701 }
658702
659- RETURN_IF_ERROR (OnnxLoader::LoadSession (
660- true /* is_path */ , *model_path, soptions, session));
703+ {
704+ // This will be allocated by OnnxRT here but will be freed when the last
705+ // instance of shared_ptr is released
706+ OrtSession* session_ptr;
707+ RETURN_IF_ERROR (OnnxLoader::LoadSession (
708+ true /* is_path */ , *model_path, soptions, &session_ptr));
661709
662- // get default cpu allocator
663- RETURN_IF_ORT_ERROR (
664- ort_api->GetAllocatorWithDefaultOptions (default_allocator));
710+ session = std::shared_ptr<OrtSession>(session_ptr, SessionDeleter ());
711+
712+ if (share_session_between_instances_) {
713+ // The session was created fine this is not a critical error
714+ LOG_IF_ERROR (
715+ SetSessionForGroup (instance_group_name, session),
716+ " Failed to map ort session to the group for sharing" );
717+ }
718+ }
665719
666720 return nullptr ; // success
667721}
@@ -705,7 +759,7 @@ ModelState::AutoCompleteConfig()
705759
706760 // Must cleanup 'session'. 'allocator' is default allocator which
707761 // is managed by ONNX Runtime so don't need to free/release
708- std::unique_ptr <OrtSession, SessionDeleter > session;
762+ std::shared_ptr <OrtSession> session;
709763 OrtAllocator* default_allocator;
710764 std::string model_path;
711765 {
@@ -734,12 +788,9 @@ ModelState::AutoCompleteConfig()
734788 }
735789 }
736790#endif // TRITON_ENABLE_GPU
737-
738- OrtSession* sptr = nullptr ;
739791 RETURN_IF_ERROR (LoadModel (
740- artifact_name, kind, 0 , &model_path, &sptr, &default_allocator,
741- nullptr ));
742- session.reset (sptr);
792+ artifact_name, " " , kind, 0 , &model_path,
793+ session, &default_allocator, nullptr ));
743794 }
744795 OnnxTensorInfoMap input_tensor_infos;
745796 RETURN_IF_ERROR (
@@ -906,6 +957,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
906957 return nullptr ; // success
907958}
908959
960+ TRITONSERVER_Error*
961+ ModelState::GetSessionForGroup (
962+ const std::string& group_name, std::shared_ptr<OrtSession>& session)
963+ {
964+ RETURN_ERROR_IF_TRUE (
965+ group_name.empty (), TRITONSERVER_ERROR_INVALID_ARG,
966+ std::string (" Invalid group name: " ) + group_name);
967+ {
968+ std::unordered_map<std::string, std::shared_ptr<OrtSession>>::iterator
969+ sessionEntry;
970+ sessionEntry = groupInstanceSessionMap_.find (group_name);
971+ RETURN_ERROR_IF_TRUE (
972+ (sessionEntry == groupInstanceSessionMap_.end ()),
973+ TRITONSERVER_ERROR_NOT_FOUND, std::string (" No such group" ) + group_name);
974+
975+ session = sessionEntry->second ;
976+ }
977+ return nullptr ;
978+ }
979+
980+ TRITONSERVER_Error*
981+ ModelState::SetSessionForGroup (
982+ const std::string& group_name, const std::shared_ptr<OrtSession>& session)
983+ {
984+ RETURN_ERROR_IF_TRUE (
985+ group_name.empty (), TRITONSERVER_ERROR_INVALID_ARG,
986+ std::string (" Invalid group name" ) + group_name);
987+
988+ groupInstanceSessionMap_[group_name] = session;
989+ return nullptr ;
990+ }
991+
909992//
910993// ModelInstanceState
911994//
@@ -992,7 +1075,7 @@ class ModelInstanceState : public BackendModelInstance {
9921075
9931076 // Onnx Runtime variables that are used across runs on this
9941077 // instance.
995- OrtSession* session_;
1078+ std::unique_ptr< OrtSession> session_;
9961079 OrtAllocator* default_allocator_;
9971080 OrtMemoryInfo* cuda_allocator_info_;
9981081 const OrtMemoryInfo* cpu_allocator_info_;
@@ -1044,7 +1127,7 @@ ModelInstanceState::ModelInstanceState(
10441127 io_binding_(nullptr ), output_buffer_(nullptr )
10451128{
10461129 THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
1047- ArtifactFilename (), Kind (), DeviceId (), &model_path_, & session_,
1130+ ArtifactFilename (), Name (), Kind (), DeviceId (), &model_path_, session_,
10481131 &default_allocator_, CudaStream ()));
10491132
10501133 if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
@@ -1057,7 +1140,7 @@ ModelInstanceState::ModelInstanceState(
10571140 ort_api->AllocatorGetInfo (default_allocator_, &cpu_allocator_info_));
10581141
10591142 THROW_IF_BACKEND_INSTANCE_ORT_ERROR (
1060- ort_api->CreateIoBinding (session_, &io_binding_));
1143+ ort_api->CreateIoBinding (session_. get () , &io_binding_));
10611144
10621145 THROW_IF_BACKEND_INSTANCE_ORT_ERROR (ort_api->CreateRunOptions (&runOptions_));
10631146
@@ -1156,9 +1239,6 @@ ModelInstanceState::~ModelInstanceState()
11561239 ort_api->ReleaseRunOptions (runOptions_);
11571240 ort_api->ReleaseIoBinding (io_binding_);
11581241 ort_api->ReleaseMemoryInfo (cuda_allocator_info_);
1159- if (session_ != nullptr ) {
1160- OnnxLoader::UnloadSession (session_);
1161- }
11621242 // 'default_allocator_' is default allocator which is managed by ONNX
11631243 // Runtime
11641244}
@@ -1220,7 +1300,7 @@ ModelInstanceState::ValidateBooleanSequenceControl(
12201300 if (*have_control) {
12211301 OnnxTensorInfoMap input_tensor_infos;
12221302 RETURN_IF_ERROR (
1223- InputInfos (session_, default_allocator_, input_tensor_infos));
1303+ InputInfos (session_. get () , default_allocator_, input_tensor_infos));
12241304 const auto & iit = input_tensor_infos.find (tensor_name);
12251305 if (iit == input_tensor_infos.end ()) {
12261306 return TRITONSERVER_ErrorNew (
@@ -1277,7 +1357,7 @@ ModelInstanceState::ValidateTypedSequenceControl(
12771357 if (*have_control) {
12781358 OnnxTensorInfoMap input_tensor_infos;
12791359 RETURN_IF_ERROR (
1280- InputInfos (session_, default_allocator_, input_tensor_infos));
1360+ InputInfos (session_. get () , default_allocator_, input_tensor_infos));
12811361 const auto & iit = input_tensor_infos.find (tensor_name);
12821362 if (iit == input_tensor_infos.end ()) {
12831363 return TRITONSERVER_ErrorNew (
@@ -1324,17 +1404,17 @@ TRITONSERVER_Error*
13241404ModelInstanceState::ValidateInputs (const size_t expected_input_cnt)
13251405{
13261406 std::set<std::string> input_tensor_names;
1327- RETURN_IF_ERROR (InputNames (session_, input_tensor_names));
1407+ RETURN_IF_ERROR (InputNames (session_. get () , input_tensor_names));
13281408 RETURN_IF_ERROR (
1329- InputInfos (session_, default_allocator_, input_tensor_infos_));
1409+ InputInfos (session_. get () , default_allocator_, input_tensor_infos_));
13301410
13311411 std::set<std::string> overridable_initializer_tensor_names;
13321412 RETURN_IF_ERROR (OverridableInitializerNames (
1333- session_, overridable_initializer_tensor_names));
1413+ session_. get () , overridable_initializer_tensor_names));
13341414
13351415 OnnxTensorInfoMap overridable_initializer_tensor_infos;
13361416 RETURN_IF_ERROR (OverridableInitializerInfos (
1337- session_, default_allocator_, overridable_initializer_tensor_infos));
1417+ session_. get () , default_allocator_, overridable_initializer_tensor_infos));
13381418
13391419 if (input_tensor_infos_.size () != expected_input_cnt) {
13401420 return TRITONSERVER_ErrorNew (
@@ -1471,10 +1551,10 @@ TRITONSERVER_Error*
14711551ModelInstanceState::ValidateOutputs ()
14721552{
14731553 std::set<std::string> output_tensor_names;
1474- RETURN_IF_ERROR (OutputNames (session_, output_tensor_names));
1554+ RETURN_IF_ERROR (OutputNames (session_. get () , output_tensor_names));
14751555
14761556 RETURN_IF_ERROR (
1477- OutputInfos (session_, default_allocator_, output_tensor_infos_));
1557+ OutputInfos (session_. get () , default_allocator_, output_tensor_infos_));
14781558
14791559 triton::common::TritonJson::Value ios;
14801560 RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" output" , &ios));
@@ -1871,7 +1951,7 @@ ModelInstanceState::OrtRun(
18711951 const uint32_t response_count)
18721952{
18731953 RETURN_IF_ORT_ERROR (
1874- ort_api->RunWithBinding (session_, runOptions_, io_binding_));
1954+ ort_api->RunWithBinding (session_. get () , runOptions_, io_binding_));
18751955 return nullptr ;
18761956}
18771957
@@ -2411,7 +2491,6 @@ ModelInstanceState::ReadOutputTensors(
24112491 }
24122492 }
24132493
2414-
24152494 } else {
24162495 char * output_buffer = nullptr ;
24172496 RETURN_IF_ORT_ERROR (
0 commit comments