1818
1919import pytest
2020
21- def test_flags ():
22- # Test defaults
21+ def test_flag_defaults ():
2322 flags = get_array_api_strict_flags ()
2423 assert flags == {
25- 'api_version' : '2022.12' ,
24+ 'api_version' : '2023.12' ,
25+ 'boolean_indexing' : True ,
26+ 'data_dependent_shapes' : True ,
27+ 'enabled_extensions' : ('linalg' , 'fft' ),
28+ }
29+
30+
31+ def test_reset_flags ():
32+ with pytest .warns (UserWarning ):
33+ set_array_api_strict_flags (
34+ api_version = '2021.12' ,
35+ boolean_indexing = False ,
36+ data_dependent_shapes = False ,
37+ enabled_extensions = ())
38+ reset_array_api_strict_flags ()
39+ flags = get_array_api_strict_flags ()
40+ assert flags == {
41+ 'api_version' : '2023.12' ,
2642 'boolean_indexing' : True ,
2743 'data_dependent_shapes' : True ,
2844 'enabled_extensions' : ('linalg' , 'fft' ),
2945 }
3046
31- # Test setting flags
47+
48+ def test_setting_flags ():
3249 set_array_api_strict_flags (data_dependent_shapes = False )
3350 flags = get_array_api_strict_flags ()
3451 assert flags == {
35- 'api_version' : '2022 .12' ,
52+ 'api_version' : '2023 .12' ,
3653 'boolean_indexing' : True ,
3754 'data_dependent_shapes' : False ,
3855 'enabled_extensions' : ('linalg' , 'fft' ),
3956 }
4057 set_array_api_strict_flags (enabled_extensions = ('fft' ,))
4158 flags = get_array_api_strict_flags ()
4259 assert flags == {
43- 'api_version' : '2022 .12' ,
60+ 'api_version' : '2023 .12' ,
4461 'boolean_indexing' : True ,
4562 'data_dependent_shapes' : False ,
4663 'enabled_extensions' : ('fft' ,),
4764 }
65+
66+ def test_flags_api_version_2021_12 ():
4867 # Make sure setting the version to 2021.12 disables fft and issues a
4968 # warning.
5069 with pytest .warns (UserWarning ) as record :
@@ -55,27 +74,23 @@ def test_flags():
5574 assert flags == {
5675 'api_version' : '2021.12' ,
5776 'boolean_indexing' : True ,
58- 'data_dependent_shapes' : False ,
59- 'enabled_extensions' : (),
77+ 'data_dependent_shapes' : True ,
78+ 'enabled_extensions' : ('linalg' , ),
6079 }
61- reset_array_api_strict_flags ()
6280
63- with pytest . warns ( UserWarning ):
64- set_array_api_strict_flags (api_version = '2021 .12' )
81+ def test_flags_api_version_2022_12 ( ):
82+ set_array_api_strict_flags (api_version = '2022 .12' )
6583 flags = get_array_api_strict_flags ()
6684 assert flags == {
67- 'api_version' : '2021 .12' ,
85+ 'api_version' : '2022 .12' ,
6886 'boolean_indexing' : True ,
6987 'data_dependent_shapes' : True ,
70- 'enabled_extensions' : ('linalg' ,),
88+ 'enabled_extensions' : ('linalg' , 'fft' ),
7189 }
72- reset_array_api_strict_flags ()
7390
74- # 2023.12 should issue a warning
75- with pytest .warns (UserWarning ) as record :
76- set_array_api_strict_flags (api_version = '2023.12' )
77- assert len (record ) == 1
78- assert '2023.12' in str (record [0 ].message )
91+
92+ def test_flags_api_version_2023_12 ():
93+ set_array_api_strict_flags (api_version = '2023.12' )
7994 flags = get_array_api_strict_flags ()
8095 assert flags == {
8196 'api_version' : '2023.12' ,
@@ -84,6 +99,7 @@ def test_flags():
8499 'enabled_extensions' : ('linalg' , 'fft' ),
85100 }
86101
102+ def test_setting_flags_invalid ():
87103 # Test setting flags with invalid values
88104 pytest .raises (ValueError , lambda :
89105 set_array_api_strict_flags (api_version = '2020.12' ))
@@ -94,35 +110,15 @@ def test_flags():
94110 api_version = '2021.12' ,
95111 enabled_extensions = ('linalg' , 'fft' )))
96112
97- # Test resetting flags
98- with pytest .warns (UserWarning ):
99- set_array_api_strict_flags (
100- api_version = '2021.12' ,
101- boolean_indexing = False ,
102- data_dependent_shapes = False ,
103- enabled_extensions = ())
104- reset_array_api_strict_flags ()
105- flags = get_array_api_strict_flags ()
106- assert flags == {
107- 'api_version' : '2022.12' ,
108- 'boolean_indexing' : True ,
109- 'data_dependent_shapes' : True ,
110- 'enabled_extensions' : ('linalg' , 'fft' ),
111- }
112-
113113def test_api_version ():
114114 # Test defaults
115- assert xp .__array_api_version__ == '2022 .12'
115+ assert xp .__array_api_version__ == '2023 .12'
116116
117117 # Test setting the version
118- with pytest .warns (UserWarning ):
119- set_array_api_strict_flags (api_version = '2021.12' )
120- assert xp .__array_api_version__ == '2021.12'
118+ set_array_api_strict_flags (api_version = '2022.12' )
119+ assert xp .__array_api_version__ == '2022.12'
121120
122121def test_data_dependent_shapes ():
123- with pytest .warns (UserWarning ):
124- set_array_api_strict_flags (api_version = '2023.12' ) # to enable repeat()
125-
126122 a = asarray ([0 , 0 , 1 , 2 , 2 ])
127123 mask = asarray ([True , False , True , False , True ])
128124 repeats = asarray ([1 , 1 , 2 , 2 , 2 ])
@@ -275,12 +271,16 @@ def test_fft(func_name):
275271def test_api_version_2023_12 (func_name ):
276272 func = api_version_2023_12_examples [func_name ]
277273
278- # By default, these functions should error
274+ # By default, these functions should not error
275+ func ()
276+
277+ # In 2022.12, these functions should error
278+ set_array_api_strict_flags (api_version = '2022.12' )
279279 pytest .raises (RuntimeError , func )
280280
281- with pytest . warns ( UserWarning ):
282- set_array_api_strict_flags (api_version = '2023.12' )
283- func ()
281+ # Test the behavior gets updated properly
282+ set_array_api_strict_flags (api_version = '2023.12' )
283+ func ()
284284
285285 set_array_api_strict_flags (api_version = '2022.12' )
286286 pytest .raises (RuntimeError , func )
@@ -371,16 +371,25 @@ def test_disabled_extensions():
371371 assert 'linalg' not in ns
372372 assert 'fft' not in ns
373373
374+ reset_array_api_strict_flags ()
375+ assert 'linalg' in xp .__all__
376+ assert 'fft' in xp .__all__
377+ xp .linalg # No error
378+ xp .fft # No error
379+ ns = {}
380+ exec ('from array_api_strict import *' , ns )
381+ assert 'linalg' in ns
382+ assert 'fft' in ns
374383
375384def test_environment_variables ():
376385 # Test that the environment variables work as expected
377386 subprocess_tests = [
378387 # ARRAY_API_STRICT_API_VERSION
379388 ('''\
380389 import array_api_strict as xp
381- assert xp.__array_api_version__ == '2022 .12'
390+ assert xp.__array_api_version__ == '2023 .12'
382391
383- assert xp.get_array_api_strict_flags()['api_version'] == '2022 .12'
392+ assert xp.get_array_api_strict_flags()['api_version'] == '2023 .12'
384393
385394''' , {}),
386395 * [
0 commit comments