@@ -491,7 +491,7 @@ def solve_args():
491491 Strategy for the x1 and x2 arguments to test_solve()
492492
493493 solve() takes x1, x2, where x1 is any stack of square invertible matrices
494- of shape (..., M, M), and x2 is either shape (..., M ) or (..., M, K),
494+ of shape (..., M, M), and x2 is either shape (M, ) or (..., M, K),
495495 where the ... parts of x1 and x2 are broadcast compatible.
496496 """
497497 stack_shapes = shared (two_mutually_broadcastable_shapes )
@@ -501,26 +501,18 @@ def solve_args():
501501 pair [0 ])))
502502
503503 @composite
504- def x2_shapes (draw ):
505- end = draw (xps .array_shapes (min_dims = 0 , max_dims = 1 , min_side = 0 ,
506- max_side = SQRT_MAX_ARRAY_SIZE ))
507- return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + end
504+ def _x2_shapes (draw ):
505+ end = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ))
506+ return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + (end ,)
508507
509- x2 = xps .arrays (dtype = xps .floating_dtypes (), shape = x2_shapes ())
508+ x2_shapes = one_of (x1 .map (lambda x : (x .shape [- 1 ],)), _x2_shapes ())
509+ x2 = xps .arrays (dtype = xps .floating_dtypes (), shape = x2_shapes )
510510 return x1 , x2
511511
512512@pytest .mark .xp_extension ('linalg' )
513513@given (* solve_args ())
514514def test_solve (x1 , x2 ):
515- # TODO: solve() is currently ambiguous, in that some inputs can be
516- # interpreted in two different ways. For example, if x1 is shape (2, 2, 2)
517- # and x2 is shape (2, 2), should this be interpreted as x2 is (2,) stack
518- # of a (2,) vector, i.e., the result would be (2, 2, 2, 1) after
519- # broadcasting, or as a single stack of a 2x2 matrix, i.e., resulting in
520- # (2, 2, 2, 2).
521-
522- # res = linalg.solve(x1, x2)
523- pass
515+ res = linalg .solve (x1 , x2 )
524516
525517@pytest .mark .xp_extension ('linalg' )
526518@given (
0 commit comments