Skip to content

Commit b838811

Browse files
authored
First fixes to 01 notebook (#590)
1 parent b6ec66a commit b838811

File tree

4 files changed

+147
-300
lines changed

4 files changed

+147
-300
lines changed

scenarios/tracking/01_training_introduction.ipynb

Lines changed: 115 additions & 267 deletions
Large diffs are not rendered by default.

utils_cv/tracking/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def _write_fairMOT_format(self) -> None:
172172
self.fairmot_imlist_path = osp.join(
173173
self.root, "{}.train".format(self.name)
174174
)
175-
with open(self.fairmot_imlist_path, "a") as f:
175+
with open(self.fairmot_imlist_path, "w") as f:
176176
for im_filename in sorted(self.im_filenames):
177177
f.write(osp.join(self.im_dir, im_filename) + "\n")
178178

utils_cv/tracking/model.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _get_gpu_str():
4343

4444
def _get_frame(input_video: str, frame_id: int):
4545
video = cv2.VideoCapture()
46-
video.open(input_video)
46+
video.open(input_video)
4747
video.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
4848
_, im = video.read()
4949
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
@@ -178,7 +178,7 @@ def fit(
178178
179179
Raise:
180180
Exception if dataset is undefined
181-
181+
182182
Implementation inspired from code found here: https://github.com/ifzhang/FairMOT/blob/master/src/train.py
183183
"""
184184
if not self.dataset:
@@ -227,7 +227,7 @@ def fit(
227227
print(f"{k}: {v}")
228228
if epoch in opt_fit.lr_step:
229229
lr = opt_fit.lr * (0.1 ** (opt_fit.lr_step.index(epoch) + 1))
230-
for param_group in optimizer.param_groups:
230+
for param_group in self.optimizer.param_groups:
231231
param_group["lr"] = lr
232232

233233
# store losses in each epoch
@@ -237,11 +237,11 @@ def fit(
237237

238238
def plot_training_losses(self, figsize: Tuple[int, int] = (10, 5)) -> None:
239239
"""
240-
Plot training loss.
241-
240+
Plot training loss.
241+
242242
Args:
243243
figsize (optional): width and height wanted for figure of training-loss plot
244-
244+
245245
"""
246246
fig = plt.figure(figsize=figsize)
247247
ax1 = fig.add_subplot(1, 1, 1)
@@ -274,15 +274,15 @@ def evaluate(
274274
self, results: Dict[int, List[TrackingBbox]], gt_root_path: str
275275
) -> str:
276276

277-
"""
277+
"""
278278
Evaluate performance wrt MOTA, MOTP, track quality measures, global ID measures, and more,
279279
as computed by py-motmetrics on a single experiment. By default, use 'single_vid' as exp_name.
280280
281281
Args:
282-
results: prediction results from predict() function, i.e. Dict[int, List[TrackingBbox]]
282+
results: prediction results from predict() function, i.e. Dict[int, List[TrackingBbox]]
283283
gt_root_path: path of dataset containing GT annotations in MOTchallenge format (xywh)
284284
Returns:
285-
strsummary: str output by method in 'motmetrics' package, containing metrics scores
285+
strsummary: str output by method in 'motmetrics' package, containing metrics scores
286286
"""
287287

288288
# Implementation inspired from code found here: https://github.com/ifzhang/FairMOT/blob/master/src/track.py
@@ -371,7 +371,7 @@ def predict(
371371
372372
Args:
373373
im_or_video_path: path to image(s) or video. Supports jpg, jpeg, png, tif formats for images.
374-
Supports mp4, avi formats for video.
374+
Supports mp4, avi formats for video.
375375
conf_thres: confidence thresh for tracking
376376
det_thres: confidence thresh for detection
377377
nms_thres: iou thresh for nms

utils_cv/tracking/plot.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,24 @@
1717

1818

1919
def plot_single_frame(
20-
results: Dict[int, List[TrackingBbox]], input_video: str, frame_id: int
20+
input_video: str,
21+
frame_id: int,
22+
results: Dict[int, List[TrackingBbox]] = None
2123
) -> None:
22-
"""
23-
Plot the bounding box and id on a wanted frame. Display as image to front end.
24+
"""
25+
Plot the bounding box and id on a wanted frame. Display as image to front end.
2426
2527
Args:
26-
results: dictionary mapping frame id to a list of predicted TrackingBboxes
2728
input_video: path to the input video
2829
frame_id: frame_id for frame to show tracking result
30+
results: dictionary mapping frame id to a list of predicted TrackingBboxes
2931
"""
3032

31-
if results is None: # if no tracking bboxes, only plot image
32-
# Get frame from video
33-
im = Image.fromarray(_get_frame(input_video, frame_id))
34-
# Display image
35-
IPython.display.display(im)
33+
# Extract frame
34+
im = _get_frame(input_video, frame_id)
3635

37-
else:
36+
# Overlay results
37+
if results:
3838
results = OrderedDict(sorted(results.items()))
3939

4040
# Assign bbox color per id
@@ -43,27 +43,26 @@ def plot_single_frame(
4343
)
4444
color_map = assign_colors(unique_ids)
4545

46-
# Get frame from video
47-
im = _get_frame(input_video, frame_id)
48-
4946
# Extract tracking results for wanted frame, and draw bboxes+tracking id, display frame
5047
cur_tracks = results[frame_id]
5148

5249
if len(cur_tracks) > 0:
5350
im = draw_boxes(im, cur_tracks, color_map)
54-
im = Image.fromarray(im)
55-
IPython.display.display(im)
51+
52+
# Display image
53+
im = Image.fromarray(im)
54+
IPython.display.display(im)
5655

5756

5857
def play_video(
5958
results: Dict[int, List[TrackingBbox]], input_video: str
6059
) -> None:
61-
"""
60+
"""
6261
Plot the predicted tracks on the input video. Displays to front-end as sequence of images stringed together in a video.
6362
6463
Args:
6564
results: dictionary mapping frame id to a list of predicted TrackingBboxes
66-
input_video: path to the input video
65+
input_video: path to the input video
6766
"""
6867

6968
results = OrderedDict(sorted(results.items()))
@@ -98,7 +97,7 @@ def play_video(
9897
def write_video(
9998
results: Dict[int, List[TrackingBbox]], input_video: str, output_video: str
10099
) -> None:
101-
"""
100+
"""
102101
Plot the predicted tracks on the input video. Write the output to {output_path}.
103102
104103
Args:
@@ -143,7 +142,7 @@ def draw_boxes(
143142
cur_tracks: List[TrackingBbox],
144143
color_map: Dict[int, Tuple[int, int, int]],
145144
) -> np.ndarray:
146-
"""
145+
"""
147146
Overlay bbox and id labels onto the frame
148147
149148
Args:
@@ -181,11 +180,11 @@ def draw_boxes(
181180

182181

183182
def assign_colors(id_list: List[int],) -> Dict[int, Tuple[int, int, int]]:
184-
"""
183+
"""
185184
Produce corresponding unique color palettes for unique ids
186-
185+
187186
Args:
188-
id_list: list of track ids
187+
id_list: list of track ids
189188
"""
190189
palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
191190

0 commit comments

Comments
 (0)