|
| 1 | +from functools import update_wrapper, wraps |
1 | 2 | import torch |
2 | 3 | from torch import Tensor |
3 | | -from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach |
| 4 | +from torch.optim.optimizer import Optimizer |
| 5 | +try: |
| 6 | + from torch.optim.optimizer import _use_grad_for_differentiable, _default_to_fused_or_foreach |
| 7 | + has_recent_pt = True |
| 8 | +except ImportError: |
| 9 | + has_recent_pt = False |
| 10 | + |
4 | 11 | from typing import List, Optional |
5 | 12 |
|
6 | 13 | __all__ = ['SGDW', 'sgdw'] |
@@ -62,7 +69,9 @@ def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list): |
62 | 69 |
|
63 | 70 | return has_sparse_grad |
64 | 71 |
|
65 | | - @_use_grad_for_differentiable |
| 72 | + # FIXME figure out how to make _use_grad_for_differentiable interchangeable with no_grad decorator |
| 73 | + # without args, for backwards compatibility with old pytorch |
| 74 | + @torch.no_grad() |
66 | 75 | def step(self, closure=None): |
67 | 76 | """Performs a single optimization step. |
68 | 77 |
|
@@ -124,17 +133,19 @@ def sgdw( |
124 | 133 |
|
125 | 134 | See :class:`~torch.optim.SGD` for details. |
126 | 135 | """ |
| 136 | + if has_recent_pt and hasattr(Optimizer, '_group_tensors_by_device_and_dtype'): |
| 137 | + if foreach is None: |
| 138 | + # why must we be explicit about an if statement for torch.jit.is_scripting here? |
| 139 | + # because JIT can't handle Optionals nor fancy conditionals when scripting |
| 140 | + if not torch.jit.is_scripting(): |
| 141 | + _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) |
| 142 | + else: |
| 143 | + foreach = False |
127 | 144 |
|
128 | | - if foreach is None: |
129 | | - # why must we be explicit about an if statement for torch.jit.is_scripting here? |
130 | | - # because JIT can't handle Optionals nor fancy conditionals when scripting |
131 | | - if not torch.jit.is_scripting(): |
132 | | - _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) |
133 | | - else: |
134 | | - foreach = False |
135 | | - |
136 | | - if foreach and torch.jit.is_scripting(): |
137 | | - raise RuntimeError('torch.jit.script not supported with foreach optimizers') |
| 145 | + if foreach and torch.jit.is_scripting(): |
| 146 | + raise RuntimeError('torch.jit.script not supported with foreach optimizers') |
| 147 | + else: |
| 148 | + foreach = False # disabling altogether for older pytorch, as using _group_tensors_by_device_and_dtype |
138 | 149 |
|
139 | 150 | if foreach and not torch.jit.is_scripting(): |
140 | 151 | func = _multi_tensor_sgdw |
|
0 commit comments