33from collections .abc import Sequence
44from functools import reduce as _reduce , wraps as _wraps
55from builtins import all as _builtin_all , any as _builtin_any
6- from typing import Any
6+ from typing import Any , List , Optional , Sequence , Tuple , Union , Literal
77
88import torch
99
@@ -547,8 +547,12 @@ def count_nonzero(
547547) -> Array :
548548 result = torch .count_nonzero (x , dim = axis )
549549 if keepdims :
550- if axis is not None :
550+ if isinstance ( axis , int ) :
551551 return result .unsqueeze (axis )
552+ elif isinstance (axis , tuple ):
553+ n_axis = [x .ndim + ax if ax < 0 else ax for ax in axis ]
554+ sh = [1 if i in n_axis else x .shape [i ] for i in range (x .ndim )]
555+ return torch .reshape (result , sh )
552556 return _axis_none_keepdims (result , x .ndim , keepdims )
553557 else :
554558 return result
@@ -820,6 +824,12 @@ def sign(x: Array, /) -> Array:
820824 return out
821825
822826
827+ def meshgrid (* arrays : Array , indexing : Literal ['xy' , 'ij' ] = 'xy' ) -> List [Array ]:
828+ # enforce the default of 'xy'
829+ # TODO: is the return type a list or a tuple
830+ return list (torch .meshgrid (* arrays , indexing = 'xy' ))
831+
832+
823833__all__ = ['__array_namespace_info__' , 'asarray' , 'result_type' , 'can_cast' ,
824834 'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
825835 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
@@ -836,6 +846,6 @@ def sign(x: Array, /) -> Array:
836846 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
837847 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
838848 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
839- 'take' , 'take_along_axis' , 'sign' , 'finfo' , 'iinfo' , 'repeat' ]
849+ 'take' , 'take_along_axis' , 'sign' , 'finfo' , 'iinfo' , 'repeat' , 'meshgrid' ]
840850
841851_all_ignore = ['torch' , 'get_xp' ]
0 commit comments