2222#include " outputconnectorstrategy.h"
2323#include < thread>
2424#include < algorithm>
25- #include " utils/utils.hpp"
2625
2726// NCNN
2827#include " ncnnlib.h"
@@ -53,10 +52,10 @@ namespace dd
5352 {
5453 this ->_libname = " ncnn" ;
5554 _net = new ncnn::Net ();
56- _net->opt .num_threads = _threads ;
55+ _net->opt .num_threads = 1 ;
5756 _net->opt .blob_allocator = &_blob_pool_allocator;
5857 _net->opt .workspace_allocator = &_workspace_pool_allocator;
59- _net->opt .lightmode = _lightmode ;
58+ _net->opt .lightmode = true ;
6059 }
6160
6261 template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
@@ -69,12 +68,9 @@ namespace dd
6968 this ->_libname = " ncnn" ;
7069 _net = tl._net ;
7170 tl._net = nullptr ;
72- _nclasses = tl._nclasses ;
73- _threads = tl._threads ;
7471 _timeserie = tl._timeserie ;
7572 _old_height = tl._old_height ;
76- _inputBlob = tl._inputBlob ;
77- _outputBlob = tl._outputBlob ;
73+ _init_dto = tl._init_dto ;
7874 }
7975
8076 template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
@@ -94,6 +90,8 @@ namespace dd
9490 void NCNNLib<TInputConnectorStrategy, TOutputConnectorStrategy,
9591 TMLModel>::init_mllib(const APIData &ad)
9692 {
93+ _init_dto = ad.createSharedDTO <NcnnInitDto>();
94+
9795 bool use_fp32 = (ad.has (" datatype" )
9896 && ad.get (" datatype" ).get <std::string>()
9997 == " fp32" ); // default is fp16
@@ -124,35 +122,11 @@ namespace dd
124122 _old_height = this ->_inputc .height ();
125123 _net->set_input_h (_old_height);
126124
127- if (ad.has (" nclasses" ))
128- _nclasses = ad.get (" nclasses" ).get <int >();
129-
130- if (ad.has (" threads" ))
131- _threads = ad.get (" threads" ).get <int >();
132- else
133- _threads = dd_utils::my_hardware_concurrency ();
134-
135125 _timeserie = this ->_inputc ._timeserie ;
136126 if (_timeserie)
137127 this ->_mltype = " timeserie" ;
138128
139- if (ad.has (" lightmode" ))
140- {
141- _lightmode = ad.get (" lightmode" ).get <bool >();
142- _net->opt .lightmode = _lightmode;
143- }
144-
145- // setting the value of Input Layer
146- if (ad.has (" inputblob" ))
147- {
148- _inputBlob = ad.get (" inputblob" ).get <std::string>();
149- }
150- // setting the final Output Layer
151- if (ad.has (" outputblob" ))
152- {
153- _outputBlob = ad.get (" outputblob" ).get <std::string>();
154- }
155-
129+ _net->opt .lightmode = _init_dto->lightmode ;
156130 _blob_pool_allocator.set_size_compare_ratio (0 .0f );
157131 _workspace_pool_allocator.set_size_compare_ratio (0 .5f );
158132 model_type (this ->_mlmodel ._params , this ->_mltype );
@@ -232,8 +206,7 @@ namespace dd
232206 }
233207
234208 // Extract detection or classification
235- int ret = 0 ;
236- std::string out_blob = _outputBlob;
209+ std::string out_blob = _init_dto->outputBlob ->std_str ();
237210 if (out_blob.empty ())
238211 {
239212 if (bbox == true )
@@ -245,6 +218,14 @@ namespace dd
245218 else
246219 out_blob = " prob" ;
247220 }
221+ <<<<<<< HEAD
222+ =======
223+ int ret = ex.extract (out_blob.c_str (), inputc._out );
224+ if (ret == -1 )
225+ {
226+ throw MLLibInternalException (" NCNN internal error" );
227+ }
228+ >>>>>>> 55bb8639 (feat: use DTO for NCNN init parameters)
248229
249230 std::vector<APIData> vrad;
250231
@@ -262,8 +243,8 @@ namespace dd
262243 {
263244 best = ad_output.get (" best" ).get <int >();
264245 }
265- if (best == -1 || best > _nclasses )
266- best = _nclasses ;
246+ if (best == -1 || best > _init_dto-> nclasses )
247+ best = _init_dto-> nclasses ;
267248
268249 // for loop around batch size
269250#pragma omp parallel for num_threads(_threads)
@@ -276,8 +257,8 @@ namespace dd
276257 APIData rad;
277258
278259 ncnn::Extractor ex = _net->create_extractor ();
279- ex.set_num_threads (_threads );
280- ex.input (_inputBlob. c_str (), inputc._in .at (b));
260+ ex.set_num_threads (_init_dto-> threads );
261+ ex.input (_init_dto-> inputBlob -> c_str (), inputc._in .at (b));
281262
282263 ret = ex.extract (out_blob.c_str (), inputc._out .at (b));
283264 if (ret == -1 )
@@ -423,7 +404,8 @@ namespace dd
423404 } // end for batch_size
424405
425406 tout.add_results (vrad);
426- out.add (" nclasses" , this ->_nclasses );
407+ int nclasses = this ->_init_dto ->nclasses ;
408+ out.add (" nclasses" , nclasses);
427409 if (bbox == true )
428410 out.add (" bbox" , true );
429411 out.add (" roi" , false );
0 commit comments