22
33from functools import reduce as _reduce , wraps as _wraps
44from builtins import all as _builtin_all , any as _builtin_any
5- from typing import Any , List , Optional , Sequence , Tuple , Union
5+ from typing import Any , List , Optional , Sequence , Tuple , Union , Literal
66
77import torch
88
@@ -547,8 +547,12 @@ def count_nonzero(
547547) -> Array :
548548 result = torch .count_nonzero (x , dim = axis )
549549 if keepdims :
550- if axis is not None :
550+ if isinstance ( axis , int ) :
551551 return result .unsqueeze (axis )
552+ elif isinstance (axis , tuple ):
553+ n_axis = [x .ndim + ax if ax < 0 else ax for ax in axis ]
554+ sh = [1 if i in n_axis else x .shape [i ] for i in range (x .ndim )]
555+ return torch .reshape (result , sh )
552556 return _axis_none_keepdims (result , x .ndim , keepdims )
553557 else :
554558 return result
@@ -823,6 +827,12 @@ def sign(x: Array, /) -> Array:
823827 return out
824828
825829
830+ def meshgrid (* arrays : Array , indexing : Literal ['xy' , 'ij' ] = 'xy' ) -> List [Array ]:
831+ # enforce the default of 'xy'
832+ # TODO: is the return type a list or a tuple
833+ return list (torch .meshgrid (* arrays , indexing = 'xy' ))
834+
835+
826836__all__ = ['asarray' , 'result_type' , 'can_cast' ,
827837 'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
828838 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
@@ -839,4 +849,4 @@ def sign(x: Array, /) -> Array:
839849 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
840850 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
841851 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
842- 'take' , 'take_along_axis' , 'sign' , 'finfo' , 'iinfo' , 'repeat' ]
852+ 'take' , 'take_along_axis' , 'sign' , 'finfo' , 'iinfo' , 'repeat' , 'meshgrid' ]
0 commit comments