2525
2626API_VERSION = default_version = "2022.12"
2727
28+ BOOLEAN_INDEXING = True
29+
2830DATA_DEPENDENT_SHAPES = True
2931
3032all_extensions = (
4648def set_array_api_strict_flags (
4749 * ,
4850 api_version = None ,
51+ boolean_indexing = None ,
4952 data_dependent_shapes = None ,
5053 enabled_extensions = None ,
5154):
@@ -67,6 +70,12 @@ def set_array_api_strict_flags(
6770 Note that 2021.12 is supported, but currently gives the same thing as
6871 2022.12 (except that the fft extension will be disabled).
6972
73+
74+ - `boolean_indexing`: Whether indexing by a boolean array is supported.
75+ Note that although boolean array indexing does result in data-dependent
76+ shapes, this flag is independent of the `data_dependent_shapes` flag
77+ (see below).
78+
7079 - `data_dependent_shapes`: Whether data-dependent shapes are enabled in
7180 array-api-strict.
7281
@@ -79,10 +88,12 @@ def set_array_api_strict_flags(
7988
8089 - `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`.
8190 - `nonzero`
82- - Boolean array indexing
8391 - `repeat` when the `repeats` argument is an array (requires 2023.12
8492 version of the standard)
8593
94+ Note that while boolean indexing is also data-dependent, it is
95+ controlled by a separate `boolean_indexing` flag (see above).
96+
8697 See
8798 https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html
8899 for more details.
@@ -102,8 +113,8 @@ def set_array_api_strict_flags(
102113 >>> # Set the standard version to 2021.12
103114 >>> set_array_api_strict_flags(api_version="2021.12")
104115
105- >>> # Disable data-dependent shapes
106- >>> set_array_api_strict_flags(data_dependent_shapes=False)
116+ >>> # Disable data-dependent shapes and boolean indexing
117+ >>> set_array_api_strict_flags(data_dependent_shapes=False, boolean_indexing=False )
107118
108119 >>> # Enable only the linalg extension (disable the fft extension)
109120 >>> set_array_api_strict_flags(enabled_extensions=["linalg"])
@@ -116,7 +127,7 @@ def set_array_api_strict_flags(
116127 ArrayAPIStrictFlags: A context manager to temporarily set the flags.
117128
118129 """
119- global API_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
130+ global API_VERSION , BOOLEAN_INDEXING , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
120131
121132 if api_version is not None :
122133 if api_version not in supported_versions :
@@ -126,6 +137,9 @@ def set_array_api_strict_flags(
126137 API_VERSION = api_version
127138 array_api_strict .__array_api_version__ = API_VERSION
128139
140+ if boolean_indexing is not None :
141+ BOOLEAN_INDEXING = boolean_indexing
142+
129143 if data_dependent_shapes is not None :
130144 DATA_DEPENDENT_SHAPES = data_dependent_shapes
131145
@@ -169,7 +183,11 @@ def get_array_api_strict_flags():
169183 >>> from array_api_strict import get_array_api_strict_flags
170184 >>> flags = get_array_api_strict_flags()
171185 >>> flags
172- {'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')}
186+ {'api_version': '2022.12',
187+ 'boolean_indexing': True,
188+ 'data_dependent_shapes': True,
189+ 'enabled_extensions': ('linalg', 'fft')
190+ }
173191
174192 See Also
175193 --------
@@ -181,6 +199,7 @@ def get_array_api_strict_flags():
181199 """
182200 return {
183201 "api_version" : API_VERSION ,
202+ "boolean_indexing" : BOOLEAN_INDEXING ,
184203 "data_dependent_shapes" : DATA_DEPENDENT_SHAPES ,
185204 "enabled_extensions" : ENABLED_EXTENSIONS ,
186205 }
@@ -215,9 +234,10 @@ def reset_array_api_strict_flags():
215234 ArrayAPIStrictFlags: A context manager to temporarily set the flags.
216235
217236 """
218- global API_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
237+ global API_VERSION , BOOLEAN_INDEXING , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
219238 API_VERSION = default_version
220239 array_api_strict .__array_api_version__ = API_VERSION
240+ BOOLEAN_INDEXING = True
221241 DATA_DEPENDENT_SHAPES = True
222242 ENABLED_EXTENSIONS = default_extensions
223243
@@ -242,10 +262,11 @@ class ArrayAPIStrictFlags:
242262 reset_array_api_strict_flags: Reset the flags to their default values.
243263
244264 """
245- def __init__ (self , * , api_version = None , data_dependent_shapes = None ,
246- enabled_extensions = None ):
265+ def __init__ (self , * , api_version = None , boolean_indexing = None ,
266+ data_dependent_shapes = None , enabled_extensions = None ):
247267 self .kwargs = {
248268 "api_version" : api_version ,
269+ "boolean_indexing" : boolean_indexing ,
249270 "data_dependent_shapes" : data_dependent_shapes ,
250271 "enabled_extensions" : enabled_extensions ,
251272 }
@@ -265,6 +286,11 @@ def set_flags_from_environment():
265286 api_version = os .environ ["ARRAY_API_STRICT_API_VERSION" ]
266287 )
267288
289+ if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os .environ :
290+ set_array_api_strict_flags (
291+ boolean_indexing = os .environ ["ARRAY_API_STRICT_BOOLEAN_INDEXING" ].lower () == "true"
292+ )
293+
268294 if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os .environ :
269295 set_array_api_strict_flags (
270296 data_dependent_shapes = os .environ ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" ].lower () == "true"
0 commit comments