Skip to content

Commit a40520c

Browse files
author
hhsecond
committed
review changes
1 parent 46c0ca3 commit a40520c

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

redisai/client.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
np = None
88

99
from .constants import Backend, Device, DType
10-
from .utils import str_or_strlist, to_string
10+
from .utils import str_or_strsequence, to_string
1111
from .tensor import Tensor, BlobTensor
1212

1313

@@ -32,16 +32,16 @@ def modelset(self,
3232
backend: Backend,
3333
device: Device,
3434
data: ByteString,
35-
input: Union[AnyStr, Collection[AnyStr], None] = None,
36-
output: Union[AnyStr, Collection[AnyStr], None] = None
35+
inputs: Union[AnyStr, Collection[AnyStr], None] = None,
36+
outputs: Union[AnyStr, Collection[AnyStr], None] = None
3737
) -> AnyStr:
3838
args = ['AI.MODELSET', name, backend.value, device.value]
3939
if backend == Backend.tf:
40-
if not(all((input, output))):
40+
if not(all((inputs, outputs))):
4141
raise ValueError(
4242
'Require keyword arguments input and output for TF models')
43-
args += ['INPUTS'] + str_or_strlist(input)
44-
args += ['OUTPUTS'] + str_or_strlist(output)
43+
args += ['INPUTS'] + str_or_strsequence(inputs)
44+
args += ['OUTPUTS'] + str_or_strsequence(outputs)
4545
args += [data]
4646
return self.execute_command(*args)
4747

@@ -58,12 +58,12 @@ def modeldel(self, name: AnyStr) -> AnyStr:
5858

5959
def modelrun(self,
6060
name: AnyStr,
61-
input: Union[AnyStr, Collection[AnyStr]],
62-
output: Union[AnyStr, Collection[AnyStr]]
61+
inputs: Union[AnyStr, Collection[AnyStr]],
62+
outputs: Union[AnyStr, Collection[AnyStr]]
6363
) -> AnyStr:
6464
args = ['AI.MODELRUN', name]
65-
args += ['INPUTS'] + str_or_strlist(input)
66-
args += ['OUTPUTS'] + str_or_strlist(output)
65+
args += ['INPUTS'] + str_or_strsequence(inputs)
66+
args += ['OUTPUTS'] + str_or_strsequence(outputs)
6767
return self.execute_command(*args)
6868

6969
def tensorset(self,
@@ -75,16 +75,16 @@ def tensorset(self,
7575
Set the values of the tensor on the server using the provided Tensor object
7676
:param key: The name of the tensor
7777
:param tensor: a `Tensor` object
78-
:param shape: Shape of the tensor. Required if input is a sequence of ints/floats
78+
:param shape: Shape of the tensor
7979
:param dtype: data type of the tensor. Required if input is a sequence of ints/floats
8080
"""
8181
# TODO: tensorset will not accept BlobTensor or Tensor object in the future.
8282
# Keeping it in the current version for compatibility with the example repo
8383
if np and isinstance(tensor, np.ndarray):
8484
tensor = BlobTensor.from_numpy(tensor)
85-
elif hasattr(tensor, 'shape') and hasattr(tensor, 'dtype'):
86-
raise TypeError('Numpy is not installed but the input tensor seem to be a numpy array')
8785
elif isinstance(tensor, (list, tuple)):
86+
if shape is None:
87+
shape = (len(tensor),)
8888
tensor = Tensor(dtype, shape, tensor)
8989
args = ['AI.TENSORSET', key, tensor.type.value]
9090
args += tensor.shape
@@ -139,11 +139,11 @@ def scriptdel(self, name):
139139
def scriptrun(self,
140140
name: AnyStr,
141141
function: AnyStr,
142-
input: Union[AnyStr, Collection[AnyStr]],
143-
output: Union[AnyStr, Collection[AnyStr]]
142+
inputs: Union[AnyStr, Collection[AnyStr]],
143+
outputs: Union[AnyStr, Collection[AnyStr]]
144144
) -> AnyStr:
145145
args = ['AI.SCRIPTRUN', name, function, 'INPUTS']
146-
args += str_or_strlist(input)
146+
args += str_or_strsequence(inputs)
147147
args += ['OUTPUTS']
148-
args += str_or_strlist(output)
148+
args += str_or_strsequence(outputs)
149149
return self.execute_command(*args)

redisai/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@ def to_string(s):
1111
return s # Not a string we care about
1212

1313

14-
def str_or_strlist(v):
15-
if isinstance(v, six.string_types):
16-
return [v]
14+
def str_or_strsequence(v):
15+
if not isinstance(v, (list, tuple)):
16+
if isinstance(v, six.string_types):
17+
return [v]
18+
else:
19+
raise TypeError('Argument must be a string, list or a tuple')
1720
return v
1821

1922

0 commit comments

Comments
 (0)