Skip to content

Commit 3aa5923

Browse files
Merge pull request #85 from brandonwillard/custom-ndarray-class
Simplify NumPy array handling by using a custom ndarray class
2 parents 6f7f95e + 8b84ac1 commit 3aa5923

File tree

18 files changed

+203
-174
lines changed

18 files changed

+203
-174
lines changed

conftest.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
import pytest
22

3+
34
@pytest.fixture()
45
def run_with_theano():
5-
from symbolic_pymc.theano.meta import load_dispatcher
6+
# import symbolic_pymc.meta
67

7-
load_dispatcher()
8+
# from symbolic_pymc.meta import base_metatize
89

10+
import symbolic_pymc.theano.meta as tm
911

10-
@pytest.fixture()
11-
def run_with_tensorflow():
12-
from symbolic_pymc.tensorflow.meta import load_dispatcher
12+
tm.load_dispatcher()
1313

14-
load_dispatcher()
14+
# yield
1515

16-
# Let's make sure we have a clean graph slate
17-
from tensorflow.compat.v1 import reset_default_graph
18-
reset_default_graph()
16+
# symbolic_pymc.meta._metatize = base_metatize

symbolic_pymc/meta.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import types
33
import reprlib
44

5+
import numpy as np
6+
7+
from copy import deepcopy
58
from itertools import chain
69
from functools import partial
710
from collections import OrderedDict
@@ -10,7 +13,7 @@
1013

1114
from unification import isvar, Var
1215

13-
from .utils import _check_eq
16+
from .utils import HashableNDArray
1417

1518
from multipledispatch import dispatch
1619

@@ -89,6 +92,18 @@ def _metatize_list(obj):
8992
return type(obj)([metatize(o) for o in obj])
9093

9194

95+
@_metatize.register(np.ndarray)
96+
def _metatize_ndarray(obj):
97+
"""Convert Numpy ndarrays into hashable objects."""
98+
return metatize(obj.view(HashableNDArray))
99+
100+
101+
@_metatize.register(HashableNDArray)
102+
def _metatize_HashableNDArray(obj):
103+
"""Fallback case for converted Numpy ndarrays."""
104+
return obj
105+
106+
92107
@_metatize.register(Iterator)
93108
@cached(metatize_cache)
94109
def _metatize_Iterator(obj):
@@ -183,9 +198,6 @@ def __metatize(cls, obj):
183198

184199
new_cls = super().__new__(cls, name, bases, clsdict)
185200

186-
if isinstance(new_cls.base, type):
187-
_metatize.add((new_cls.base,), new_cls._metatize)
188-
189201
# Wrap the class implementation of `__hash__` with this value-caching
190202
# code.
191203
if "_hash" in clsdict["__volatile_slots__"]:
@@ -226,12 +238,11 @@ def obj(self):
226238
return object.__getattribute__(self, "_obj")
227239

228240
@classmethod
229-
def base_classes(cls, mro_order=True):
230-
res = tuple(c.base for c in cls.__subclasses__())
231-
if hasattr(cls, "base"):
232-
res = (cls.base,) + res
233-
sorted(res, key=lambda c: len(c.mro()), reverse=mro_order)
234-
return res
241+
def base_subclasses(cls):
242+
for subclass in cls.__subclasses__():
243+
yield from subclass.base_subclasses()
244+
if isinstance(subclass.base, type):
245+
yield subclass
235246

236247
@classmethod
237248
def is_meta(cls, obj):
@@ -294,7 +305,7 @@ def __eq__(self, other):
294305
assert self.base == other.base
295306

296307
if self.rands():
297-
return all(_check_eq(s, o) for s, o in zip(self.rands(), other.rands()))
308+
return all(s == o for s, o in zip(self.rands(), other.rands()))
298309
else:
299310
return NotImplemented
300311

@@ -459,3 +470,6 @@ def _metatize_type(obj_type):
459470

460471
if obj_cls is not None:
461472
return obj_cls
473+
474+
475+
base_metatize = deepcopy(_metatize)

symbolic_pymc/relations/theano/distributions.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .. import concat
1212
from ...etuple import etuple
13+
from ...utils import HashableNDArray
1314
from ...theano.meta import TheanoMetaConstant, mt
1415

1516
from kanren.facts import Relation
@@ -61,18 +62,18 @@ def constant_neq(lvar, val):
6162
Scalar values are broadcast across arrays.
6263
"""
6364

64-
def _goal(s):
65+
if isinstance(val, np.ndarray):
66+
val = val.view(HashableNDArray)
67+
68+
def constant_neq_goal(s):
6569
lvar_val = walk(lvar, s)
6670
if isinstance(lvar_val, (tt.Constant, TheanoMetaConstant)):
67-
data = lvar_val.data
68-
if (isinstance(val, np.ndarray) and not np.array_equal(data, val)) or not all(
69-
np.atleast_1d(data) == val
70-
):
71+
if lvar_val.data != val:
7172
yield s
7273
else:
7374
yield s
7475

75-
return _goal
76+
return constant_neq_goal
7677

7778

7879
def scale_loc_transform(in_expr, out_expr):

symbolic_pymc/tensorflow/meta.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import types
22
import inspect
33

4-
import numpy as np
5-
64
import tensorflow as tf
75
import tensorflow_probability as tfp
86

@@ -39,6 +37,9 @@
3937

4038
from .. import meta
4139

40+
from ..utils import HashableNDArray
41+
42+
4243
tf_metatize_cache = Cache(50)
4344

4445

@@ -167,7 +168,7 @@ def _metatize_tf_object(obj):
167168
try:
168169
tf_obj = tf.convert_to_tensor(obj)
169170
except (TypeError, ValueError):
170-
raise ValueError("Could not find a TensorFlow MetaSymbol class for {obj}")
171+
raise ValueError(f"Error converting {obj} to a TensorFlow tensor.")
171172

172173
return _metatize(tf_obj)
173174

@@ -180,9 +181,9 @@ def load_dispatcher():
180181

181182
def _metatize_tf_svd(obj):
182183
"""Turn a TensorFlow `Svd` object/tuple into a standard tuple."""
183-
return _metatize(tuple(obj))
184+
return meta._metatize(tuple(obj))
184185

185-
_metatize.add((_SvdOutput,), _metatize_tf_svd)
186+
meta._metatize.add((_SvdOutput,), _metatize_tf_svd)
186187

187188
def _metatize_tf_eager(obj):
188189
"""Catch eager tensor metatize issues early."""
@@ -192,12 +193,17 @@ def _metatize_tf_eager(obj):
192193
" (e.g. within `tensorflow.python.eager.context.graph_mode`)"
193194
)
194195

195-
_metatize.add((EagerTensor,), _metatize_tf_eager)
196+
meta._metatize.add((EagerTensor,), _metatize_tf_eager)
196197

197-
_metatize.add((object,), _metatize_tf_object)
198+
meta._metatize.add((object,), _metatize_tf_object)
199+
meta._metatize.add((HashableNDArray,), _metatize_tf_object)
198200

201+
for new_cls in TFlowMetaSymbol.base_subclasses():
202+
meta._metatize.add((new_cls.base,), new_cls._metatize)
199203

200-
load_dispatcher()
204+
meta._metatize.add((TFlowMetaOpDef.base,), TFlowMetaOpDef._metatize)
205+
206+
return meta._metatize
201207

202208

203209
class TFlowMetaSymbol(MetaSymbol):
@@ -451,7 +457,7 @@ def _protobuf_convert(cls, k, v):
451457
elif k == "T":
452458
return tf.as_dtype(v.type).name
453459
elif k == "value":
454-
return tensor_util.MakeNdarray(v.tensor)
460+
return tensor_util.MakeNdarray(v.tensor).view(HashableNDArray)
455461
else:
456462
# Consider only the narrow case where a single object is converted
457463
# (e.g. a Python builtin type under `v.b`, `v.f`, etc.)
@@ -492,9 +498,7 @@ def frozen_attr(self):
492498
if isvar(self.attr):
493499
self._frozen_attr = self.attr
494500
else:
495-
self._frozen_attr = frozenset(
496-
(k, v.tostring() if isinstance(v, np.ndarray) else v) for k, v in self.attr.items()
497-
)
501+
self._frozen_attr = frozenset(self.attr.items())
498502
return self._frozen_attr
499503

500504
def __eq__(self, other):
@@ -993,3 +997,5 @@ def __getattr__(self, obj):
993997

994998

995999
mt = TFlowMetaAccessor()
1000+
1001+
load_dispatcher()

symbolic_pymc/theano/meta.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,35 @@
2323
_metatize,
2424
)
2525

26+
from .. import meta
27+
28+
from ..utils import HashableNDArray
29+
2630

2731
def _metatize_theano_object(obj):
2832
try:
2933
obj = tt.as_tensor_variable(obj)
3034
except (ValueError, tt.AsTensorError):
31-
raise ValueError("Could not find a MetaSymbol class for {}".format(obj))
35+
raise ValueError("Error converting {} to a Theano tensor.".format(obj))
36+
except AssertionError:
37+
# This is a work-around for a Theano bug; specifically,
38+
# an assert statement in `theano.scalar.basic` that unnecessarily
39+
# requires the object type be exclusively an ndarray or memmap.
40+
# See https://github.com/Theano/Theano/pull/6727
41+
obj = tt.as_tensor_variable(np.asarray(obj))
42+
3243
return _metatize(obj)
3344

3445

3546
def load_dispatcher():
3647
"""Set/override dispatcher to default to TF objects."""
37-
_metatize.add((object,), _metatize_theano_object)
48+
meta._metatize.add((object,), _metatize_theano_object)
49+
meta._metatize.add((HashableNDArray,), _metatize_theano_object)
3850

51+
for new_cls in TheanoMetaSymbol.base_subclasses():
52+
meta._metatize.add((new_cls.base,), new_cls._metatize)
3953

40-
load_dispatcher()
54+
return meta._metatize
4155

4256

4357
class TheanoMetaSymbol(MetaSymbol):
@@ -202,7 +216,7 @@ def __call__(self, *args, ttype=None, index=None, **kwargs):
202216
# XXX: We don't have a higher-order meta object model, so being
203217
# wrong about the exact type of output variable will cause
204218
# problems.
205-
out_meta_type, = self.out_meta_types(op_args)
219+
(out_meta_type,) = self.out_meta_types(op_args)
206220
res_var = out_meta_type(ttype, res_apply, index, name)
207221
res_var._obj = var()
208222

@@ -451,28 +465,9 @@ def _metatize(cls, obj):
451465
return res
452466

453467
def __init__(self, type, data, name=None, obj=None):
454-
self.data = data
468+
self.data = data if not isinstance(data, np.ndarray) else data.view(HashableNDArray)
455469
super().__init__(type, None, None, name, obj=obj)
456470

457-
def __eq__(self, other):
458-
if self is other:
459-
return True
460-
461-
if type(self) != type(other):
462-
return False
463-
464-
if all(
465-
(s.tostring() if isinstance(s, np.ndarray) else s)
466-
== (o.tostring() if isinstance(o, np.ndarray) else o)
467-
for s, o in zip(self.rands(), other.rands())
468-
):
469-
return True
470-
471-
return False
472-
473-
def __hash__(self):
474-
return hash(v.tostring() if isinstance(v, np.ndarray) else v for v in self.rands())
475-
476471

477472
class TheanoMetaTensorConstant(TheanoMetaConstant):
478473
# TODO: Could extend `theano.tensor.var._tensor_py_operators`, too.
@@ -606,7 +601,9 @@ def meta_obj(*args, **kwargs):
606601

607602
mt = TheanoMetaAccessor()
608603

609-
mt.dot = metatize(tt.basic._dot)
604+
_metatize = load_dispatcher()
605+
606+
mt.dot = _metatize(tt.basic._dot)
610607

611608

612609
#

symbolic_pymc/theano/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def _infer_shape(self, size, dist_params, param_shapes=None):
168168

169169
# _, out_bcasts, bcastd_inputs = tt.add.get_output_info(tt.DimShuffle, *dist_params)
170170

171-
bcast_ind, = out_bcasts
171+
(bcast_ind,) = out_bcasts
172172
ndim_ind = len(bcast_ind)
173173
shape_ind = bcastd_inputs[0].shape
174174

@@ -226,7 +226,7 @@ def compute_bcast(self, dist_params, size):
226226
s_x, s_idx = s.owner.inputs
227227
s_idx = tt.get_scalar_constant_value(s_idx)
228228
if isinstance(s_x.owner.op, tt.Shape):
229-
x_obj, = s_x.owner.inputs
229+
(x_obj,) = s_x.owner.inputs
230230
s_val = x_obj.type.broadcastable[s_idx]
231231
else:
232232
# TODO: Could go for an existing broadcastable here,

symbolic_pymc/theano/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def optimize_graph(x, optimization, return_graph=None, in_place=False):
122122
else:
123123
res = x_graph_opt.outputs
124124
if len(res) == 1:
125-
res, = res
125+
(res,) = res
126126
return res
127127

128128

0 commit comments

Comments
 (0)