@@ -99,18 +99,20 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
9999 if not all (params .type .broadcastable ):
100100 return None
101101
102- # Check whether axis covers all dimensions
103- axis = set (node .op .axis )
104- base_var_dims = set (range (base_var .ndim ))
105- if axis != base_var_dims :
106- return None
102+ if node .op .axis is None :
103+ axis = tuple (range (base_var .ndim ))
104+ else :
105+ # Check whether axis covers all dimensions
106+ axis = tuple (sorted (node .op .axis ))
107+ if axis != tuple (range (base_var .ndim )):
108+ return None
107109
108110 # distinguish measurable discrete and continuous (because logprob is different)
109111 measurable_max : Max
110112 if base_var .type .dtype .startswith ("int" ):
111- measurable_max = MeasurableMaxDiscrete (list ( axis ) )
113+ measurable_max = MeasurableMaxDiscrete (axis )
112114 else :
113- measurable_max = MeasurableMax (list ( axis ) )
115+ measurable_max = MeasurableMax (axis )
114116
115117 max_rv_node = measurable_max .make_node (base_var )
116118 max_rv = max_rv_node .outputs
@@ -206,21 +208,23 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVa
206208 if not all (params .type .broadcastable ):
207209 return None
208210
209- # Check whether axis is supported or not
210- axis = set (node .op .axis )
211- base_var_dims = set (range (base_var .ndim ))
212- if axis != base_var_dims :
213- return None
211+ if node .op .axis is None :
212+ axis = tuple (range (base_var .ndim ))
213+ else :
214+ # Check whether axis is supported or not
215+ axis = tuple (sorted (node .op .axis ))
216+ if axis != tuple (range (base_var .ndim )):
217+ return None
214218
215219 if not rv_map_feature .request_measurable ([base_rv ]):
216220 return None
217221
218222 # distinguish measurable discrete and continuous (because logprob is different)
219223 measurable_min : Max
220224 if base_rv .type .dtype .startswith ("int" ):
221- measurable_min = MeasurableDiscreteMaxNeg (list ( axis ) )
225+ measurable_min = MeasurableDiscreteMaxNeg (axis )
222226 else :
223- measurable_min = MeasurableMaxNeg (list ( axis ) )
227+ measurable_min = MeasurableMaxNeg (axis )
224228
225229 return measurable_min .make_node (base_rv ).outputs
226230
0 commit comments