1- import pickle
21import os
32import warnings
4-
5- from .client import Device , Backend
3+ import sys
64
75try :
86 import tensorflow as tf
97except (ModuleNotFoundError , ImportError ):
10- pass # that's Okey if you don't have TF
8+ pass
119
1210try :
1311 import torch
1412except (ModuleNotFoundError , ImportError ):
15- pass # it's Okey if you don't have PT either
13+ pass
1614
15+ try :
16+ import onnx
17+ except (ModuleNotFoundError , ImportError ):
18+ pass
19+
20+ try :
21+ import skl2onnx
22+ import sklearn
23+ except (ModuleNotFoundError , ImportError ):
24+ pass
1725
1826
1927class Model :
2028
2129 __slots__ = ['graph' , 'backend' , 'device' , 'inputs' , 'outputs' ]
22-
23- def __init__ (self , path , device = Device . cpu , inputs = None , outputs = None ):
30+
31+ def __init__ (self , path , device = None , inputs = None , outputs = None ):
2432 """
2533 Declare a model suitable for passing to modelset
2634 :param path: Filepath from where the stored model can be read
@@ -37,9 +45,9 @@ def __init__(self, path, device=Device.cpu, inputs=None, outputs=None):
3745 raise NotImplementedError ('Instance creation is not impelemented yet' )
3846
3947 @classmethod
40- def save (cls , obj , path : str , input = None , output = None , as_native = True ):
48+ def save (cls , obj , path : str , input = None , output = None , as_native = True , prototype = None ):
4149 """
42- Infer the backend (TF/PyTorch) by inspecting the class hierarchy
50+ Infer the backend (TF/PyTorch/ONNX ) by inspecting the class hierarchy
4351 and calls the appropriate serialization utility. It is essentially a
4452 wrapper over serialization mechanism of each backend
4553 :param path: Path to which the graph/model will be saved
@@ -54,15 +62,25 @@ def save(cls, obj, path: str, input=None, output=None, as_native=True):
5462 mechanism if True. If False, custom saving utility will be called
5563 which saves other informations required for modelset. Defaults to True
5664 """
57- if issubclass (type (obj ), tf .Session ):
65+ if 'tensorflow' in sys . modules and issubclass (type (obj ), tf .Session ):
5866 cls ._save_tf_graph (obj , path , output , as_native )
59- elif issubclass (type (type (obj )), torch .jit .ScriptMeta ):
67+ elif 'torch' in sys .modules and issubclass (
68+ type (type (obj )), torch .jit .ScriptMeta ):
6069 # TODO Is there a better way to check this
61- cls ._save_pt_graph (obj , path , as_native )
70+ cls ._save_torch_graph (obj , path , as_native )
71+ elif 'onnx' in sys .modules and issubclass (
72+ type (obj ), onnx .onnx_ONNX_RELEASE_ml_pb2 .ModelProto ):
73+ cls ._save_onnx_graph (obj , path , as_native )
74+ elif 'skl2onnx' in sys .modules and issubclass (
75+ type (obj ), sklearn .base .BaseEstimator ):
76+ cls ._save_sklearn_graph (obj , path , as_native , prototype )
6277 else :
63- raise TypeError (('Invalid Object. '
64- 'Need traced graph or scripted graph from PyTorch or '
65- 'Session object from Tensorflow' ))
78+ message = ("Could not find the required dependancy to export the graph object. "
79+ "`save_model` relies on serialization mechanism provided by the"
80+ " supported backends such as Tensorflow, PyTorch, ONNX or skl2onnx. "
81+ "Please install package required for serializing your graph. "
82+ "For more information, checkout the redisia-py documentation" )
83+ raise RuntimeError (message )
6684
6785 @classmethod
6886 def _save_tf_graph (cls , sess , path , output , as_native ):
@@ -81,10 +99,10 @@ def _save_tf_graph(cls, sess, path, output, as_native):
8199 raise NotImplementedError ('Saving non-native graph is not supported yet' )
82100
83101 @classmethod
84- def _save_pt_graph (cls , graph , path , as_native ):
102+ def _save_torch_graph (cls , graph , path , as_native ):
85103 # TODO how to handle the cpu/gpu
86104 if as_native :
87- if graph .training == True :
105+ if graph .training is True :
88106 warnings .warn (
89107 'Graph is in training mode. Converting to evaluation mode' )
90108 graph .eval ()
@@ -93,25 +111,33 @@ def _save_pt_graph(cls, graph, path, as_native):
93111 else :
94112 raise NotImplementedError ('Saving non-native graph is not supported yet' )
95113
96- @staticmethod
97- def _get_filled_dict ( graph , backend , input = None , output = None ):
98- return {
99- 'graph' : graph ,
100- 'backend' : backend ,
101- 'input' : input ,
102- 'output' : output }
114+ @classmethod
115+ def _save_onnx_graph ( cls , graph , path , as_native ):
116+ if as_native :
117+ with open ( path , 'wb' ) as f :
118+ f . write ( graph . SerializeToString ())
119+ else :
120+ raise NotImplementedError ( 'Saving non-native graph is not supported yet' )
103121
104- @staticmethod
105- def _write_custom_model (outdict , path ):
106- with open (path , 'wb' ) as file :
107- pickle .dump (outdict , file )
122+ @classmethod
123+ def _save_sklearn_graph (cls , graph , path , as_native , prototype ):
124+ if not as_native :
125+ raise NotImplementedError ('Saving non-native graph is not supported yet' )
126+ if hasattr (prototype , 'shape' ) and hasattr (prototype , 'dtype' ):
127+ datatype = skl2onnx .common .data_types .guess_data_type (prototype )
128+ serialized = skl2onnx .convert_sklearn (graph , initial_types = datatype )
129+ cls ._save_onnx_graph (serialized , path , as_native )
130+ else :
131+ raise TypeError (
132+ "Serializing scikit learn model needs to know shape and dtype"
133+ " of input data which will be inferred from `prototype` "
134+ "parameter. It has to be a valid `numpy.ndarray` of shape of your input" )
108135
109136 @classmethod
110- def load (cls , path :str ):
137+ def load (cls , path : str ):
111138 """
112139 Return the binary data if saved with `as_native` otherwise return the dict
113- that contains binary graph/model on `graph` key. Check `_get_filled_dict`
114- for more details.
140+ that contains binary graph/model on `graph` key (Not implemented yet).
115141 :param path: File path from where the native model or the rai models are saved
116142 """
117143 with open (path , 'rb' ) as f :
0 commit comments