11from typing import Any
22
33import numpy as np
4+ import pytest
45
56from kirin .prelude import basic
67from kirin .dialects import vmath
@@ -22,6 +23,7 @@ def add_scalar_lhs():
2223 return add_kernel (x = 3.0 , y = [3.0 , 4 , 5 ])
2324
2425
26+ @pytest .mark .xfail ()
2527def test_add_scalar_lhs ():
2628 # out = add_scalar_lhs()
2729 add_scalar_lhs .print ()
@@ -31,9 +33,10 @@ def test_add_scalar_lhs():
3133 assert np .allclose (np .asarray (res ), np .array ([6 , 7 , 8 ]))
3234
3335
36+ @pytest .mark .xfail ()
3437def test_typed_kernel_add ():
3538 add_scalar_rhs_typed .print ()
36- res = add_scalar_rhs_typed (IList ([0 , 1 , 2 ]), 3.1 )
39+ res = add_scalar_rhs_typed (IList ([0.0 , 1.0 , 2.0 ]), 3.1 )
3740 assert np .allclose (np .asarray (res ), np .asarray ([3.1 , 4.1 , 5.1 ]))
3841
3942
@@ -52,6 +55,7 @@ def sub_scalar_rhs_typed(x: IList[float, Any], y: float):
5255 return x - y
5356
5457
58+ @pytest .mark .xfail ()
5559def test_sub_scalar_typed ():
5660 res = sub_scalar_rhs_typed (IList ([0 , 1 , 2 ]), 3.1 )
5761 assert np .allclose (np .asarray (res ), np .asarray ([- 3.1 , - 2.1 , - 1.1 ]))
@@ -62,11 +66,30 @@ def mult_scalar_lhs_typed(x: float, y: IList[float, Any]):
6266 return x * y
6367
6468
69+ @basic .union ([vmath ])(typeinfer = True )
70+ def mult_kernel (x , y ):
71+ return x * y
72+
73+
74+ @basic .union ([vmath ])(typeinfer = True , aggressive = True )
75+ def mult_scalar_lhs ():
76+ return mult_kernel (x = 3.0 , y = [3.0 , 4.0 , 5.0 ])
77+
78+
79+ @pytest .mark .xfail ()
6580def test_mult_scalar_typed ():
6681 res = mult_scalar_lhs_typed (3 , IList ([0 , 1 , 2 ]))
6782 assert np .allclose (np .asarray (res ), np .asarray ([0 , 3 , 6 ]))
6883
6984
85+ @pytest .mark .xfail ()
86+ def test_mult_scalar_lhs ():
87+ res = mult_scalar_lhs ()
88+ assert isinstance (res , IList )
89+ assert res .type .vars [0 ].typ is float
90+ assert np .allclose (np .asarray (res ), np .array ([9 , 12 , 15 ]))
91+
92+
7093@basic .union ([vmath ])(typeinfer = True )
7194def div_scalar_lhs_typed (x : float , y : IList [float , Any ]):
7295 return x / y
0 commit comments