Skip to content

Commit 707b230

Browse files
authored
Remove deprecated return_state_dict in bundle load (#8454)
Fixes #8453 ### Description Remove deprecated `return_state_dict ` in bundle `load` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 03a0d85 commit 707b230

File tree

3 files changed

+4
-24
lines changed

3 files changed

+4
-24
lines changed

monai/bundle/scripts.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,6 @@ def load(
702702
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
703703
the corresponding metadata dict, and extra files dict.
704704
please check `monai.data.load_net_with_metadata` for more details.
705-
4. If `return_state_dict` is True, return model weights, only used for compatibility
706-
when `model` and `net_name` are all `None`.
707705
708706
"""
709707
bundle_dir_ = _process_bundle_dir(bundle_dir)

tests/bundle/test_bundle_download.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -268,19 +268,17 @@ class TestLoad(unittest.TestCase):
268268
@skip_if_quick
269269
def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file):
270270
with skip_if_downloading_fails():
271-
# download bundle, and load weights from the downloaded path
272271
with tempfile.TemporaryDirectory() as tempdir:
273272
bundle_root = os.path.join(tempdir, bundle_name)
274273
# load weights
275-
weights = load(
274+
model_1 = load(
276275
name=bundle_name,
277276
model_file=model_file,
278277
bundle_dir=tempdir,
279278
repo=repo,
280279
source="github",
281280
progress=False,
282281
device=device,
283-
return_state_dict=True,
284282
)
285283
# prepare network
286284
with open(os.path.join(bundle_root, bundle_files[2])) as f:
@@ -289,7 +287,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
289287
del net_args["_target_"]
290288
model = getattr(nets, model_name)(**net_args)
291289
model.to(device)
292-
model.load_state_dict(weights)
290+
model.load_state_dict(model_1)
293291
model.eval()
294292

295293
# prepare data and test
@@ -313,13 +311,11 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
313311
progress=False,
314312
device=device,
315313
source="github",
316-
return_state_dict=False,
317314
)
318315
model_2.eval()
319316
output_2 = model_2.forward(input_tensor)
320317
assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
321318

322-
# test compatibility with return_state_dict=True.
323319
model_3 = load(
324320
name=bundle_name,
325321
model_file=model_file,
@@ -328,7 +324,6 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
328324
device=device,
329325
net_name=model_name,
330326
source="github",
331-
return_state_dict=False,
332327
**net_args,
333328
)
334329
model_3.eval()
@@ -343,14 +338,7 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
343338
# download bundle, and load weights from the downloaded path
344339
with tempfile.TemporaryDirectory() as tempdir:
345340
# load weights
346-
model = load(
347-
name=bundle_name,
348-
bundle_dir=tempdir,
349-
source="monaihosting",
350-
progress=False,
351-
device=device,
352-
return_state_dict=False,
353-
)
341+
model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device)
354342

355343
# prepare data and test
356344
input_tensor = torch.rand(1, 1, 96, 96, 96).to(device)
@@ -371,7 +359,6 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
371359
source="monaihosting",
372360
progress=False,
373361
device=device,
374-
return_state_dict=False,
375362
net_override=net_override,
376363
)
377364

tests/ngc_bundle_download.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,7 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download
8383
self.assertTrue(check_hash(filepath=full_file_path, val=hash_val))
8484

8585
model = load(
86-
name=bundle_name,
87-
source="ngc",
88-
version=version,
89-
bundle_dir=tempdir,
90-
remove_prefix=remove_prefix,
91-
return_state_dict=False,
86+
name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix
9287
)
9388
assert_allclose(
9489
model.state_dict()[TESTCASE_WEIGHTS["key"]],

0 commit comments

Comments
 (0)