2222 number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
2323 depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
2424"""
25+ from copy import deepcopy
2526
2627import torch .nn as nn
2728
@@ -69,6 +70,59 @@ def load_state_dict(self, state_dict, **kwargs):
6970 super ().load_state_dict (state_dict , ** kwargs )
7071
7172
73+ new_settings = {
74+ "resnet18" : {
75+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth" ,
76+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth"
77+ },
78+ "resnet50" : {
79+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth" ,
80+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth"
81+ },
82+ "resnext50_32x4d" : {
83+ "imagenet" : "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" ,
84+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth" ,
85+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth" ,
86+ },
87+ "resnext101_32x4d" : {
88+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth" ,
89+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth"
90+ },
91+ "resnext101_32x8d" : {
92+ "imagenet" : "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth" ,
93+ "instagram" : "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth" ,
94+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth" ,
95+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth" ,
96+ },
97+ "resnext101_32x16d" : {
98+ "instagram" : "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth" ,
99+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth" ,
100+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth" ,
101+ },
102+ "resnext101_32x32d" : {
103+ "instagram" : "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth" ,
104+ },
105+ "resnext101_32x48d" : {
106+ "instagram" : "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth" ,
107+ }
108+ }
109+
110+ pretrained_settings = deepcopy (pretrained_settings )
111+ for model_name , sources in new_settings .items ():
112+ if model_name not in pretrained_settings :
113+ pretrained_settings [model_name ] = {}
114+
115+ for source_name , source_url in sources .items ():
116+ pretrained_settings [model_name ][source_name ] = {
117+ "url" : source_url ,
118+ 'input_size' : [3 , 224 , 224 ],
119+ 'input_range' : [0 , 1 ],
120+ 'mean' : [0.485 , 0.456 , 0.406 ],
121+ 'std' : [0.229 , 0.224 , 0.225 ],
122+ 'num_classes' : 1000
123+ }
124+
125+
72126resnet_encoders = {
73127 "resnet18" : {
74128 "encoder" : ResNetEncoder ,
@@ -117,17 +171,7 @@ def load_state_dict(self, state_dict, **kwargs):
117171 },
118172 "resnext50_32x4d" : {
119173 "encoder" : ResNetEncoder ,
120- "pretrained_settings" : {
121- "imagenet" : {
122- "url" : "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" ,
123- "input_space" : "RGB" ,
124- "input_size" : [3 , 224 , 224 ],
125- "input_range" : [0 , 1 ],
126- "mean" : [0.485 , 0.456 , 0.406 ],
127- "std" : [0.229 , 0.224 , 0.225 ],
128- "num_classes" : 1000 ,
129- }
130- },
174+ "pretrained_settings" : pretrained_settings ["resnext50_32x4d" ],
131175 "params" : {
132176 "out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
133177 "block" : Bottleneck ,
@@ -136,28 +180,20 @@ def load_state_dict(self, state_dict, **kwargs):
136180 "width_per_group" : 4 ,
137181 },
138182 },
139- "resnext101_32x8d " : {
183+ "resnext101_32x4d " : {
140184 "encoder" : ResNetEncoder ,
141- "pretrained_settings" : {
142- "imagenet" : {
143- "url" : "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth" ,
144- "input_space" : "RGB" ,
145- "input_size" : [3 , 224 , 224 ],
146- "input_range" : [0 , 1 ],
147- "mean" : [0.485 , 0.456 , 0.406 ],
148- "std" : [0.229 , 0.224 , 0.225 ],
149- "num_classes" : 1000 ,
150- },
151- "instagram" : {
152- "url" : "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth" ,
153- "input_space" : "RGB" ,
154- "input_size" : [3 , 224 , 224 ],
155- "input_range" : [0 , 1 ],
156- "mean" : [0.485 , 0.456 , 0.406 ],
157- "std" : [0.229 , 0.224 , 0.225 ],
158- "num_classes" : 1000 ,
159- },
185+ "pretrained_settings" : pretrained_settings ["resnext101_32x4d" ],
186+ "params" : {
187+ "out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
188+ "block" : Bottleneck ,
189+ "layers" : [3 , 4 , 23 , 3 ],
190+ "groups" : 32 ,
191+ "width_per_group" : 4 ,
160192 },
193+ },
194+ "resnext101_32x8d" : {
195+ "encoder" : ResNetEncoder ,
196+ "pretrained_settings" : pretrained_settings ["resnext101_32x8d" ],
161197 "params" : {
162198 "out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
163199 "block" : Bottleneck ,
@@ -168,17 +204,7 @@ def load_state_dict(self, state_dict, **kwargs):
168204 },
169205 "resnext101_32x16d" : {
170206 "encoder" : ResNetEncoder ,
171- "pretrained_settings" : {
172- "instagram" : {
173- "url" : "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth" ,
174- "input_space" : "RGB" ,
175- "input_size" : [3 , 224 , 224 ],
176- "input_range" : [0 , 1 ],
177- "mean" : [0.485 , 0.456 , 0.406 ],
178- "std" : [0.229 , 0.224 , 0.225 ],
179- "num_classes" : 1000 ,
180- }
181- },
207+ "pretrained_settings" : pretrained_settings ["resnext101_32x16d" ],
182208 "params" : {
183209 "out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
184210 "block" : Bottleneck ,
@@ -189,17 +215,7 @@ def load_state_dict(self, state_dict, **kwargs):
189215 },
190216 "resnext101_32x32d" : {
191217 "encoder" : ResNetEncoder ,
192- "pretrained_settings" : {
193- "instagram" : {
194- "url" : "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth" ,
195- "input_space" : "RGB" ,
196- "input_size" : [3 , 224 , 224 ],
197- "input_range" : [0 , 1 ],
198- "mean" : [0.485 , 0.456 , 0.406 ],
199- "std" : [0.229 , 0.224 , 0.225 ],
200- "num_classes" : 1000 ,
201- }
202- },
218+ "pretrained_settings" : pretrained_settings ["resnext101_32x32d" ],
203219 "params" : {
204220 "out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
205221 "block" : Bottleneck ,
@@ -210,17 +226,7 @@ def load_state_dict(self, state_dict, **kwargs):
210226 },
211227 "resnext101_32x48d" : {
212228 "encoder" : ResNetEncoder ,
213- "pretrained_settings" : {
214- "instagram" : {
215- "url" : "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth" ,
216- "input_space" : "RGB" ,
217- "input_size" : [3 , 224 , 224 ],
218- "input_range" : [0 , 1 ],
219- "mean" : [0.485 , 0.456 , 0.406 ],
220- "std" : [0.229 , 0.224 , 0.225 ],
221- "num_classes" : 1000 ,
222- }
223- },
229+ "pretrained_settings" : pretrained_settings ["resnext101_32x48d" ],
224230 "params" : {
225231 "out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
226232 "block" : Bottleneck ,
0 commit comments