|
1 | 1 | import multiprocessing |
2 | 2 | import os |
3 | 3 | import random |
4 | | -from typing import Any, List, Tuple, Union |
| 4 | +from typing import Any, List, Optional, Tuple, Union |
5 | 5 |
|
6 | 6 | import aim |
7 | 7 | import albumentations as A |
|
19 | 19 | from torchvision.datasets import VOCDetection |
20 | 20 | from torchvision.models import detection |
21 | 21 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor |
| 22 | +from torchvision.models.detection.retinanet import RetinaNetClassificationHead |
22 | 23 | from torchvision.utils import draw_bounding_boxes |
23 | 24 |
|
24 | 25 | import ignite.distributed as idist |
|
39 | 40 | "fasterrcnn_resnet50_fpn", |
40 | 41 | "fasterrcnn_mobilenet_v3_large_fpn", |
41 | 42 | "fasterrcnn_mobilenet_v3_large_320_fpn", |
| 43 | + "retinanet_resnet50_fpn", |
42 | 44 | ] |
43 | 45 |
|
44 | 46 |
|
@@ -108,33 +110,29 @@ def run( |
108 | 110 | local_rank: int, |
109 | 111 | device: str, |
110 | 112 | experiment_name: str, |
111 | | - gpus: Union[int, List[int], str] = None, |
| 113 | + gpus: Optional[Union[int, List[int], str]] = None, |
112 | 114 | dataset_root: str = "./dataset", |
113 | 115 | log_dir: str = "./log", |
114 | 116 | model: str = "fasterrcnn_resnet50_fpn", |
115 | 117 | epochs: int = 13, |
116 | 118 | batch_size: int = 4, |
117 | | - lr: int = 0.01, |
| 119 | + lr: float = 0.01, |
118 | 120 | download: bool = False, |
119 | 121 | image_size: int = 256, |
120 | | - resume_from: dict = None, |
| 122 | + resume_from: Optional[dict] = None, |
121 | 123 | ) -> None: |
122 | 124 | bbox_params = A.BboxParams(format="pascal_voc") |
123 | 125 | train_transform = A.Compose( |
124 | | - [A.Resize(image_size, image_size), A.HorizontalFlip(p=0.5), ToTensorV2()], |
| 126 | + [A.HorizontalFlip(p=0.5), ToTensorV2()], |
125 | 127 | bbox_params=bbox_params, |
126 | 128 | ) |
127 | | - val_transform = A.Compose([A.Resize(image_size, image_size), ToTensorV2()], bbox_params=bbox_params) |
| 129 | + val_transform = A.Compose([ToTensorV2()], bbox_params=bbox_params) |
128 | 130 |
|
129 | 131 | download = local_rank == 0 and download |
130 | 132 | train_dataset = Dataset(root=dataset_root, download=download, image_set="train", transforms=train_transform) |
131 | 133 | val_dataset = Dataset(root=dataset_root, download=download, image_set="val", transforms=val_transform) |
132 | 134 | vis_dataset = Subset(val_dataset, random.sample(range(len(val_dataset)), k=16)) |
133 | 135 |
|
134 | | - # for testing |
135 | | - train_dataset = Subset(train_dataset, range(100)) |
136 | | - val_dataset = Subset(train_dataset, range(100)) |
137 | | - |
138 | 136 | train_dataloader = idist.auto_dataloader( |
139 | 137 | train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4 |
140 | 138 | ) |
@@ -322,8 +320,15 @@ def main( |
322 | 320 |
|
323 | 321 | # to precent multiple download for preatrined checkpoint, create model in the main process |
324 | 322 | model = getattr(detection, model)(pretrained=True) |
325 | | - in_features = model.roi_heads.box_predictor.cls_score.in_features |
326 | | - model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 21) |
| 323 | + |
| 324 | + if model.__class__.__name__ == "FasterRCNN": |
| 325 | + in_features = model.roi_heads.box_predictor.cls_score.in_features |
| 326 | + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 21) |
| 327 | + elif model.__class__.__name__ == "RetinaNet": |
| 328 | + head = RetinaNetClassificationHead( |
| 329 | + model.backbone.out_channels, model.anchor_generator.num_anchors_per_location()[0], num_classes=21 |
| 330 | + ) |
| 331 | + model.head.classification_head = head |
327 | 332 |
|
328 | 333 | with idist.Parallel(backend=backend, nproc_per_node=nproc_per_node) as parallel: |
329 | 334 | parallel.run( |
|
0 commit comments