2222#include " outputconnectorstrategy.h"
2323#include < thread>
2424#include < algorithm>
25- #include " utils/utils.hpp"
2625
2726// NCNN
2827#include " ncnnlib.h"
@@ -53,10 +52,10 @@ namespace dd
5352 {
5453 this ->_libname = " ncnn" ;
5554 _net = new ncnn::Net ();
56- _net->opt .num_threads = _threads ;
55+ _net->opt .num_threads = 1 ;
5756 _net->opt .blob_allocator = &_blob_pool_allocator;
5857 _net->opt .workspace_allocator = &_workspace_pool_allocator;
59- _net->opt .lightmode = _lightmode ;
58+ _net->opt .lightmode = true ;
6059 }
6160
6261 template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
@@ -69,12 +68,9 @@ namespace dd
6968 this ->_libname = " ncnn" ;
7069 _net = tl._net ;
7170 tl._net = nullptr ;
72- _nclasses = tl._nclasses ;
73- _threads = tl._threads ;
7471 _timeserie = tl._timeserie ;
7572 _old_height = tl._old_height ;
76- _inputBlob = tl._inputBlob ;
77- _outputBlob = tl._outputBlob ;
73+ _init_dto = tl._init_dto ;
7874 }
7975
8076 template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
@@ -94,6 +90,8 @@ namespace dd
9490 void NCNNLib<TInputConnectorStrategy, TOutputConnectorStrategy,
9591 TMLModel>::init_mllib(const APIData &ad)
9692 {
93+ _init_dto = ad.createSharedDTO <NcnnInitDto>();
94+
9795 bool use_fp32 = (ad.has (" datatype" )
9896 && ad.get (" datatype" ).get <std::string>()
9997 == " fp32" ); // default is fp16
@@ -124,35 +122,11 @@ namespace dd
124122 _old_height = this ->_inputc .height ();
125123 _net->set_input_h (_old_height);
126124
127- if (ad.has (" nclasses" ))
128- _nclasses = ad.get (" nclasses" ).get <int >();
129-
130- if (ad.has (" threads" ))
131- _threads = ad.get (" threads" ).get <int >();
132- else
133- _threads = dd_utils::my_hardware_concurrency ();
134-
135125 _timeserie = this ->_inputc ._timeserie ;
136126 if (_timeserie)
137127 this ->_mltype = " timeserie" ;
138128
139- if (ad.has (" lightmode" ))
140- {
141- _lightmode = ad.get (" lightmode" ).get <bool >();
142- _net->opt .lightmode = _lightmode;
143- }
144-
145- // setting the value of Input Layer
146- if (ad.has (" inputblob" ))
147- {
148- _inputBlob = ad.get (" inputblob" ).get <std::string>();
149- }
150- // setting the final Output Layer
151- if (ad.has (" outputblob" ))
152- {
153- _outputBlob = ad.get (" outputblob" ).get <std::string>();
154- }
155-
129+ _net->opt .lightmode = _init_dto->lightmode ;
156130 _blob_pool_allocator.set_size_compare_ratio (0 .0f );
157131 _workspace_pool_allocator.set_size_compare_ratio (0 .5f );
158132 model_type (this ->_mlmodel ._params , this ->_mltype );
@@ -233,7 +207,10 @@ namespace dd
233207
234208 // Extract detection or classification
235209 int ret = 0 ;
236- std::string out_blob = _outputBlob;
210+ std::string out_blob;
211+ if (_init_dto->outputBlob != nullptr )
212+ out_blob = _init_dto->outputBlob ->std_str ();
213+
237214 if (out_blob.empty ())
238215 {
239216 if (bbox == true )
@@ -262,11 +239,11 @@ namespace dd
262239 {
263240 best = ad_output.get (" best" ).get <int >();
264241 }
265- if (best == -1 || best > _nclasses )
266- best = _nclasses ;
242+ if (best == -1 || best > _init_dto-> nclasses )
243+ best = _init_dto-> nclasses ;
267244
268245 // for loop around batch size
269- #pragma omp parallel for num_threads(_threads )
246+ #pragma omp parallel for num_threads(*_init_dto->threads )
270247 for (size_t b = 0 ; b < inputc._ids .size (); b++)
271248 {
272249 std::vector<double > probs;
@@ -276,8 +253,8 @@ namespace dd
276253 APIData rad;
277254
278255 ncnn::Extractor ex = _net->create_extractor ();
279- ex.set_num_threads (_threads );
280- ex.input (_inputBlob. c_str (), inputc._in .at (b));
256+ ex.set_num_threads (_init_dto-> threads );
257+ ex.input (_init_dto-> inputBlob -> c_str (), inputc._in .at (b));
281258
282259 ret = ex.extract (out_blob.c_str (), inputc._out .at (b));
283260 if (ret == -1 )
@@ -423,7 +400,8 @@ namespace dd
423400 } // end for batch_size
424401
425402 tout.add_results (vrad);
426- out.add (" nclasses" , this ->_nclasses );
403+ int nclasses = this ->_init_dto ->nclasses ;
404+ out.add (" nclasses" , nclasses);
427405 if (bbox == true )
428406 out.add (" bbox" , true );
429407 out.add (" roi" , false );
0 commit comments