@@ -35,16 +35,22 @@ def printVector(f, ft, vector, name):
3535
3636def printLayer (f , ft , layer ):
3737 weights = layer .get_weights ()
38+ activation = re .search ('function (.*) at' , str (layer .activation )).group (1 ).upper ()
3839 if len (weights ) > 2 :
39- ft .write ('{} {}\n ' .format (weights [0 ].shape [0 ], weights [0 ].shape [1 ]/ 3 ))
40+ ft .write ('{} {} ' .format (weights [0 ].shape [0 ], weights [0 ].shape [1 ]/ 3 ))
41+ else :
42+ ft .write ('{} {} ' .format (weights [0 ].shape [0 ], weights [0 ].shape [1 ]))
43+ if activation == 'SIGMOID' :
44+ ft .write ('1\n ' )
45+ elif activation == 'RELU' :
46+ ft .write ('2\n ' )
4047 else :
41- ft .write ('{} {} \n ' . format ( weights [ 0 ]. shape [ 0 ], weights [ 0 ]. shape [ 1 ]) )
48+ ft .write ('0 \n ' )
4249 printVector (f , ft , weights [0 ], layer .name + '_weights' )
4350 if len (weights ) > 2 :
4451 printVector (f , ft , weights [1 ], layer .name + '_recurrent_weights' )
4552 printVector (f , ft , weights [- 1 ], layer .name + '_bias' )
4653 name = layer .name
47- activation = re .search ('function (.*) at' , str (layer .activation )).group (1 ).upper ()
4854 if len (weights ) > 2 :
4955 f .write ('static const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n }};\n \n '
5056 .format (name , name , name , name , weights [0 ].shape [0 ], weights [0 ].shape [1 ]/ 3 , activation ))
@@ -77,7 +83,7 @@ def mean_squared_sqrt_error(y_true, y_pred):
7783ft = open (sys .argv [3 ], 'w' )
7884
7985f .write ('/*This file is automatically generated from a Keras model*/\n \n ' )
80- f .write ('#ifdef HAVE_CONFIG_H\n #include "config.h"\n #endif\n \n #include "rnn.h"\n \n ' )
86+ f .write ('#ifdef HAVE_CONFIG_H\n #include "config.h"\n #endif\n \n #include "rnn.h"\n #include "rnn_data.h" \n \n ' )
8187ft .write ('rnnoise-nu model file version 1\n ' )
8288
8389layer_list = []
0 commit comments