Skip to content

Commit 61f587a

Browse files
committed
Some fixes to make model files work.
1 parent 6d49c2d commit 61f587a

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

src/rnn_reader.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
#include "rnn_data.h"
3737
#include "rnnoise.h"
3838

39+
/* Although these values are the same as in rnn.h, we make them separate to
40+
* avoid accidentally burning internal values into a file format */
41+
#define F_ACTIVATION_TANH 0
42+
#define F_ACTIVATION_SIGMOID 1
43+
#define F_ACTIVATION_RELU 2
44+
3945
RNNModel *rnnoise_model_from_file(FILE *f)
4046
{
4147
int i, in;
@@ -71,6 +77,21 @@ RNNModel *rnnoise_model_from_file(FILE *f)
7177
name = in; \
7278
} while (0)
7379

80+
#define INPUT_ACTIVATION(name) do { \
81+
int activation; \
82+
INPUT_VAL(activation); \
83+
switch (activation) { \
84+
case F_ACTIVATION_SIGMOID: \
85+
name = ACTIVATION_SIGMOID; \
86+
break; \
87+
case F_ACTIVATION_RELU: \
88+
name = ACTIVATION_RELU; \
89+
break; \
90+
default: \
91+
name = ACTIVATION_TANH; \
92+
} \
93+
} while (0)
94+
7495
#define INPUT_ARRAY(name, len) do { \
7596
rnn_weight *values = malloc((len) * sizeof(rnn_weight)); \
7697
if (!values) { \
@@ -90,13 +111,15 @@ RNNModel *rnnoise_model_from_file(FILE *f)
90111
#define INPUT_DENSE(name) do { \
91112
INPUT_VAL(name->nb_inputs); \
92113
INPUT_VAL(name->nb_neurons); \
114+
INPUT_ACTIVATION(name->activation); \
93115
INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
94116
INPUT_ARRAY(name->bias, name->nb_neurons); \
95117
} while (0)
96118

97119
#define INPUT_GRU(name) do { \
98120
INPUT_VAL(name->nb_inputs); \
99121
INPUT_VAL(name->nb_neurons); \
122+
INPUT_ACTIVATION(name->activation); \
100123
INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons * 3); \
101124
INPUT_ARRAY(name->recurrent_weights, name->nb_neurons * name->nb_neurons * 3); \
102125
INPUT_ARRAY(name->bias, name->nb_neurons * 3); \

training/dump_rnn.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,22 @@ def printVector(f, ft, vector, name):
3535

3636
def printLayer(f, ft, layer):
3737
weights = layer.get_weights()
38+
activation = re.search('function (.*) at', str(layer.activation)).group(1).upper()
3839
if len(weights) > 2:
39-
ft.write('{} {}\n'.format(weights[0].shape[0], weights[0].shape[1]/3))
40+
ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1]/3))
41+
else:
42+
ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1]))
43+
if activation == 'SIGMOID':
44+
ft.write('1\n')
45+
elif activation == 'RELU':
46+
ft.write('2\n')
4047
else:
41-
ft.write('{} {}\n'.format(weights[0].shape[0], weights[0].shape[1]))
48+
ft.write('0\n')
4249
printVector(f, ft, weights[0], layer.name + '_weights')
4350
if len(weights) > 2:
4451
printVector(f, ft, weights[1], layer.name + '_recurrent_weights')
4552
printVector(f, ft, weights[-1], layer.name + '_bias')
4653
name = layer.name
47-
activation = re.search('function (.*) at', str(layer.activation)).group(1).upper()
4854
if len(weights) > 2:
4955
f.write('static const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n'
5056
.format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation))
@@ -77,7 +83,7 @@ def mean_squared_sqrt_error(y_true, y_pred):
7783
ft = open(sys.argv[3], 'w')
7884

7985
f.write('/*This file is automatically generated from a Keras model*/\n\n')
80-
f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n\n')
86+
f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n#include "rnn_data.h"\n\n')
8187
ft.write('rnnoise-nu model file version 1\n')
8288

8389
layer_list = []

0 commit comments

Comments
 (0)