@@ -245,15 +245,21 @@ def all_names(mod):
245245 return list (objs )
246246
247247
248+ def get_mod (library , module , * , compat ):
249+ if compat :
250+ library = f"array_api_compat.{ library } "
251+ xp = pytest .importorskip (library )
252+ return getattr (xp , module ) if module else xp
253+
254+
248255@pytest .mark .parametrize ("func" , [all_names , dir ])
249256@pytest .mark .parametrize ("module" , list (NAMES ))
250257@pytest .mark .parametrize ("library" , wrapped_libraries )
251258def test_array_api_names (library , module , func ):
252259 """Test that __all__ and dir() aren't missing any exports
253260 dictated by the Standard.
254261 """
255- xp = pytest .importorskip (f"array_api_compat.{ library } " )
256- mod = getattr (xp , module ) if module else xp
262+ mod = get_mod (library , module , compat = True )
257263 missing = set (NAMES [module ]) - set (func (mod ))
258264 xfail = set (XFAILS .get ((library , module ), []))
259265 xpass = xfail - missing
@@ -269,10 +275,8 @@ def test_compat_doesnt_hide_names(library, module, func):
269275 """The base namespace can have more names than the ones explicitly exported
270276 by array-api-compat. Test that we're not suppressing them.
271277 """
272- bare_xp = pytest .importorskip (library )
273- compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
274- bare_mod = getattr (bare_xp , module ) if module else bare_xp
275- compat_mod = getattr (compat_xp , module ) if module else compat_xp
278+ bare_mod = get_mod (library , module , compat = False )
279+ compat_mod = get_mod (library , module , compat = True )
276280
277281 missing = set (func (bare_mod )) - set (func (compat_mod ))
278282 missing = {name for name in missing if not name .startswith ("_" )}
@@ -286,10 +290,8 @@ def test_compat_doesnt_add_names(library, module, func):
286290 """Test that array-api-compat isn't adding names to the namespace
287291 besides those defined by the Array API Standard.
288292 """
289- bare_xp = pytest .importorskip (library )
290- compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
291- bare_mod = getattr (bare_xp , module ) if module else bare_xp
292- compat_mod = getattr (compat_xp , module ) if module else compat_xp
293+ bare_mod = get_mod (library , module , compat = False )
294+ compat_mod = get_mod (library , module , compat = True )
293295
294296 aapi_names = set (NAMES [module ])
295297 spurious = set (func (compat_mod )) - set (func (bare_mod )) - aapi_names - {"__all__" }
0 commit comments