Skip to content

Commit 04053a5

Browse files
authored
Merge pull request #147 from MatthewScholefield/feature/documentation
Add more docstrings
2 parents 20be762 + 822b098 commit 04053a5

27 files changed

+533
-381
lines changed

precise/functions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""
15+
Mathematical functions used to customize
16+
computation in various places
17+
"""
1418
from math import exp, log, sqrt, pi
1519
import numpy as np
1620
from typing import *
@@ -20,6 +24,11 @@
2024

2125
def set_loss_bias(bias: float):
2226
"""
27+
Changes the loss bias
28+
29+
This allows customizing the acceptable tolerance between
30+
false negatives and false positives
31+
2332
Near 1.0 reduces false positives
2433
Near 0.0 reduces false negatives
2534
"""
@@ -42,6 +51,7 @@ def weighted_log_loss(yt, yp) -> Any:
4251

4352

4453
def weighted_mse_loss(yt, yp) -> Any:
54+
"""Standard mse loss with a weighting between false negatives and positives"""
4555
from keras import backend as K
4656

4757
total = K.sum(K.ones_like(yt))
@@ -52,16 +62,27 @@ def weighted_mse_loss(yt, yp) -> Any:
5262

5363

5464
def false_pos(yt, yp) -> Any:
65+
"""
66+
Metric for Keras that *estimates* false positives while training
67+
This will not be completely accurate because it weights batches
68+
equally
69+
"""
5570
from keras import backend as K
5671
return K.sum(K.cast(yp * (1 - yt) > 0.5, 'float')) / K.maximum(1.0, K.sum(1 - yt))
5772

5873

5974
def false_neg(yt, yp) -> Any:
75+
"""
76+
Metric for Keras that *estimates* false negatives while training
77+
This will not be completely accurate because it weights batches
78+
equally
79+
"""
6080
from keras import backend as K
6181
return K.sum(K.cast((1 - yp) * (0 + yt) > 0.5, 'float')) / K.maximum(1.0, K.sum(0 + yt))
6282

6383

6484
def load_keras() -> Any:
85+
"""Imports Keras injecting custom functions to prevent exceptions"""
6586
import keras
6687
keras.losses.weighted_log_loss = weighted_log_loss
6788
keras.metrics.false_pos = false_pos

precise/model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""
15+
Loads model
16+
"""
1417
import attr
1518
from os.path import isfile
1619
from typing import *
@@ -26,17 +29,20 @@
2629
class ModelParams:
2730
"""
2831
Attributes:
29-
recurrent_units:
30-
dropout:
31-
extra_metrics: Whether to include false positive and false negative metrics
32+
recurrent_units: Number of GRU units. Higher values increase computation
33+
but allow more complex learning. Too high of a value causes overfitting
34+
dropout: Reduces overfitting but can potentially decrease accuracy if too high
35+
extra_metrics: Whether to include false positive and false negative metrics while training
3236
skip_acc: Whether to skip accuracy calculation while training
37+
loss_bias: Near 1.0 reduces false positives. See <set_loss_bias>
38+
freeze_till: Layer number from start to freeze after loading (allows for partial training)
3339
"""
3440
recurrent_units = attr.ib(20) # type: int
3541
dropout = attr.ib(0.2) # type: float
3642
extra_metrics = attr.ib(False) # type: bool
3743
skip_acc = attr.ib(False) # type: bool
3844
loss_bias = attr.ib(0.7) # type: float
39-
freeze_till = attr.ib(0) # type: bool
45+
freeze_till = attr.ib(0) # type: int
4046

4147

4248
def load_precise_model(model_name: str) -> Any:
@@ -70,7 +76,8 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential'
7076
model = Sequential()
7177
model.add(GRU(
7278
params.recurrent_units, activation='linear',
73-
input_shape=(pr.n_features, pr.feature_size), dropout=params.dropout, name='net'
79+
input_shape=(
80+
pr.n_features, pr.feature_size), dropout=params.dropout, name='net'
7481
))
7582
model.add(Dense(1, activation='sigmoid'))
7683

@@ -79,5 +86,6 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential'
7986
set_loss_bias(params.loss_bias)
8087
for i in model.layers[:params.freeze_till]:
8188
i.trainable = False
82-
model.compile('rmsprop', weighted_log_loss, metrics=(not params.skip_acc) * metrics)
89+
model.compile('rmsprop', weighted_log_loss,
90+
metrics=(not params.skip_acc) * metrics)
8391
return model

precise/network_runner.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""
15+
Pieces that convert audio to predictions
16+
"""
1417
import numpy as np
1518
from abc import abstractmethod, ABCMeta
1619
from importlib import import_module
@@ -26,6 +29,10 @@
2629

2730

2831
class Runner(metaclass=ABCMeta):
32+
"""
33+
Classes that execute trained models on vectorized audio
34+
and produce prediction values
35+
"""
2936
@abstractmethod
3037
def predict(self, inputs: np.ndarray) -> np.ndarray:
3138
pass
@@ -36,6 +43,7 @@ def run(self, inp: np.ndarray) -> float:
3643

3744

3845
class TensorFlowRunner(Runner):
46+
"""Executes a frozen Tensorflow model created from precise-convert"""
3947
def __init__(self, model_name: str):
4048
if model_name.endswith('.net'):
4149
print('Warning: ', model_name, 'looks like a Keras model.')
@@ -67,6 +75,7 @@ def run(self, inp: np.ndarray) -> float:
6775

6876

6977
class KerasRunner(Runner):
78+
""" Executes a regular Keras model created from precise-train"""
7079
def __init__(self, model_name: str):
7180
import tensorflow as tf
7281
# ISSUE 88 - Following 3 lines added to resolve issue 88 - JM 2020-02-04 per liny90626

precise/params.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""
15+
Parameters used in the audio pipeline
16+
These configure the following stages:
17+
- Conversion from audio to input vectors
18+
- Interpretation of the network output to a confidence value
19+
"""
1420
from math import floor
1521

1622
import attr
@@ -21,42 +27,78 @@
2127

2228
@attr.s(frozen=True)
2329
class ListenerParams:
30+
"""
31+
General pipeline information:
32+
- Audio goes through a series of transformations to convert raw audio into machine readable data
33+
- These transformations are as follows:
34+
- Raw audio -> chopped audio
35+
- buffer_t, sample_depth: Input audio loaded and truncated using these value
36+
- window_t, hop_t: Linear audio chopped into overlapping frames using a sliding window
37+
- Chopped audio -> FFT spectrogram
38+
- n_fft, sample_rate: Each audio frame is converted to n_fft frequency intensities
39+
- FFT spectrogram -> Mel spectrogram (compressed)
40+
- n_filt: Each fft frame is compressed to n_filt summarized mel frequency bins/bands
41+
- Mel spectrogram -> MFCC
42+
- n_mfcc: Each mel frame is converted to MFCCs and the first n_mfcc values are taken
43+
- Disabled by default: Last phase -> Delta vectors
44+
- use_delta: If this value is true, the difference between consecutive vectors is concatenated to each frame
45+
46+
Parameters for audio pipeline:
47+
- buffer_t: Input size of audio. Wakeword must fit within this time
48+
- window_t: Time of the window used to calculate a single spectrogram frame
49+
- hop_t: Time the window advances forward to calculate the next spectrogram frame
50+
- sample_rate: Input audio sample rate
51+
- sample_depth: Bytes per input audio sample
52+
- n_fft: Size of FFT to generate from audio frame
53+
- n_filt: Number of filters to compress FFT to
54+
- n_mfcc: Number of MFCC coefficients to use
55+
- use_delta: If True, generates "delta vectors" before sending to network
56+
- vectorizer: The type of input fed into the network. Options listed in class Vectorizer
57+
- threshold_config: Output distribution configuration automatically generated from precise-calc-threshold
58+
- threshold_center: Output distribution center automatically generated from precise-calc-threshold
59+
"""
60+
buffer_t = attr.ib() # type: float
2461
window_t = attr.ib() # type: float
2562
hop_t = attr.ib() # type: float
26-
buffer_t = attr.ib() # type: float
2763
sample_rate = attr.ib() # type: int
2864
sample_depth = attr.ib() # type: int
29-
n_mfcc = attr.ib() # type: int
30-
n_filt = attr.ib() # type: int
3165
n_fft = attr.ib() # type: int
66+
n_filt = attr.ib() # type: int
67+
n_mfcc = attr.ib() # type: int
3268
use_delta = attr.ib() # type: bool
3369
vectorizer = attr.ib() # type: int
3470
threshold_config = attr.ib() # type: tuple
3571
threshold_center = attr.ib() # type: float
3672

3773
@property
3874
def buffer_samples(self):
75+
"""buffer_t converted to samples, truncating partial frames"""
3976
samples = int(self.sample_rate * self.buffer_t + 0.5)
4077
return self.hop_samples * (samples // self.hop_samples)
4178

4279
@property
4380
def n_features(self):
81+
"""Number of timesteps in one input to the network"""
4482
return 1 + int(floor((self.buffer_samples - self.window_samples) / self.hop_samples))
4583

4684
@property
4785
def window_samples(self):
86+
"""window_t converted to samples"""
4887
return int(self.sample_rate * self.window_t + 0.5)
4988

5089
@property
5190
def hop_samples(self):
91+
"""hop_t converted to samples"""
5292
return int(self.sample_rate * self.hop_t + 0.5)
5393

5494
@property
5595
def max_samples(self):
96+
"""The input size converted to audio samples"""
5697
return int(self.buffer_t * self.sample_rate)
5798

5899
@property
59100
def feature_size(self):
101+
"""The size of an input vector generated with these parameters"""
60102
num_features = {
61103
Vectorizer.mfccs: self.n_mfcc,
62104
Vectorizer.mels: self.n_filt,
@@ -77,15 +119,27 @@ def vectorization_md5_hash(self):
77119

78120

79121
class Vectorizer:
122+
"""
123+
Chooses which function to call to vectorize audio
124+
125+
Options:
126+
mels: Convert to a compressed Mel spectrogram
127+
mfccs: Convert to a MFCC spectrogram
128+
speechpy_mfccs: Legacy option to convert to MFCCs using old library
129+
"""
80130
mels = 1
81131
mfccs = 2
82132
speechpy_mfccs = 3
83133

84134

85135
# Global listener parameters
136+
# These are the default values for all parameters
137+
# These were selected tentatively to balance CPU usage with accuracy
138+
# For the Hey Mycroft wake word, small changes to these parameters
139+
# did not make a significant difference in accuracy
86140
pr = ListenerParams(
87-
window_t=0.1, hop_t=0.05, buffer_t=1.5, sample_rate=16000,
88-
sample_depth=2, n_mfcc=13, n_filt=20, n_fft=512, use_delta=False,
141+
buffer_t=1.5, window_t=0.1, hop_t=0.05, sample_rate=16000,
142+
sample_depth=2, n_fft=512, n_filt=20, n_mfcc=13, use_delta=False,
89143
threshold_config=((6, 4),), threshold_center=0.2, vectorizer=Vectorizer.mfccs
90144
)
91145

precise/pocketsphinx/listener.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
"""
16+
Conversion of audio data to predictions using Pocketsphinx
17+
Used for comparison with Precise
18+
"""
1519
import numpy as np
1620
from typing import *
1721
from typing import BinaryIO

precise/pocketsphinx/scripts/listen.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,24 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
"""
16+
Run Pocketsphinx on microphone audio input
17+
18+
:key_phrase str
19+
Key phrase composed of words from dictionary
20+
21+
:dict_file str
22+
Filename of dictionary with word pronunciations
23+
24+
:hmm_folder str
25+
Folder containing hidden markov model
26+
27+
:-th --threshold str 1e-90
28+
Threshold for activations
29+
30+
:-c --chunk-size int 2048
31+
Samples between inferences
32+
"""
1533
from precise_runner import PreciseRunner
1634
from precise_runner.runner import ListenerEngine
1735
from prettyparse import Usage
@@ -23,24 +41,7 @@
2341

2442

2543
class PocketsphinxListenScript(BaseScript):
26-
usage = Usage('''
27-
Run Pocketsphinx on microphone audio input
28-
29-
:key_phrase str
30-
Key phrase composed of words from dictionary
31-
32-
:dict_file str
33-
Filename of dictionary with word pronunciations
34-
35-
:hmm_folder str
36-
Folder containing hidden markov model
37-
38-
:-th --threshold str 1e-90
39-
Threshold for activations
40-
41-
:-c --chunk-size int 2048
42-
Samples between inferences
43-
''')
44+
usage = Usage(__doc__)
4445

4546
def run(self):
4647
def on_activation():

precise/pocketsphinx/scripts/test.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,29 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
"""
16+
Test a dataset using Pocketsphinx
17+
18+
:key_phrase str
19+
Key phrase composed of words from dictionary
20+
21+
:dict_file str
22+
Filename of dictionary with word pronunciations
23+
24+
:hmm_folder str
25+
Folder containing hidden markov model
26+
27+
:-th --threshold str 1e-90
28+
Threshold for activations
29+
30+
:-t --use-train
31+
Evaluate training data instead of test data
32+
33+
:-nf --no-filenames
34+
Don't show the names of files that failed
35+
36+
...
37+
"""
1538
import wave
1639
from prettyparse import Usage
1740
from subprocess import check_output, PIPE
@@ -23,29 +46,7 @@
2346

2447

2548
class PocketsphinxTestScript(BaseScript):
26-
usage = Usage('''
27-
Test a dataset using Pocketsphinx
28-
29-
:key_phrase str
30-
Key phrase composed of words from dictionary
31-
32-
:dict_file str
33-
Filename of dictionary with word pronunciations
34-
35-
:hmm_folder str
36-
Folder containing hidden markov model
37-
38-
:-th --threshold str 1e-90
39-
Threshold for activations
40-
41-
:-t --use-train
42-
Evaluate training data instead of test data
43-
44-
:-nf --no-filenames
45-
Don't show the names of files that failed
46-
47-
...
48-
''') | TrainData.usage
49+
usage = Usage(__doc__) | TrainData.usage
4950

5051
def __init__(self, args):
5152
super().__init__(args)

0 commit comments

Comments
 (0)