Skip to content

Commit 7fdd126

Browse files
authored
Merge pull request #54 from RedisAI/tensorget_as_mutable_numpy
TESNORGET as mutable numpy ndarray
2 parents 740fac0 + 6427625 commit 7fdd126

File tree

4 files changed

+28
-12
lines changed

4 files changed

+28
-12
lines changed

redisai/client.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def tensorset(self,
335335
return res if not self.enable_postprocess else processor.tensorset(res)
336336

337337
def tensorget(self,
338-
key: AnyStr, as_numpy: bool = True,
338+
key: AnyStr, as_numpy: bool = True, as_numpy_mutable: bool = False,
339339
meta_only: bool = False) -> Union[dict, np.ndarray]:
340340
"""
341341
Retrieve the value of a tensor from the server. By default it returns the numpy
@@ -349,6 +349,9 @@ def tensorget(self,
349349
If True, returns a numpy.ndarray. Returns the value as a list and the
350350
metadata in a dictionary if False. This flag also decides how to fetch
351351
the value from the RedisAI server, which also has performance implications
352+
as_numpy_mutable : bool
353+
If True, returns a a mutable numpy.ndarray object by copy the tensor data. Otherwise (as long as_numpy=True)
354+
the returned numpy.ndarray will use the original tensor buffer and will be for read-only
352355
meta_only : bool
353356
If True, the value is not retrieved, only the shape and the type
354357
@@ -368,8 +371,7 @@ def tensorget(self,
368371
"""
369372
args = builder.tensorget(key, as_numpy, meta_only)
370373
res = self.execute_command(*args)
371-
return res if not self.enable_postprocess else processor.tensorget(res,
372-
as_numpy, meta_only)
374+
return res if not self.enable_postprocess else processor.tensorget(res, as_numpy, as_numpy_mutable, meta_only)
373375

374376
def scriptset(self, key: AnyStr, device: str, script: str, tag: AnyStr = None) -> str:
375377
"""
@@ -587,11 +589,12 @@ def __init__(self, enable_postprocess, *args, **kwargs):
587589
def dag(self, *args, **kwargs):
588590
raise RuntimeError("Pipeline object doesn't allow DAG creation currently")
589591

590-
def tensorget(self, key, as_numpy=True, meta_only=False):
592+
def tensorget(self, key, as_numpy=True, as_numpy_mutable=False, meta_only=False):
591593
self.tensorget_processors.append(partial(processor.tensorget,
592594
as_numpy=as_numpy,
595+
as_numpy_mutable=as_numpy_mutable,
593596
meta_only=meta_only))
594-
return super().tensorget(key, as_numpy, meta_only)
597+
return super().tensorget(key, as_numpy, as_numpy_mutable, meta_only)
595598

596599
def _execute_transaction(self, *args, **kwargs):
597600
# TODO: Blocking commands like MODELRUN, SCRIPTRUN and DAGRUN won't work
@@ -648,13 +651,14 @@ def tensorset(self,
648651
return self
649652

650653
def tensorget(self,
651-
key: AnyStr, as_numpy: bool = True,
654+
key: AnyStr, as_numpy: bool = True, as_numpy_mutable: bool = False,
652655
meta_only: bool = False) -> Any:
653-
args = builder.tensorget(key, as_numpy, meta_only)
656+
args = builder.tensorget(key, as_numpy, as_numpy_mutable)
654657
self.commands.extend(args)
655658
self.commands.append("|>")
656659
self.result_processors.append(partial(processor.tensorget,
657660
as_numpy=as_numpy,
661+
as_numpy_mutable=as_numpy_mutable,
658662
meta_only=meta_only))
659663
return self
660664

redisai/postprocessor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,19 @@ def modelscan(res):
1919
return utils.recursive_bytetransform(res, lambda x: x.decode())
2020

2121
@staticmethod
22-
def tensorget(res, as_numpy, meta_only):
22+
def tensorget(res, as_numpy, as_numpy_mutable, meta_only):
2323
"""Process the tensorget output.
2424
2525
If ``as_numpy`` is True, it'll be converted to a numpy array. The required
2626
information such as datatype and shape must be in ``rai_result`` itself.
2727
"""
2828
rai_result = utils.list2dict(res)
29-
if meta_only:
29+
if meta_only is True:
3030
return rai_result
31+
elif as_numpy_mutable is True:
32+
return utils.blob2numpy(rai_result['blob'], rai_result['shape'], rai_result['dtype'], mutable=True)
3133
elif as_numpy is True:
32-
return utils.blob2numpy(rai_result['blob'], rai_result['shape'], rai_result['dtype'])
34+
return utils.blob2numpy(rai_result['blob'], rai_result['shape'], rai_result['dtype'], mutable=False)
3335
else:
3436
target = float if rai_result['dtype'] in ('FLOAT', 'DOUBLE') else int
3537
utils.recursive_bytetransform(rai_result['values'], target)

redisai/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,17 @@ def numpy2blob(tensor: np.ndarray) -> tuple:
3131
return dtype, shape, blob
3232

3333

34-
def blob2numpy(value: ByteString, shape: Union[list, tuple], dtype: str) -> np.ndarray:
34+
def blob2numpy(value: ByteString, shape: Union[list, tuple], dtype: str, mutable: bool) -> np.ndarray:
3535
"""Convert `BLOB` result from RedisAI to `np.ndarray`."""
3636
mm = {
3737
'FLOAT': 'float32',
3838
'DOUBLE': 'float64'
3939
}
4040
dtype = mm.get(dtype, dtype.lower())
41-
a = np.frombuffer(value, dtype=dtype)
41+
if mutable:
42+
a = np.fromstring(value, dtype=dtype)
43+
else:
44+
a = np.frombuffer(value, dtype=dtype)
4245
return a.reshape(shape)
4346

4447

test/test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ def test_numpy_tensor(self):
102102
ret = con.tensorset('x', values)
103103
self.assertEqual(ret, 'OK')
104104

105+
# By default tensorget returns immutable, unless as_numpy_mutable is set as True
106+
ret = con.tensorget('x')
107+
self.assertRaises(ValueError, np.put, ret, 0, 1)
108+
ret = con.tensorget('x', as_numpy_mutable=True)
109+
np.put(ret, 0, 1)
110+
self.assertEqual(ret[0], 1)
111+
105112
stringarr = np.array('dummy')
106113
with self.assertRaises(TypeError):
107114
con.tensorset('trying', stringarr)

0 commit comments

Comments
 (0)