11from __future__ import annotations
22
3- from functools import wraps as _wraps
3+ from functools import reduce as _reduce , wraps as _wraps
44from builtins import all as _builtin_all , any as _builtin_any
55
66from ..common import _aliases
@@ -124,43 +124,35 @@ def _fix_promotion(x1, x2, only_scalar=True):
124124
125125
126126def result_type (* arrays_and_dtypes : Union [array , Dtype , bool , int , float , complex ]) -> Dtype :
127- if len (arrays_and_dtypes ) == 0 :
128- raise TypeError ("At least one array or dtype must be provided" )
129- if len (arrays_and_dtypes ) == 1 :
127+ num = len (arrays_and_dtypes )
128+
129+ if num == 0 :
130+ raise ValueError ("At least one array or dtype must be provided" )
131+
132+ elif num == 1 :
130133 x = arrays_and_dtypes [0 ]
131134 if isinstance (x , torch .dtype ):
132135 return x
133136 return x .dtype
134137
135- if len (arrays_and_dtypes ) > 2 :
136- # sort the scalars to the left so that they are treated last
137- scalars , others = [], []
138- for x in arrays_and_dtypes :
139- if isinstance (x , _py_scalars ):
140- scalars .append (x )
141- else :
142- others .append (x )
143- if len (scalars ) == len (arrays_and_dtypes ):
144- raise ValueError ("At least one array or dtype is required." )
138+ if num == 2 :
139+ x , y = arrays_and_dtypes
140+ return _result_type (x , y )
145141
146- arrays_and_dtypes = scalars + others
147- return result_type (arrays_and_dtypes [0 ], result_type (* arrays_and_dtypes [1 :]))
142+ else :
143+ if _builtin_all (isinstance (x , _py_scalars ) for x in arrays_and_dtypes ):
144+ raise ValueError ("At least one array or dtype must be provided" )
148145
149- # the binary case
150- x , y = arrays_and_dtypes
146+ return _reduce (_result_type , arrays_and_dtypes )
151147
152- if isinstance (x , _py_scalars ):
153- if isinstance (y , _py_scalars ):
154- raise ValueError ("At least one array or dtype is required." )
155- return y
156- elif isinstance (y , _py_scalars ):
157- return x
158148
159- xdt = x .dtype if not isinstance (x , torch .dtype ) else x
160- ydt = y .dtype if not isinstance (y , torch .dtype ) else y
149+ def _result_type (x , y ):
150+ if not (isinstance (x , _py_scalars ) or isinstance (y , _py_scalars )):
151+ xdt = x .dtype if not isinstance (x , torch .dtype ) else x
152+ ydt = y .dtype if not isinstance (y , torch .dtype ) else y
161153
162- if (xdt , ydt ) in _promotion_table :
163- return _promotion_table [xdt , ydt ]
154+ if (xdt , ydt ) in _promotion_table :
155+ return _promotion_table [xdt , ydt ]
164156
165157 # This doesn't result_type(dtype, dtype) for non-array API dtypes
166158 # because torch.result_type only accepts tensors. This does however, allow
@@ -169,6 +161,7 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple
169161 y = torch .tensor ([], dtype = y ) if isinstance (y , torch .dtype ) else y
170162 return torch .result_type (x , y )
171163
164+
172165def can_cast (from_ : Union [Dtype , array ], to : Dtype , / ) -> bool :
173166 if not isinstance (from_ , torch .dtype ):
174167 from_ = from_ .dtype
0 commit comments