Skip to content

Commit 7be380b

Browse files
committed
Use raval_pytree to calculate gradient norm to make independent of array sizes
1 parent 6fffde4 commit 7be380b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

varipeps/optimization/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def random_noise(a):
844844
contraction = "precondition_operator"
845845

846846
grad_norm_squared = 1e-2 * (
847-
jnp.linalg.norm(jnp.asarray(working_gradient)) ** 2
847+
jnp.linalg.norm(ravel_pytree(working_gradient)[0]) ** 2
848848
)
849849

850850
tmp_descent_dir = [

0 commit comments

Comments
 (0)