Skip to content

Commit 777856b

Browse files
authored
Bug fixes and rewrite of opts.py (#592)
1 parent 6db4915 commit 777856b

File tree

3 files changed

+136
-340
lines changed

3 files changed

+136
-340
lines changed

utils_cv/tracking/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
self.keypoints = None
5353
self.mask_paths = None
5454

55-
# Init FairMOT opt object
55+
# Init FairMOT opt object with all parameter settings
5656
opt = opts()
5757

5858
# Read annotations
@@ -64,7 +64,7 @@ def __init__(
6464
# Create FairMOT dataset object
6565
transforms = T.Compose([T.ToTensor()])
6666
self.train_data = JointDataset(
67-
opt.opt,
67+
opt,
6868
self.root,
6969
{name: self.fairmot_imlist_path},
7070
(opt.input_w, opt.input_h),

utils_cv/tracking/model.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(
125125
dataset: Optional[TrackingDataset] = None,
126126
model_path: Optional[str] = None,
127127
arch: str = "dla_34",
128-
head_conv: int = None,
128+
head_conv: int = -1,
129129
) -> None:
130130
"""
131131
Initialize learner object.
@@ -142,10 +142,9 @@ def __init__(
142142
"""
143143
self.opt = opts()
144144
self.opt.arch = arch
145-
self.opt.head_conv = head_conv if head_conv else -1
146-
self.opt.gpus = _get_gpu_str()
145+
self.opt.set_head_conv(head_conv)
146+
self.opt.set_gpus(_get_gpu_str())
147147
self.opt.device = torch_device()
148-
149148
self.dataset = dataset
150149
self.model = None
151150
self._init_model(model_path)
@@ -183,40 +182,42 @@ def fit(
183182
"""
184183
if not self.dataset:
185184
raise Exception("No dataset provided")
186-
lr_step = str(lr_step)
185+
if type(lr_step) is not list:
186+
lr_step = [lr_step]
187+
lr_step = [int(x) for x in lr_step]
187188

188-
opt_fit = deepcopy(self.opt) # copy opt to avoid bug
189-
opt_fit.lr = lr
190-
opt_fit.lr_step = lr_step
191-
opt_fit.num_epochs = num_epochs
189+
# update parameters
190+
self.opt.lr = lr
191+
self.opt.lr_step = lr_step
192+
self.opt.num_epochs = num_epochs
193+
opt = deepcopy(self.opt) #to avoid fairMOT over-writing opt
192194

193195
# update dataset options
194-
opt_fit.update_dataset_info_and_set_heads(self.dataset.train_data)
196+
opt.update_dataset_info_and_set_heads(self.dataset.train_data)
195197

196198
# initialize dataloader
197199
train_loader = self.dataset.train_dl
198-
199200
self.model = create_model(
200-
self.opt.arch, self.opt.heads, self.opt.head_conv
201+
opt.arch, opt.heads, opt.head_conv
201202
)
202-
self.model = load_model(self.model, opt_fit.load_model)
203-
self.optimizer = torch.optim.Adam(self.model.parameters(), opt_fit.lr)
203+
self.model = load_model(self.model, opt.load_model)
204+
self.optimizer = torch.optim.Adam(self.model.parameters(), opt.lr)
204205
start_epoch = 0
205206

206-
Trainer = train_factory[opt_fit.task]
207-
trainer = Trainer(opt_fit.opt, self.model, self.optimizer)
208-
trainer.set_device(opt_fit.gpus, opt_fit.chunk_sizes, opt_fit.device)
207+
Trainer = train_factory[opt.task]
208+
trainer = Trainer(opt, self.model, self.optimizer)
209+
trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)
209210

210211
# initialize loss vars
211212
self.losses_dict = defaultdict(list)
212213

213214
# training loop
214215
for epoch in range(
215-
start_epoch + 1, start_epoch + opt_fit.num_epochs + 1
216+
start_epoch + 1, start_epoch + opt.num_epochs + 1
216217
):
217218
print(
218219
"=" * 5,
219-
f" Epoch: {epoch}/{start_epoch + opt_fit.num_epochs} ",
220+
f" Epoch: {epoch}/{start_epoch + opt.num_epochs} ",
220221
"=" * 5,
221222
)
222223
self.epoch = epoch
@@ -226,8 +227,8 @@ def fit(
226227
print(f"{k}:{v} min")
227228
else:
228229
print(f"{k}: {v}")
229-
if epoch in opt_fit.lr_step:
230-
lr = opt_fit.lr * (0.1 ** (opt_fit.lr_step.index(epoch) + 1))
230+
if epoch in opt.lr_step:
231+
lr = opt.lr * (0.1 ** (opt.lr_step.index(epoch) + 1))
231232
for param_group in self.optimizer.param_groups:
232233
param_group["lr"] = lr
233234

@@ -369,8 +370,6 @@ def predict(
369370
self,
370371
im_or_video_path: str,
371372
conf_thres: float = 0.6,
372-
det_thres: float = 0.3,
373-
nms_thres: float = 0.4,
374373
track_buffer: int = 30,
375374
min_box_area: float = 200,
376375
frame_rate: int = 30,
@@ -382,8 +381,6 @@ def predict(
382381
im_or_video_path: path to image(s) or video. Supports jpg, jpeg, png, tif formats for images.
383382
Supports mp4, avi formats for video.
384383
conf_thres: confidence thresh for tracking
385-
det_thres: confidence thresh for detection
386-
nms_thres: iou thresh for nms
387384
track_buffer: tracking buffer
388385
min_box_area: filter out tiny boxes
389386
frame_rate: frame rate
@@ -392,20 +389,13 @@ def predict(
392389
393390
Implementation inspired from code found here: https://github.com/ifzhang/FairMOT/blob/master/src/track.py
394391
"""
395-
opt_pred = deepcopy(self.opt) # copy opt to avoid bug
396-
opt_pred.conf_thres = conf_thres
397-
opt_pred.det_thres = det_thres
398-
opt_pred.nms_thres = nms_thres
399-
opt_pred.track_buffer = track_buffer
400-
opt_pred.min_box_area = min_box_area
392+
self.opt.conf_thres = conf_thres
393+
self.opt.track_buffer = track_buffer
394+
self.opt.min_box_area = min_box_area
395+
opt = deepcopy(self.opt) #to avoid fairMOT over-writing opt
401396

402397
# initialize tracker
403-
if self.model:
404-
tracker = JDETracker(
405-
opt_pred.opt, frame_rate=frame_rate, model=self.model
406-
)
407-
else:
408-
tracker = JDETracker(opt_pred.opt, frame_rate=frame_rate)
398+
tracker = JDETracker(opt, frame_rate=frame_rate, model=self.model)
409399

410400
# initialize dataloader
411401
dataloader = self._get_dataloader(im_or_video_path)
@@ -422,7 +412,7 @@ def predict(
422412
tlbr = t.tlbr
423413
tid = t.track_id
424414
vertical = tlwh[2] / tlwh[3] > 1.6
425-
if tlwh[2] * tlwh[3] > opt_pred.min_box_area and not vertical:
415+
if tlwh[2] * tlwh[3] > opt.min_box_area and not vertical:
426416
bb = TrackingBbox(
427417
tlbr[0], tlbr[1], tlbr[2], tlbr[3], frame_id, tid
428418
)

0 commit comments

Comments
 (0)