Skip to content

Commit 61d8c6e

Browse files
author
hhsecond
committed
dagrun testcases
1 parent 053c8b7 commit 61d8c6e

File tree

3 files changed

+147
-22
lines changed

3 files changed

+147
-22
lines changed

redisai/client.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(self, load, persist, executor, readonly=False):
2727
if persist:
2828
raise RuntimeError("READONLY requests cannot write (duh!) and should not "
2929
"have PERSISTing values")
30-
self.commands = ['AI.DAGRUNRO']
30+
self.commands = ['AI.DAGRUN_RO']
3131
else:
3232
self.commands = ['AI.DAGRUN']
3333
if load:
@@ -40,6 +40,8 @@ def __init__(self, load, persist, executor, readonly=False):
4040
self.commands += ["PERSIST", 1, persist, '|>']
4141
else:
4242
self.commands += ["PERSIST", len(persist), *persist, '|>']
43+
elif load:
44+
self.commands.append('|>')
4345
self.executor = executor
4446

4547
def tensorset(self,
@@ -66,8 +68,8 @@ def tensorget(self,
6668

6769
def modelrun(self,
6870
name: AnyStr,
69-
inputs: List[AnyStr],
70-
outputs: List[AnyStr]) -> Any:
71+
inputs: Union[AnyStr, List[AnyStr]],
72+
outputs: Union[AnyStr, List[AnyStr]]) -> Any:
7173
args = builder.modelrun(name, inputs, outputs)
7274
self.commands.extend(args)
7375
self.commands.append("|>")
@@ -117,8 +119,8 @@ def modelset(self,
117119
batch: int = None,
118120
minbatch: int = None,
119121
tag: AnyStr = None,
120-
inputs: List[AnyStr] = None,
121-
outputs: List[AnyStr] = None) -> str:
122+
inputs: Union[AnyStr, List[AnyStr]] = None,
123+
outputs: Union[AnyStr, List[AnyStr]] = None) -> str:
122124
"""
123125
Set the model on provided key.
124126
:param name: str, Key name
@@ -151,16 +153,17 @@ def modeldel(self, name: AnyStr) -> str:
151153

152154
def modelrun(self,
153155
name: AnyStr,
154-
inputs: List[AnyStr],
155-
outputs: List[AnyStr]) -> str:
156+
inputs: Union[AnyStr, List[AnyStr]],
157+
outputs: Union[AnyStr, List[AnyStr]]) -> str:
156158
args = builder.modelrun(name, inputs, outputs)
157159
return self.execute_command(*args).decode()
158160

159161
def modelscan(self) -> list:
160162
warnings.warn("Experimental: Model List API is experimental and might change "
161163
"in the future without any notice", UserWarning)
162164
args = builder.modelscan()
163-
return utils.recursive_bytetransform(self.execute_command(*args), lambda x: x.decode())
165+
result = self.execute_command(*args)
166+
return utils.recursive_bytetransform(result, lambda x: x.decode())
164167

165168
def tensorset(self,
166169
key: AnyStr,

redisai/command_builder.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ def loadbackend(identifier: AnyStr, path: AnyStr) -> Sequence:
1212
return 'AI.CONFIG LOADBACKEND', identifier, path
1313

1414
def modelset(self, name: AnyStr, backend: str, device: str, data: ByteString,
15-
batch: int, minbatch: int, tag: AnyStr, inputs: List[AnyStr],
16-
outputs: List[AnyStr]) -> Sequence:
15+
batch: int, minbatch: int, tag: AnyStr,
16+
inputs: Union[AnyStr, List[AnyStr]],
17+
outputs: Union[AnyStr, List[AnyStr]]) -> Sequence:
1718
args = ['AI.MODELSET', name, backend, device]
1819

1920
if batch is not None:
@@ -27,8 +28,8 @@ def modelset(self, name: AnyStr, backend: str, device: str, data: ByteString,
2728
if not(all((inputs, outputs))):
2829
raise ValueError(
2930
'Require keyword arguments input and output for TF models')
30-
args += ['INPUTS'] + utils.listify(inputs)
31-
args += ['OUTPUTS'] + utils.listify(outputs)
31+
args += ['INPUTS', *utils.listify(inputs)]
32+
args += ['OUTPUTS', *utils.listify(outputs)]
3233
chunk_size = 500 * 1024 * 1024
3334
data_chunks = [data[i:i + chunk_size] for i in range(0, len(data), chunk_size)]
3435
# TODO: need a test case for this
@@ -64,6 +65,11 @@ def tensorset(self,
6465
dtype, shape, blob = utils.numpy2blob(tensor)
6566
args = ['AI.TENSORSET', key, dtype, *shape, 'BLOB', blob]
6667
elif isinstance(tensor, (list, tuple)):
68+
try:
69+
dtype = utils.dtype_dict[dtype.lower()]
70+
except KeyError:
71+
raise TypeError(f'``{dtype}`` is not supported by RedisAI. Currently '
72+
f'supported types are {list(utils.dtype_dict.keys())}')
6773
if shape is None:
6874
shape = (len(tensor),)
6975
args = ['AI.TENSORSET', key, dtype, *shape, 'VALUES', *tensor]
@@ -92,7 +98,6 @@ def scriptset(self, name: AnyStr, device: str, script: str, tag: AnyStr = None)
9298
return args
9399

94100
def scriptget(self, name: AnyStr, meta_only=False) -> Sequence:
95-
# TODO scripget test
96101
args = ['AI.SCRIPTGET', name, 'META']
97102
if not meta_only:
98103
args.append('SOURCE')

test/test.py

Lines changed: 126 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,36 @@ def bar(a, b):
2929
return a + b
3030
"""
3131

32-
33-
class ClientTestCase(TestCase):
32+
class RedisAITestBase(TestCase):
3433
def setUp(self):
35-
super(ClientTestCase, self).setUp()
34+
super().setUp()
3635
self.get_client().flushall()
3736

3837
def get_client(self, debug=DEBUG):
3938
return Client(debug)
4039

40+
41+
class ClientTestCase(RedisAITestBase):
42+
4143
def test_set_non_numpy_tensor(self):
4244
con = self.get_client()
4345
con.tensorset('x', (2, 3, 4, 5), dtype='float')
4446
result = con.tensorget('x', as_numpy=False)
4547
self.assertEqual([2, 3, 4, 5], result['values'])
4648
self.assertEqual([4], result['shape'])
4749

50+
con.tensorset('x', (2, 3, 4, 5), dtype='float64')
51+
result = con.tensorget('x', as_numpy=False)
52+
self.assertEqual([2, 3, 4, 5], result['values'])
53+
self.assertEqual([4], result['shape'])
54+
self.assertEqual('DOUBLE', result['dtype'])
55+
4856
con.tensorset('x', (2, 3, 4, 5), dtype='int16', shape=(2, 2))
4957
result = con.tensorget('x', as_numpy=False)
5058
self.assertEqual([2, 3, 4, 5], result['values'])
5159
self.assertEqual([2, 2], result['shape'])
5260

53-
with self.assertRaises(ResponseError):
61+
with self.assertRaises(TypeError):
5462
con.tensorset('x', (2, 3, 4, 5), dtype='wrongtype', shape=(2, 2))
5563
con.tensorset('x', (2, 3, 4, 5), dtype='int8', shape=(2, 2))
5664
result = con.tensorget('x', as_numpy=False)
@@ -63,7 +71,7 @@ def test_set_non_numpy_tensor(self):
6371
con.tensorset('x')
6472
con.tensorset(1)
6573

66-
def test_meta(self):
74+
def test_tensorget_meta(self):
6775
con = self.get_client()
6876
con.tensorset('x', (2, 3, 4, 5), dtype='float')
6977
result = con.tensorget('x', meta_only=True)
@@ -73,9 +81,20 @@ def test_meta(self):
7381
def test_numpy_tensor(self):
7482
con = self.get_client()
7583

84+
input_array = np.array([2, 3], dtype=np.float32)
85+
con.tensorset('x', input_array)
86+
values = con.tensorget('x')
87+
self.assertEqual(values.dtype, np.float32)
88+
89+
input_array = np.array([2, 3], dtype=np.float64)
90+
con.tensorset('x', input_array)
91+
values = con.tensorget('x')
92+
self.assertEqual(values.dtype, np.float64)
93+
7694
input_array = np.array([2, 3])
7795
con.tensorset('x', input_array)
7896
values = con.tensorget('x')
97+
7998
self.assertTrue(np.allclose([2, 3], values))
8099
self.assertEqual(values.dtype, np.int64)
81100
self.assertEqual(values.shape, (2,))
@@ -87,6 +106,15 @@ def test_numpy_tensor(self):
87106
with self.assertRaises(TypeError):
88107
con.tensorset('trying', stringarr)
89108

109+
def test_modelget_meta(self):
110+
model_path = os.path.join(MODEL_DIR, 'graph.pb')
111+
model_pb = load_model(model_path)
112+
con = self.get_client()
113+
con.modelset('m', 'tf', 'cpu', model_pb,
114+
inputs=['a', 'b'], outputs=['mul'], tag='v1.0')
115+
model = con.modelget('m', meta_only=True)
116+
self.assertEqual(model, {'backend': 'TF', 'device': 'cpu', 'tag': 'v1.0'})
117+
90118
def test_modelrun_non_list_input_output(self):
91119
model_path = os.path.join(MODEL_DIR, 'graph.pb')
92120
model_pb = load_model(model_path)
@@ -121,6 +149,10 @@ def test_run_tf_model(self):
121149
con = self.get_client()
122150
con.modelset('m', 'tf', 'cpu', model_pb,
123151
inputs=['a', 'b'], outputs=['mul'], tag='v1.0')
152+
con.modeldel('m')
153+
self.assertRaises(ResponseError, con.modelget, 'm')
154+
con.modelset('m', 'tf', 'cpu', model_pb,
155+
inputs=['a', 'b'], outputs='mul', tag='v1.0')
124156

125157
# wrong model
126158
self.assertRaises(ResponseError,
@@ -166,6 +198,9 @@ def test_scripts(self):
166198
script_det = con.scriptget('ket')
167199
self.assertTrue(script_det['device'] == 'cpu')
168200
self.assertTrue(script_det['source'] == script)
201+
script_det = con.scriptget('ket', meta_only=True)
202+
self.assertTrue(script_det['device'] == 'cpu')
203+
self.assertNotIn('source', script_det)
169204
con.scriptdel('ket')
170205
self.assertRaises(ResponseError, con.scriptget, 'ket')
171206

@@ -245,7 +280,7 @@ def test_model_scan(self):
245280
ptmodel = load_model(model_path)
246281
con = self.get_client()
247282
con.modelset("pt_model", 'torch', 'cpu', ptmodel)
248-
mlist = con.modelscan()
283+
mlist = con.modelscan() # TODO: modelscan issues in RedisAI
249284
self.assertEqual(mlist, [['pt_model', ''], ['m', 'v1.2']])
250285

251286
def test_script_scan(self):
@@ -259,21 +294,103 @@ def test_debug(self):
259294
con = self.get_client(debug=True)
260295
with Capturing() as output:
261296
con.tensorset('x', (2, 3, 4, 5), dtype='float')
262-
self.assertEqual(['AI.TENSORSET x float 4 VALUES 2 3 4 5'], output)
297+
self.assertEqual(['AI.TENSORSET x FLOAT 4 VALUES 2 3 4 5'], output)
298+
263299

264-
def test_z_dag(self): # TODO: z in the name is to make it run in the end
300+
class DagTestCase(RedisAITestBase):
301+
def setUp(self):
302+
super().setUp()
303+
con = self.get_client()
265304
model_path = os.path.join(MODEL_DIR, 'pt-minimal.pt')
266305
ptmodel = load_model(model_path)
267-
con = self.get_client()
268306
con.modelset("pt_model", 'torch', 'cpu', ptmodel, tag='v1.0')
307+
308+
def test_dagrun_with_load(self):
309+
con = self.get_client()
310+
con.tensorset('a', [2, 3, 2, 3], shape=(2, 2), dtype='float')
311+
312+
dag = con.dag(load='a')
313+
dag.tensorset('b', [2, 3, 2, 3], shape=(2, 2), dtype='float')
314+
dag.modelrun("pt_model", ["a", "b"], ["output"])
315+
dag.tensorget('output')
316+
result = dag.run()
317+
expected = ['OK', 'OK', np.array([[4., 6.], [4., 6.]], dtype=np.float32)]
318+
self.assertTrue(np.allclose(expected.pop(), result.pop()))
319+
self.assertEqual(expected, result)
320+
self.assertRaises(ResponseError, con.tensorget, 'b')
321+
322+
def test_dagrun_with_persist(self):
323+
con = self.get_client()
324+
325+
dag = con.dag(persist='wrongkey') # this won't raise Error
326+
dag.tensorset('a', [2, 3, 2, 3], shape=(2, 2), dtype='float').run()
327+
328+
dag = con.dag(persist=['b'])
329+
dag.tensorset('a', [2, 3, 2, 3], shape=(2, 2), dtype='float')
330+
dag.tensorset('b', [2, 3, 2, 3], shape=(2, 2), dtype='float')
331+
dag.tensorget('b')
332+
result = dag.run()
333+
b = con.tensorget('b')
334+
self.assertTrue(np.allclose(b, result[-1]))
335+
self.assertEqual(b.dtype, np.float32)
336+
self.assertEqual(len(result), 3)
337+
338+
def test_dagrun_calling_on_return(self):
339+
con = self.get_client()
340+
con.tensorset('a', [2, 3, 2, 3], shape=(2, 2), dtype='float')
341+
result = con.\
342+
dag(load='a').\
343+
tensorset('b', [2, 3, 2, 3], shape=(2, 2), dtype='float').\
344+
modelrun("pt_model", ["a", "b"], ["output"]).\
345+
tensorget('output').\
346+
run()
347+
expected = ['OK', 'OK', np.array([[4., 6.], [4., 6.]], dtype=np.float32)]
348+
self.assertTrue(np.allclose(expected.pop(), result.pop()))
349+
self.assertEqual(expected, result)
350+
351+
def test_dagrun_without_load_and_persist(self):
352+
con = self.get_client()
353+
354+
dag = con.dag(load='wrongkey')
355+
with self.assertRaises(ResponseError):
356+
dag.tensorget('wrongkey').run()
357+
269358
dag = con.dag()
270359
dag.tensorset('a', [2, 3, 2, 3], shape=(2, 2), dtype='float')
271360
dag.tensorset('b', [2, 3, 2, 3], shape=(2, 2), dtype='float')
272361
dag.modelrun("pt_model", ["a", "b"], ["output"])
273362
dag.tensorget('output')
363+
result = dag.run()
274364
expected = ['OK', 'OK', 'OK', np.array([[4., 6.], [4., 6.]], dtype=np.float32)]
365+
self.assertTrue(np.allclose(expected.pop(), result.pop()))
366+
self.assertEqual(expected, result)
367+
368+
def test_dagrun_with_load_and_persist(self):
369+
con = self.get_client()
370+
con.tensorset('a', [2, 3, 2, 3], shape=(2, 2), dtype='float')
371+
con.tensorset('b', [2, 3, 2, 3], shape=(2, 2), dtype='float')
372+
dag = con.dag(load=['a', 'b'], persist='output')
373+
dag.modelrun("pt_model", ["a", "b"], ["output"])
374+
dag.tensorget('output')
275375
result = dag.run()
376+
expected = ['OK', np.array([[4., 6.], [4., 6.]], dtype=np.float32)]
377+
result_outside_dag = con.tensorget('output')
276378
self.assertTrue(np.allclose(expected.pop(), result.pop()))
379+
result = dag.run()
380+
self.assertTrue(np.allclose(result_outside_dag, result.pop()))
277381
self.assertEqual(expected, result)
278382

383+
def test_dagrunRO(self):
384+
con = self.get_client()
385+
con.tensorset('a', [2, 3, 2, 3], shape=(2, 2), dtype='float')
386+
con.tensorset('b', [2, 3, 2, 3], shape=(2, 2), dtype='float')
387+
with self.assertRaises(RuntimeError):
388+
con.dag(load=['a', 'b'], persist='output', readonly=True)
389+
dag = con.dag(load=['a', 'b'], readonly=True)
390+
dag.modelrun("pt_model", ["a", "b"], ["output"])
391+
dag.tensorget('output')
392+
result = dag.run()
393+
expected = ['OK', np.array([[4., 6.], [4., 6.]], dtype=np.float32)]
394+
self.assertTrue(np.allclose(expected.pop(), result.pop()))
395+
279396

0 commit comments

Comments
 (0)