@@ -116,7 +116,8 @@ def __init__(
116116 precond_dtype : Optional [torch .dtype ] = None ,
117117 decoupled_decay : bool = False ,
118118 flatten : bool = False ,
119- flatten_start_end : Tuple [int , int ] = (2 , - 1 ),
119+ flatten_start_dim : int = 2 ,
120+ flatten_end_dim : int = - 1 ,
120121 deterministic : bool = False ,
121122 ):
122123 if not has_opt_einsum :
@@ -144,7 +145,8 @@ def __init__(
144145 precond_dtype = precond_dtype ,
145146 decoupled_decay = decoupled_decay ,
146147 flatten = flatten ,
147- flatten_start_end = flatten_start_end ,
148+ flatten_start_dim = flatten_start_dim ,
149+ flatten_end_dim = flatten_end_dim ,
148150 )
149151 super (Kron , self ).__init__ (params , defaults )
150152
@@ -235,7 +237,7 @@ def step(self, closure=None):
235237
236238 flattened = False
237239 if group ['flatten' ]:
238- grad = safe_flatten (grad , * group ["flatten_start_end " ])
240+ grad = safe_flatten (grad , group ["flatten_start_dim" ], group [ "flatten_end_dim " ])
239241 flattened = True
240242
241243 if len (state ) == 0 :
0 commit comments