Skip to content

Commit c69b666

Browse files
feat: add map_location for torchensemble.utils.io.load (#155)
* add map location to the load function of io * Add Atiqur Rahman to Contributors list * Update CONTRIBUTORS.md --------- Co-authored-by: Yi-Xuan Xu <xuyx@lamda.nju.edu.cn>
1 parent 81ab9ec commit c69b666

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchensemble/utils/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def save(model, save_dir, logger):
4545
return
4646

4747

48-
def load(model, save_dir="./", logger=None):
48+
def load(model, save_dir="./", map_location=None, logger=None):
4949
"""Implement model deserialization from the specified directory."""
5050
if not os.path.exists(save_dir):
5151
raise FileExistsError("`{}` does not exist".format(save_dir))
@@ -67,7 +67,7 @@ def load(model, save_dir="./", logger=None):
6767
if logger:
6868
logger.info("Loading the model from `{}`".format(save_dir))
6969

70-
state = torch.load(save_dir)
70+
state = torch.load(save_dir, map_location=map_location)
7171
n_estimators = state["n_estimators"]
7272
model_params = state["model"]
7373
model._criterion = state["_criterion"]

0 commit comments

Comments
 (0)