@@ -22,6 +22,34 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
2222 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
2323 return torch_linalg .cross (x1 , x2 , dim = axis )
2424
25- __all__ = linalg_all + ['outer' , 'trace' , 'matrix_transpose' , 'tensordot' ]
25+ def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 , ** kwargs ) -> array :
26+ from ._aliases import isdtype
27+
28+ x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
29+
30+ # torch.linalg.vecdot doesn't support integer dtypes
31+ if isdtype (x1 .dtype , 'integral' ) or isdtype (x2 .dtype , 'integral' ):
32+ if kwargs :
33+ raise RuntimeError ("vecdot kwargs not supported for integral dtypes" )
34+ ndim = max (x1 .ndim , x2 .ndim )
35+ x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
36+ x2_shape = (1 ,)* (ndim - x2 .ndim ) + tuple (x2 .shape )
37+ if x1_shape [axis ] != x2_shape [axis ]:
38+ raise ValueError ("x1 and x2 must have the same size along the given axis" )
39+
40+ x1_ , x2_ = torch .broadcast_tensors (x1 , x2 )
41+ x1_ = torch .moveaxis (x1_ , axis , - 1 )
42+ x2_ = torch .moveaxis (x2_ , axis , - 1 )
43+
44+ res = x1_ [..., None , :] @ x2_ [..., None ]
45+ return res [..., 0 , 0 ]
46+ return torch .linalg .vecdot (x1 , x2 , dim = axis , ** kwargs )
47+
48+ def solve (x1 : array , x2 : array , / , ** kwargs ) -> array :
49+ x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
50+ return torch .linalg .solve (x1 , x2 , ** kwargs )
51+
52+ __all__ = linalg_all + ['outer' , 'trace' , 'matrix_transpose' , 'tensordot' ,
53+ 'vecdot' , 'solve' ]
2654
2755del linalg_all
0 commit comments