Skip to content

Commit 6345c50

Browse files
committed
Add docstrings and clean up test slightly
1 parent a99d229 commit 6345c50

File tree

4 files changed

+32
-8
lines changed

4 files changed

+32
-8
lines changed

test/scripts/dummy_audio_folder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,14 @@ def rand(self, min, max):
3535
return min + (max - min) * np.random.random() * pr.buffer_t
3636

3737
def generate_samples(self, folder, name, value, duration):
38+
"""Generate sample file.
39+
40+
The file is generated in the specified folder, with the specified name,
41+
dummy value and duration.
42+
"""
3843
for i in range(self.count):
39-
save_audio(join(folder, name.format(i)), np.array([value] * int(duration * pr.sample_rate)))
44+
save_audio(join(folder, name.format(i)),
45+
np.array([value] * int(duration * pr.sample_rate)))
4046

4147
def subdir(self, *parts):
4248
folder = self.path(*parts)

test/scripts/test_combined.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,27 @@ def read_content(filename):
2626

2727

2828
def test_combined(train_folder, train_script):
29+
"""Test a "normal" development cycle, train, evaluate and calc threshold.
30+
"""
2931
train_script.run()
3032
params_file = train_folder.model + '.params'
3133
assert isfile(train_folder.model)
3234
assert isfile(params_file)
3335

34-
EvalScript.create(folder=train_folder.root, models=[train_folder.model]).run()
36+
EvalScript.create(folder=train_folder.root,
37+
models=[train_folder.model]).run()
3538

39+
# Ensure that the graph script generates a numpy savez file
3640
out_file = train_folder.path('outputs.npz')
37-
graph_script = GraphScript.create(folder=train_folder.root, models=[train_folder.model], output_file=out_file)
41+
graph_script = GraphScript.create(folder=train_folder.root,
42+
models=[train_folder.model],
43+
output_file=out_file)
3844
graph_script.run()
3945
assert isfile(out_file)
4046

47+
# Esure the params are updated after threshold is calculated
4148
params_before = read_content(params_file)
42-
CalcThresholdScript.create(folder=train_folder.root, model=train_folder.model, input_file=out_file).run()
49+
CalcThresholdScript.create(folder=train_folder.root,
50+
model=train_folder.model,
51+
input_file=out_file).run()
4352
assert params_before != read_content(params_file)

test/scripts/test_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def __init__(self):
3636

3737

3838
def test_engine(train_folder, train_script):
39+
"""
40+
Test t hat the output format of the engina matches a decimal form in the
41+
range 0.0 - 1.0.
42+
"""
3943
train_script.run()
4044
with open(glob.glob(join(train_folder.root, 'wake-word', '*.wav'))[0], 'rb') as f:
4145
data = f.read()

test/scripts/test_train.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,20 @@
2222
class DummyTrainFolder(DummyAudioFolder):
2323
def __init__(self, count=10):
2424
super().__init__(count)
25-
self.generate_samples(self.subdir('wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2 * pr.buffer_t))
26-
self.generate_samples(self.subdir('not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2 * pr.buffer_t))
27-
self.generate_samples(self.subdir('test', 'wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2 * pr.buffer_t))
28-
self.generate_samples(self.subdir('test', 'not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2 * pr.buffer_t))
25+
self.generate_samples(self.subdir('wake-word'), 'ww-{}.wav', 1.0,
26+
self.rand(0, 2 * pr.buffer_t))
27+
self.generate_samples(self.subdir('not-wake-word'), 'nww-{}.wav', 0.0,
28+
self.rand(0, 2 * pr.buffer_t))
29+
self.generate_samples(self.subdir('test', 'wake-word'), 'ww-{}.wav',
30+
1.0, self.rand(0, 2 * pr.buffer_t))
31+
self.generate_samples(self.subdir('test', 'not-wake-word'),
32+
'nww-{}.wav', 0.0, self.rand(0, 2 * pr.buffer_t))
2933
self.model = self.path('model.net')
3034

3135

3236
class TestTrain:
3337
def test_run_basic(self):
38+
"""Run a training and check that a model is generated."""
3439
folders = DummyTrainFolder(10)
3540
script = TrainScript.create(model=folders.model, folder=folders.root)
3641
script.run()

0 commit comments

Comments
 (0)