@@ -125,7 +125,7 @@ def __init__(
125125 dataset : Optional [TrackingDataset ] = None ,
126126 model_path : Optional [str ] = None ,
127127 arch : str = "dla_34" ,
128- head_conv : int = None ,
128+ head_conv : int = - 1 ,
129129 ) -> None :
130130 """
131131 Initialize learner object.
@@ -142,10 +142,9 @@ def __init__(
142142 """
143143 self .opt = opts ()
144144 self .opt .arch = arch
145- self .opt .head_conv = head_conv if head_conv else - 1
146- self .opt .gpus = _get_gpu_str ()
145+ self .opt .set_head_conv ( head_conv )
146+ self .opt .set_gpus ( _get_gpu_str () )
147147 self .opt .device = torch_device ()
148-
149148 self .dataset = dataset
150149 self .model = None
151150 self ._init_model (model_path )
@@ -183,40 +182,42 @@ def fit(
183182 """
184183 if not self .dataset :
185184 raise Exception ("No dataset provided" )
186- lr_step = str (lr_step )
185+ if type (lr_step ) is not list :
186+ lr_step = [lr_step ]
187+ lr_step = [int (x ) for x in lr_step ]
187188
188- opt_fit = deepcopy (self .opt ) # copy opt to avoid bug
189- opt_fit .lr = lr
190- opt_fit .lr_step = lr_step
191- opt_fit .num_epochs = num_epochs
189+ # update parameters
190+ self .opt .lr = lr
191+ self .opt .lr_step = lr_step
192+ self .opt .num_epochs = num_epochs
193+ opt = deepcopy (self .opt ) #to avoid fairMOT over-writing opt
192194
193195 # update dataset options
194- opt_fit .update_dataset_info_and_set_heads (self .dataset .train_data )
196+ opt .update_dataset_info_and_set_heads (self .dataset .train_data )
195197
196198 # initialize dataloader
197199 train_loader = self .dataset .train_dl
198-
199200 self .model = create_model (
200- self . opt .arch , self . opt .heads , self . opt .head_conv
201+ opt .arch , opt .heads , opt .head_conv
201202 )
202- self .model = load_model (self .model , opt_fit .load_model )
203- self .optimizer = torch .optim .Adam (self .model .parameters (), opt_fit .lr )
203+ self .model = load_model (self .model , opt .load_model )
204+ self .optimizer = torch .optim .Adam (self .model .parameters (), opt .lr )
204205 start_epoch = 0
205206
206- Trainer = train_factory [opt_fit .task ]
207- trainer = Trainer (opt_fit . opt , self .model , self .optimizer )
208- trainer .set_device (opt_fit .gpus , opt_fit .chunk_sizes , opt_fit .device )
207+ Trainer = train_factory [opt .task ]
208+ trainer = Trainer (opt , self .model , self .optimizer )
209+ trainer .set_device (opt .gpus , opt .chunk_sizes , opt .device )
209210
210211 # initialize loss vars
211212 self .losses_dict = defaultdict (list )
212213
213214 # training loop
214215 for epoch in range (
215- start_epoch + 1 , start_epoch + opt_fit .num_epochs + 1
216+ start_epoch + 1 , start_epoch + opt .num_epochs + 1
216217 ):
217218 print (
218219 "=" * 5 ,
219- f" Epoch: { epoch } /{ start_epoch + opt_fit .num_epochs } " ,
220+ f" Epoch: { epoch } /{ start_epoch + opt .num_epochs } " ,
220221 "=" * 5 ,
221222 )
222223 self .epoch = epoch
@@ -226,8 +227,8 @@ def fit(
226227 print (f"{ k } :{ v } min" )
227228 else :
228229 print (f"{ k } : { v } " )
229- if epoch in opt_fit .lr_step :
230- lr = opt_fit .lr * (0.1 ** (opt_fit .lr_step .index (epoch ) + 1 ))
230+ if epoch in opt .lr_step :
231+ lr = opt .lr * (0.1 ** (opt .lr_step .index (epoch ) + 1 ))
231232 for param_group in self .optimizer .param_groups :
232233 param_group ["lr" ] = lr
233234
@@ -369,8 +370,6 @@ def predict(
369370 self ,
370371 im_or_video_path : str ,
371372 conf_thres : float = 0.6 ,
372- det_thres : float = 0.3 ,
373- nms_thres : float = 0.4 ,
374373 track_buffer : int = 30 ,
375374 min_box_area : float = 200 ,
376375 frame_rate : int = 30 ,
@@ -382,8 +381,6 @@ def predict(
382381 im_or_video_path: path to image(s) or video. Supports jpg, jpeg, png, tif formats for images.
383382 Supports mp4, avi formats for video.
384383 conf_thres: confidence thresh for tracking
385- det_thres: confidence thresh for detection
386- nms_thres: iou thresh for nms
387384 track_buffer: tracking buffer
388385 min_box_area: filter out tiny boxes
389386 frame_rate: frame rate
@@ -392,20 +389,13 @@ def predict(
392389
393390 Implementation inspired from code found here: https://github.com/ifzhang/FairMOT/blob/master/src/track.py
394391 """
395- opt_pred = deepcopy (self .opt ) # copy opt to avoid bug
396- opt_pred .conf_thres = conf_thres
397- opt_pred .det_thres = det_thres
398- opt_pred .nms_thres = nms_thres
399- opt_pred .track_buffer = track_buffer
400- opt_pred .min_box_area = min_box_area
392+ self .opt .conf_thres = conf_thres
393+ self .opt .track_buffer = track_buffer
394+ self .opt .min_box_area = min_box_area
395+ opt = deepcopy (self .opt ) #to avoid fairMOT over-writing opt
401396
402397 # initialize tracker
403- if self .model :
404- tracker = JDETracker (
405- opt_pred .opt , frame_rate = frame_rate , model = self .model
406- )
407- else :
408- tracker = JDETracker (opt_pred .opt , frame_rate = frame_rate )
398+ tracker = JDETracker (opt , frame_rate = frame_rate , model = self .model )
409399
410400 # initialize dataloader
411401 dataloader = self ._get_dataloader (im_or_video_path )
@@ -422,7 +412,7 @@ def predict(
422412 tlbr = t .tlbr
423413 tid = t .track_id
424414 vertical = tlwh [2 ] / tlwh [3 ] > 1.6
425- if tlwh [2 ] * tlwh [3 ] > opt_pred .min_box_area and not vertical :
415+ if tlwh [2 ] * tlwh [3 ] > opt .min_box_area and not vertical :
426416 bb = TrackingBbox (
427417 tlbr [0 ], tlbr [1 ], tlbr [2 ], tlbr [3 ], frame_id , tid
428418 )
0 commit comments