Skip to content

Commit 862bc84

Browse files
committed
add retina net
1 parent dd6fe53 commit 862bc84

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

examples/contrib/detection/main.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import multiprocessing
22
import os
33
import random
4-
from typing import Any, List, Tuple, Union
4+
from typing import Any, List, Optional, Tuple, Union
55

66
import aim
77
import albumentations as A
@@ -19,6 +19,7 @@
1919
from torchvision.datasets import VOCDetection
2020
from torchvision.models import detection
2121
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
22+
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
2223
from torchvision.utils import draw_bounding_boxes
2324

2425
import ignite.distributed as idist
@@ -39,6 +40,7 @@
3940
"fasterrcnn_resnet50_fpn",
4041
"fasterrcnn_mobilenet_v3_large_fpn",
4142
"fasterrcnn_mobilenet_v3_large_320_fpn",
43+
"retinanet_resnet50_fpn",
4244
]
4345

4446

@@ -108,33 +110,29 @@ def run(
108110
local_rank: int,
109111
device: str,
110112
experiment_name: str,
111-
gpus: Union[int, List[int], str] = None,
113+
gpus: Optional[Union[int, List[int], str]] = None,
112114
dataset_root: str = "./dataset",
113115
log_dir: str = "./log",
114116
model: str = "fasterrcnn_resnet50_fpn",
115117
epochs: int = 13,
116118
batch_size: int = 4,
117-
lr: int = 0.01,
119+
lr: float = 0.01,
118120
download: bool = False,
119121
image_size: int = 256,
120-
resume_from: dict = None,
122+
resume_from: Optional[dict] = None,
121123
) -> None:
122124
bbox_params = A.BboxParams(format="pascal_voc")
123125
train_transform = A.Compose(
124-
[A.Resize(image_size, image_size), A.HorizontalFlip(p=0.5), ToTensorV2()],
126+
[A.HorizontalFlip(p=0.5), ToTensorV2()],
125127
bbox_params=bbox_params,
126128
)
127-
val_transform = A.Compose([A.Resize(image_size, image_size), ToTensorV2()], bbox_params=bbox_params)
129+
val_transform = A.Compose([ToTensorV2()], bbox_params=bbox_params)
128130

129131
download = local_rank == 0 and download
130132
train_dataset = Dataset(root=dataset_root, download=download, image_set="train", transforms=train_transform)
131133
val_dataset = Dataset(root=dataset_root, download=download, image_set="val", transforms=val_transform)
132134
vis_dataset = Subset(val_dataset, random.sample(range(len(val_dataset)), k=16))
133135

134-
# for testing
135-
train_dataset = Subset(train_dataset, range(100))
136-
val_dataset = Subset(train_dataset, range(100))
137-
138136
train_dataloader = idist.auto_dataloader(
139137
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4
140138
)
@@ -322,8 +320,15 @@ def main(
322320

323321
# to precent multiple download for preatrined checkpoint, create model in the main process
324322
model = getattr(detection, model)(pretrained=True)
325-
in_features = model.roi_heads.box_predictor.cls_score.in_features
326-
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 21)
323+
324+
if model.__class__.__name__ == "FasterRCNN":
325+
in_features = model.roi_heads.box_predictor.cls_score.in_features
326+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 21)
327+
elif model.__class__.__name__ == "RetinaNet":
328+
head = RetinaNetClassificationHead(
329+
model.backbone.out_channels, model.anchor_generator.num_anchors_per_location()[0], num_classes=21
330+
)
331+
model.head.classification_head = head
327332

328333
with idist.Parallel(backend=backend, nproc_per_node=nproc_per_node) as parallel:
329334
parallel.run(

0 commit comments

Comments
 (0)