File tree Expand file tree Collapse file tree 2 files changed +23
-0
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +23
-0
lines changed Original file line number Diff line number Diff line change @@ -692,6 +692,11 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
692692 fuseable_clients : FUSEABLE_MAPPING = defaultdict (list )
693693 unfuseable_clients : UNFUSEABLE_MAPPING = defaultdict (set )
694694 for out , clients in fg .clients .items ():
695+ # Old FunctionGraph nodes remain in the clients dictionary
696+ # even after they are removed by rewrites
697+ if not clients :
698+ continue
699+
695700 out_maybe_fuseable = (
696701 out .owner
697702 and isinstance (out .owner .op , Elemwise )
Original file line number Diff line number Diff line change 1+ import warnings
2+
13import numpy as np
24import pytest
35
3638 invert ,
3739 iround ,
3840 log ,
41+ log1mexp ,
3942 log2 ,
4043 log10 ,
4144 mul ,
@@ -1370,6 +1373,21 @@ def rewrite_func():
13701373
13711374 assert benchmark (rewrite_func ) == 103
13721375
1376+ def test_no_warning_from_old_client (self ):
1377+ # There used to be a warning issued when creating fuseable mapping
1378+ # for nodes that are no longer in the FunctionGraph
1379+ with warnings .catch_warnings ():
1380+ warnings .simplefilter ("error" )
1381+ # The -2 integer array cannot be passed directly to the C method
1382+ # of log1mexp as that can only handle floats. There is a rewrite
1383+ # that casts it to a float, but the FunctionGraph client retains
1384+ # the original log1mexp of the integer input, which caused
1385+ # a misleading warning for non C implementation in the FusionRewrite
1386+ assert np .isclose (
1387+ log1mexp (np .array (- 2 , dtype = "int64" )).eval (),
1388+ np .log (1 - np .exp (- 2 )),
1389+ )
1390+
13731391
13741392class TimesN (aes .basic .UnaryScalarOp ):
13751393 """
You can’t perform that action at this time.
0 commit comments