22
33import pytensor .tensor as pt
44from pytensor .compile .debugmode import _lessbroken_deepcopy
5- from pytensor .configdefaults import config
65from pytensor .graph .basic import Apply , Constant , Variable
76from pytensor .graph .op import Op
87from pytensor .link .c .op import COp
9- from pytensor .tensor .type import scalar
8+ from pytensor .tensor .type import lscalar
109from pytensor .tensor .type_other import SliceType
1110from pytensor .tensor .variable import TensorVariable
1211from pytensor .typed_list .type import TypedListType
@@ -508,7 +507,7 @@ class Index(Op):
508507 def make_node (self , x , elem ):
509508 assert isinstance (x .type , TypedListType )
510509 assert x .ttype == elem .type
511- return Apply (self , [x , elem ], [scalar ()])
510+ return Apply (self , [x , elem ], [lscalar ()])
512511
513512 def perform (self , node , inputs , outputs ):
514513 """
@@ -520,7 +519,7 @@ def perform(self, node, inputs, outputs):
520519 (out ,) = outputs
521520 for y in range (len (x )):
522521 if node .inputs [0 ].ttype .values_eq (x [y ], elem ):
523- out [0 ] = np .asarray (y , dtype = config . floatX )
522+ out [0 ] = np .asarray (y , dtype = "int64" )
524523 break
525524
526525 def __str__ (self ):
@@ -537,7 +536,7 @@ class Count(Op):
537536 def make_node (self , x , elem ):
538537 assert isinstance (x .type , TypedListType )
539538 assert x .ttype == elem .type
540- return Apply (self , [x , elem ], [scalar ()])
539+ return Apply (self , [x , elem ], [lscalar ()])
541540
542541 def perform (self , node , inputs , outputs ):
543542 """
@@ -551,7 +550,7 @@ def perform(self, node, inputs, outputs):
551550 for y in range (len (x )):
552551 if node .inputs [0 ].ttype .values_eq (x [y ], elem ):
553552 out [0 ] += 1
554- out [0 ] = np .asarray (out [0 ], dtype = config . floatX )
553+ out [0 ] = np .asarray (out [0 ], "int64" )
555554
556555 def __str__ (self ):
557556 return self .__class__ .__name__
@@ -583,7 +582,7 @@ class Length(COp):
583582
584583 def make_node (self , x ):
585584 assert isinstance (x .type , TypedListType )
586- return Apply (self , [x ], [scalar ( dtype = "int64" )])
585+ return Apply (self , [x ], [lscalar ( )])
587586
588587 def perform (self , node , x , outputs ):
589588 (out ,) = outputs
0 commit comments