Skip to content

Commit 7ccde64

Browse files
committed
Change TypedList Count and Index output to int64
1 parent f49a6c5 commit 7ccde64

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

pytensor/typed_list/basic.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22

33
import pytensor.tensor as pt
44
from pytensor.compile.debugmode import _lessbroken_deepcopy
5-
from pytensor.configdefaults import config
65
from pytensor.graph.basic import Apply, Constant, Variable
76
from pytensor.graph.op import Op
87
from pytensor.link.c.op import COp
9-
from pytensor.tensor.type import scalar
8+
from pytensor.tensor.type import lscalar
109
from pytensor.tensor.type_other import SliceType
1110
from pytensor.tensor.variable import TensorVariable
1211
from pytensor.typed_list.type import TypedListType
@@ -508,7 +507,7 @@ class Index(Op):
508507
def make_node(self, x, elem):
509508
assert isinstance(x.type, TypedListType)
510509
assert x.ttype == elem.type
511-
return Apply(self, [x, elem], [scalar()])
510+
return Apply(self, [x, elem], [lscalar()])
512511

513512
def perform(self, node, inputs, outputs):
514513
"""
@@ -520,7 +519,7 @@ def perform(self, node, inputs, outputs):
520519
(out,) = outputs
521520
for y in range(len(x)):
522521
if node.inputs[0].ttype.values_eq(x[y], elem):
523-
out[0] = np.asarray(y, dtype=config.floatX)
522+
out[0] = np.asarray(y, dtype="int64")
524523
break
525524

526525
def __str__(self):
@@ -537,7 +536,7 @@ class Count(Op):
537536
def make_node(self, x, elem):
538537
assert isinstance(x.type, TypedListType)
539538
assert x.ttype == elem.type
540-
return Apply(self, [x, elem], [scalar()])
539+
return Apply(self, [x, elem], [lscalar()])
541540

542541
def perform(self, node, inputs, outputs):
543542
"""
@@ -551,7 +550,7 @@ def perform(self, node, inputs, outputs):
551550
for y in range(len(x)):
552551
if node.inputs[0].ttype.values_eq(x[y], elem):
553552
out[0] += 1
554-
out[0] = np.asarray(out[0], dtype=config.floatX)
553+
out[0] = np.asarray(out[0], "int64")
555554

556555
def __str__(self):
557556
return self.__class__.__name__
@@ -583,7 +582,7 @@ class Length(COp):
583582

584583
def make_node(self, x):
585584
assert isinstance(x.type, TypedListType)
586-
return Apply(self, [x], [scalar(dtype="int64")])
585+
return Apply(self, [x], [lscalar()])
587586

588587
def perform(self, node, x, outputs):
589588
(out,) = outputs

0 commit comments

Comments
 (0)