1- from functools import wraps
2- from typing import Union , AnyStr , ByteString , List , Sequence
1+ from functools import wraps , partial
2+ from typing import Union , AnyStr , ByteString , List , Sequence , Any
33import warnings
44
55from redis import StrictRedis
66import numpy as np
77
88from . import utils
9+ from .command_builder import Builder
10+
11+
12+ builder = Builder ()
913
1014
1115def enable_debug (f ):
@@ -16,7 +20,69 @@ def wrapper(*args):
1620 return wrapper
1721
1822
19- # TODO: typing to use AnyStr
23+ class Dag :
24+ def __init__ (self , load , persist , executor , readonly = False ):
25+ self .result_processors = []
26+ if readonly :
27+ if persist :
28+ raise RuntimeError ("READONLY requests cannot write (duh!) and should not "
29+ "have PERSISTing values" )
30+ self .commands = ['AI.DAGRUN_RO' ]
31+ else :
32+ self .commands = ['AI.DAGRUN' ]
33+ if load :
34+ if not isinstance (load , (list , tuple )):
35+ self .commands += ["LOAD" , 1 , load ]
36+ else :
37+ self .commands += ["LOAD" , len (load ), * load ]
38+ if persist :
39+ if not isinstance (persist , (list , tuple )):
40+ self .commands += ["PERSIST" , 1 , persist , '|>' ]
41+ else :
42+ self .commands += ["PERSIST" , len (persist ), * persist , '|>' ]
43+ elif load :
44+ self .commands .append ('|>' )
45+ self .executor = executor
46+
47+ def tensorset (self ,
48+ key : AnyStr ,
49+ tensor : Union [np .ndarray , list , tuple ],
50+ shape : Sequence [int ] = None ,
51+ dtype : str = None ) -> Any :
52+ args = builder .tensorset (key , tensor , shape , dtype )
53+ self .commands .extend (args )
54+ self .commands .append ("|>" )
55+ self .result_processors .append (bytes .decode )
56+ return self
57+
58+ def tensorget (self ,
59+ key : AnyStr , as_numpy : bool = True ,
60+ meta_only : bool = False ) -> Any :
61+ args = builder .tensorget (key , as_numpy , meta_only )
62+ self .commands .extend (args )
63+ self .commands .append ("|>" )
64+ self .result_processors .append (partial (utils .tensorget_postprocessor ,
65+ as_numpy ,
66+ meta_only ))
67+ return self
68+
69+ def modelrun (self ,
70+ name : AnyStr ,
71+ inputs : Union [AnyStr , List [AnyStr ]],
72+ outputs : Union [AnyStr , List [AnyStr ]]) -> Any :
73+ args = builder .modelrun (name , inputs , outputs )
74+ self .commands .extend (args )
75+ self .commands .append ("|>" )
76+ self .result_processors .append (bytes .decode )
77+ return self
78+
79+ def run (self ):
80+ results = self .executor (* self .commands )
81+ out = []
82+ for res , fn in zip (results , self .result_processors ):
83+ out .append (fn (res ))
84+ return out
85+
2086
2187class Client (StrictRedis ):
2288 """
@@ -27,6 +93,11 @@ def __init__(self, debug=False, *args, **kwargs):
2793 if debug :
2894 self .execute_command = enable_debug (super ().execute_command )
2995
96+ def dag (self , load : Sequence = None , persist : Sequence = None ,
97+ readonly : bool = False ) -> Dag :
98+ """ Special function to return a dag object """
99+ return Dag (load , persist , self .execute_command , readonly )
100+
30101 def loadbackend (self , identifier : AnyStr , path : AnyStr ) -> str :
31102 """
32103 RedisAI by default won't load any backends. User can either explicitly
@@ -37,7 +108,8 @@ def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
37108 :param path: Path to the shared object of the backend
38109 :return: byte string represents success or failure
39110 """
40- return self .execute_command ('AI.CONFIG LOADBACKEND' , identifier , path ).decode ()
111+ args = builder .loadbackend (identifier , path )
112+ return self .execute_command (* args ).decode ()
41113
42114 def modelset (self ,
43115 name : AnyStr ,
@@ -46,9 +118,9 @@ def modelset(self,
46118 data : ByteString ,
47119 batch : int = None ,
48120 minbatch : int = None ,
49- tag : str = None ,
50- inputs : List [AnyStr ] = None ,
51- outputs : List [AnyStr ] = None ) -> str :
121+ tag : AnyStr = None ,
122+ inputs : Union [ AnyStr , List [AnyStr ] ] = None ,
123+ outputs : Union [ AnyStr , List [AnyStr ] ] = None ) -> str :
52124 """
53125 Set the model on provided key.
54126 :param name: str, Key name
@@ -66,50 +138,32 @@ def modelset(self,
66138
67139 :return:
68140 """
69- args = ['AI.MODELSET' , name , backend , device ]
70-
71- if batch is not None :
72- args += ['BATCHSIZE' , batch ]
73- if minbatch is not None :
74- args += ['MINBATCHSIZE' , minbatch ]
75- if tag is not None :
76- args += ['TAG' , tag ]
77-
78- if backend .upper () == 'TF' :
79- if not (all ((inputs , outputs ))):
80- raise ValueError (
81- 'Require keyword arguments input and output for TF models' )
82- args += ['INPUTS' ] + utils .listify (inputs )
83- args += ['OUTPUTS' ] + utils .listify (outputs )
84- args .append (data )
141+ args = builder .modelset (name , backend , device , data ,
142+ batch , minbatch , tag , inputs , outputs )
85143 return self .execute_command (* args ).decode ()
86144
87145 def modelget (self , name : AnyStr , meta_only = False ) -> dict :
88- args = ['AI.MODELGET' , name , 'META' ]
89- if not meta_only :
90- args .append ('BLOB' )
146+ args = builder .modelget (name , meta_only )
91147 rv = self .execute_command (* args )
92148 return utils .list2dict (rv )
93149
94150 def modeldel (self , name : AnyStr ) -> str :
95- return self .execute_command ('AI.MODELDEL' , name ).decode ()
151+ args = builder .modeldel (name )
152+ return self .execute_command (* args ).decode ()
96153
97154 def modelrun (self ,
98155 name : AnyStr ,
99- inputs : List [AnyStr ],
100- outputs : List [AnyStr ]
101- ) -> str :
102- out = self .execute_command (
103- 'AI.MODELRUN' , name ,
104- 'INPUTS' , * utils .listify (inputs ),
105- 'OUTPUTS' , * utils .listify (outputs )
106- )
107- return out .decode ()
156+ inputs : Union [AnyStr , List [AnyStr ]],
157+ outputs : Union [AnyStr , List [AnyStr ]]) -> str :
158+ args = builder .modelrun (name , inputs , outputs )
159+ return self .execute_command (* args ).decode ()
108160
109161 def modelscan (self ) -> list :
110162 warnings .warn ("Experimental: Model List API is experimental and might change "
111163 "in the future without any notice" , UserWarning )
112- return utils .un_bytize (self .execute_command ("AI._MODELSCAN" ), lambda x : x .decode ())
164+ args = builder .modelscan ()
165+ result = self .execute_command (* args )
166+ return utils .recursive_bytetransform (result , lambda x : x .decode ())
113167
114168 def tensorset (self ,
115169 key : AnyStr ,
@@ -123,20 +177,11 @@ def tensorset(self,
123177 :param shape: Shape of the tensor. Required if `tensor` is list or tuple
124178 :param dtype: data type of the tensor. Required if `tensor` is list or tuple
125179 """
126- if np and isinstance (tensor , np .ndarray ):
127- dtype , shape , blob = utils .numpy2blob (tensor )
128- args = ['AI.TENSORSET' , key , dtype , * shape , 'BLOB' , blob ]
129- elif isinstance (tensor , (list , tuple )):
130- if shape is None :
131- shape = (len (tensor ),)
132- args = ['AI.TENSORSET' , key , dtype , * shape , 'VALUES' , * tensor ]
133- else :
134- raise TypeError (f"``tensor`` argument must be a numpy array or a list or a "
135- f"tuple, but got { type (tensor )} " )
180+ args = builder .tensorset (key , tensor , shape , dtype )
136181 return self .execute_command (* args ).decode ()
137182
138183 def tensorget (self ,
139- key : str , as_numpy : bool = True ,
184+ key : AnyStr , as_numpy : bool = True ,
140185 meta_only : bool = False ) -> Union [dict , np .ndarray ]:
141186 """
142187 Retrieve the value of a tensor from the server. By default it returns the numpy array
@@ -149,63 +194,45 @@ def tensorget(self,
149194 only the shape and the type
150195 :return: an instance of as_type
151196 """
152- args = ['AI.TENSORGET' , key , 'META' ]
153- if not meta_only :
154- if as_numpy is True :
155- args .append ('BLOB' )
156- else :
157- args .append ('VALUES' )
158-
197+ args = builder .tensorget (key , as_numpy , meta_only )
159198 res = self .execute_command (* args )
160- res = utils .list2dict (res )
161- if meta_only :
162- return res
163- elif as_numpy is True :
164- return utils .blob2numpy (res ['blob' ], res ['shape' ], res ['dtype' ])
165- else :
166- target = float if res ['dtype' ] in ('FLOAT' , 'DOUBLE' ) else int
167- utils .un_bytize (res ['values' ], target )
168- return res
169-
170- def scriptset (self , name : str , device : str , script : str , tag : str = None ) -> str :
171- args = ['AI.SCRIPTSET' , name , device ]
172- if tag :
173- args += ['TAG' , tag ]
174- args .append (script )
199+ return utils .tensorget_postprocessor (as_numpy , meta_only , res )
200+
201+ def scriptset (self , name : AnyStr , device : str , script : str , tag : AnyStr = None ) -> str :
202+ args = builder .scriptset (name , device , script , tag )
175203 return self .execute_command (* args ).decode ()
176204
177205 def scriptget (self , name : AnyStr , meta_only = False ) -> dict :
178206 # TODO scripget test
179- args = ['AI.SCRIPTGET' , name , 'META' ]
180- if not meta_only :
181- args .append ('SOURCE' )
207+ args = builder .scriptget (name , meta_only )
182208 ret = self .execute_command (* args )
183209 return utils .list2dict (ret )
184210
185- def scriptdel (self , name : str ) -> str :
186- return self .execute_command ('AI.SCRIPTDEL' , name ).decode ()
211+ def scriptdel (self , name : AnyStr ) -> str :
212+ args = builder .scriptdel (name )
213+ return self .execute_command (* args ).decode ()
187214
188215 def scriptrun (self ,
189216 name : AnyStr ,
190217 function : AnyStr ,
191218 inputs : Union [AnyStr , Sequence [AnyStr ]],
192219 outputs : Union [AnyStr , Sequence [AnyStr ]]
193- ) -> AnyStr :
194- out = self .execute_command (
195- 'AI.SCRIPTRUN' , name , function ,
196- 'INPUTS' , * utils .listify (inputs ),
197- 'OUTPUTS' , * utils .listify (outputs )
198- )
220+ ) -> str :
221+ args = builder .scriptrun (name , function , inputs , outputs )
222+ out = self .execute_command (* args )
199223 return out .decode ()
200224
201225 def scriptscan (self ) -> list :
202226 warnings .warn ("Experimental: Script List API is experimental and might change "
203227 "in the future without any notice" , UserWarning )
204- return utils .un_bytize (self .execute_command ("AI._SCRIPTSCAN" ), lambda x : x .decode ())
228+ args = builder .scriptscan ()
229+ return utils .recursive_bytetransform (self .execute_command (* args ), lambda x : x .decode ())
205230
206- def infoget (self , key : str ) -> dict :
207- ret = self .execute_command ('AI.INFO' , key )
231+ def infoget (self , key : AnyStr ) -> dict :
232+ args = builder .infoget (key )
233+ ret = self .execute_command (* args )
208234 return utils .list2dict (ret )
209235
210- def inforeset (self , key : str ) -> str :
211- return self .execute_command ('AI.INFO' , key , 'RESETSTAT' ).decode ()
236+ def inforeset (self , key : AnyStr ) -> str :
237+ args = builder .inforeset (key )
238+ return self .execute_command (* args ).decode ()
0 commit comments