Skip to content

Commit 6d49c2d

Browse files
committed
Neural network model files
Extending the neural network dumper to dump to a simple text file format, and adding reader functions to read a neural network description from a FILE *.
1 parent d2071d9 commit 6d49c2d

File tree

4 files changed

+167
-8
lines changed

4 files changed

+167
-8
lines changed

Makefile.am

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ librnnoise_la_SOURCES = \
2222
src/denoise.c \
2323
src/rnn.c \
2424
src/rnn_data.c \
25+
src/rnn_reader.c \
2526
src/pitch.c \
2627
src/kiss_fft.c \
2728
src/celt_lpc.c

include/rnnoise.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
#ifndef RNNOISE_H
2828
#define RNNOISE_H 1
2929

30+
#include <stdio.h>
31+
32+
3033
#ifndef RNNOISE_EXPORT
3134
# if defined(WIN32)
3235
# if defined(RNNOISE_BUILD) && defined(DLL_EXPORT)
@@ -41,7 +44,6 @@
4144
# endif
4245
#endif
4346

44-
4547
typedef struct DenoiseState DenoiseState;
4648
typedef struct RNNModel RNNModel;
4749

@@ -55,4 +57,8 @@ RNNOISE_EXPORT void rnnoise_destroy(DenoiseState *st);
5557

5658
RNNOISE_EXPORT float rnnoise_process_frame(DenoiseState *st, float *out, const float *in);
5759

60+
RNNOISE_EXPORT RNNModel *rnnoise_model_from_file(FILE *f);
61+
62+
RNNOISE_EXPORT void rnnoise_model_free(RNNModel *model);
63+
5864
#endif

src/rnn_reader.c

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/* Copyright (c) 2018 Gregor Richards */
2+
/*
3+
Redistribution and use in source and binary forms, with or without
4+
modification, are permitted provided that the following conditions
5+
are met:
6+
7+
- Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
10+
- Redistributions in binary form must reproduce the above copyright
11+
notice, this list of conditions and the following disclaimer in the
12+
documentation and/or other materials provided with the distribution.
13+
14+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
15+
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
16+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
17+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
18+
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19+
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20+
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
21+
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
22+
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
23+
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
*/
26+
27+
#ifdef HAVE_CONFIG_H
28+
#include "config.h"
29+
#endif
30+
31+
#include <stdio.h>
32+
#include <stdlib.h>
33+
#include <sys/types.h>
34+
35+
#include "rnn.h"
36+
#include "rnn_data.h"
37+
#include "rnnoise.h"
38+
39+
RNNModel *rnnoise_model_from_file(FILE *f)
40+
{
41+
int i, in;
42+
43+
if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
44+
return NULL;
45+
46+
RNNModel *ret = calloc(1, sizeof(RNNModel));
47+
if (!ret)
48+
return NULL;
49+
50+
#define ALLOC_LAYER(type, name) \
51+
type *name; \
52+
name = calloc(1, sizeof(type)); \
53+
if (!name) { \
54+
rnnoise_model_free(ret); \
55+
return NULL; \
56+
} \
57+
ret->name = name
58+
59+
ALLOC_LAYER(DenseLayer, input_dense);
60+
ALLOC_LAYER(GRULayer, vad_gru);
61+
ALLOC_LAYER(GRULayer, noise_gru);
62+
ALLOC_LAYER(GRULayer, denoise_gru);
63+
ALLOC_LAYER(DenseLayer, denoise_output);
64+
ALLOC_LAYER(DenseLayer, vad_output);
65+
66+
#define INPUT_VAL(name) do { \
67+
if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
68+
rnnoise_model_free(ret); \
69+
return NULL; \
70+
} \
71+
name = in; \
72+
} while (0)
73+
74+
#define INPUT_ARRAY(name, len) do { \
75+
rnn_weight *values = malloc((len) * sizeof(rnn_weight)); \
76+
if (!values) { \
77+
rnnoise_model_free(ret); \
78+
return NULL; \
79+
} \
80+
name = values; \
81+
for (i = 0; i < (len); i++) { \
82+
if (fscanf(f, "%d", &in) != 1) { \
83+
rnnoise_model_free(ret); \
84+
return NULL; \
85+
} \
86+
values[i] = in; \
87+
} \
88+
} while (0)
89+
90+
#define INPUT_DENSE(name) do { \
91+
INPUT_VAL(name->nb_inputs); \
92+
INPUT_VAL(name->nb_neurons); \
93+
INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
94+
INPUT_ARRAY(name->bias, name->nb_neurons); \
95+
} while (0)
96+
97+
#define INPUT_GRU(name) do { \
98+
INPUT_VAL(name->nb_inputs); \
99+
INPUT_VAL(name->nb_neurons); \
100+
INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons * 3); \
101+
INPUT_ARRAY(name->recurrent_weights, name->nb_neurons * name->nb_neurons * 3); \
102+
INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
103+
} while (0)
104+
105+
INPUT_DENSE(input_dense);
106+
INPUT_GRU(vad_gru);
107+
INPUT_GRU(noise_gru);
108+
INPUT_GRU(denoise_gru);
109+
INPUT_DENSE(denoise_output);
110+
INPUT_DENSE(vad_output);
111+
112+
return ret;
113+
}
114+
115+
void rnnoise_model_free(RNNModel *model)
116+
{
117+
#define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
118+
#define FREE_DENSE(name) do { \
119+
if (model->name) { \
120+
free((void *) model->name->input_weights); \
121+
free((void *) model->name->bias); \
122+
free((void *) model->name); \
123+
} \
124+
} while (0)
125+
#define FREE_GRU(name) do { \
126+
if (model->name) { \
127+
free((void *) model->name->input_weights); \
128+
free((void *) model->name->recurrent_weights); \
129+
free((void *) model->name->bias); \
130+
free((void *) model->name); \
131+
} \
132+
} while (0)
133+
134+
if (!model)
135+
return;
136+
FREE_DENSE(input_dense);
137+
FREE_GRU(vad_gru);
138+
FREE_GRU(noise_gru);
139+
FREE_GRU(denoise_gru);
140+
FREE_DENSE(denoise_output);
141+
FREE_DENSE(vad_output);
142+
free(model);
143+
}

training/dump_rnn.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,37 @@
1212
import re
1313
import numpy as np
1414

15-
def printVector(f, vector, name):
15+
def printVector(f, ft, vector, name):
1616
v = np.reshape(vector, (-1));
1717
#print('static const float ', name, '[', len(v), '] = \n', file=f)
1818
f.write('static const rnn_weight {}[{}] = {{\n '.format(name, len(v)))
1919
for i in range(0, len(v)):
2020
f.write('{}'.format(min(127, int(round(256*v[i])))))
21+
ft.write('{}'.format(min(127, int(round(256*v[i])))))
2122
if (i!=len(v)-1):
2223
f.write(',')
2324
else:
2425
break;
26+
ft.write(" ")
2527
if (i%8==7):
2628
f.write("\n ")
2729
else:
2830
f.write(" ")
2931
#print(v, file=f)
3032
f.write('\n};\n\n')
33+
ft.write("\n")
3134
return;
3235

33-
def printLayer(f, layer):
36+
def printLayer(f, ft, layer):
3437
weights = layer.get_weights()
35-
printVector(f, weights[0], layer.name + '_weights')
3638
if len(weights) > 2:
37-
printVector(f, weights[1], layer.name + '_recurrent_weights')
38-
printVector(f, weights[-1], layer.name + '_bias')
39+
ft.write('{} {}\n'.format(weights[0].shape[0], weights[0].shape[1]/3))
40+
else:
41+
ft.write('{} {}\n'.format(weights[0].shape[0], weights[0].shape[1]))
42+
printVector(f, ft, weights[0], layer.name + '_weights')
43+
if len(weights) > 2:
44+
printVector(f, ft, weights[1], layer.name + '_recurrent_weights')
45+
printVector(f, ft, weights[-1], layer.name + '_bias')
3946
name = layer.name
4047
activation = re.search('function (.*) at', str(layer.activation)).group(1).upper()
4148
if len(weights) > 2:
@@ -67,18 +74,20 @@ def mean_squared_sqrt_error(y_true, y_pred):
6774
weights = model.get_weights()
6875

6976
f = open(sys.argv[2], 'w')
77+
ft = open(sys.argv[3], 'w')
7078

7179
f.write('/*This file is automatically generated from a Keras model*/\n\n')
7280
f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n\n')
81+
ft.write('rnnoise-nu model file version 1\n')
7382

7483
layer_list = []
7584
for i, layer in enumerate(model.layers):
7685
if len(layer.get_weights()) > 0:
77-
printLayer(f, layer)
86+
printLayer(f, ft, layer)
7887
if len(layer.get_weights()) > 2:
7988
layer_list.append(layer.name)
8089

81-
f.write('const struct RNNModel rnnoise_model_{} = {{\n'.format(sys.argv[3]))
90+
f.write('const struct RNNModel rnnoise_model_{} = {{\n'.format(sys.argv[4]))
8291
for i, layer in enumerate(model.layers):
8392
if len(layer.get_weights()) > 0:
8493
structLayer(f, layer)

0 commit comments

Comments
 (0)