Skip to content

Commit 20ab91c

Browse files
author
Tomasz Latkowski
committed
changes in tests
1 parent e552729 commit 20ab91c

File tree

4 files changed

+58
-15
lines changed

4 files changed

+58
-15
lines changed

config/main.ini

Whitespace-only changes.

methods/selection.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
11
import tensorflow as tf
22

33

4+
def selection_wrapper(data, num_instances, selection_method=None, num_features=None):
5+
if data is None:
6+
raise ValueError('Provide data to make selection.')
7+
8+
if selection_method is None:
9+
raise ValueError('Provide selection method.')
10+
11+
if num_features is None:
12+
data = tf.convert_to_tensor(data)
13+
num_features = data.get_shape().as_list()[-1]
14+
15+
values, indices = selection_method(data, num_instances, num_features)
16+
return values, tf.gather(data, indices, axis=1)
17+
18+
419
def fisher(data, num_instances: list, top_k_features=2):
520
"""
621
Performs Fisher feature selection method according to the following formula:
@@ -20,9 +35,7 @@ def fisher(data, num_instances: list, top_k_features=2):
2035
mean1, std1 = tf.nn.moments(class1, axes=0)
2136
mean2, std2 = tf.nn.moments(class2, axes=0)
2237
fisher_coeffs = tf.abs(mean1 - mean2) / (std1 + std2)
23-
values, indices = tf.nn.top_k(fisher_coeffs, k=top_k_features)
24-
most_sig_f = tf.gather(data, indices, axis=1)
25-
return most_sig_f
38+
return tf.nn.top_k(fisher_coeffs, k=top_k_features)
2639

2740

2841
def feature_correlation_with_class(data, num_instances: list, top_k_features=10):

tests/test_fisher.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import tensorflow as tf
33

4-
from methods.selection import fisher
4+
from methods.selection import fisher, selection_wrapper
55

66

77
class TestFisherSelection(tf.test.TestCase):
@@ -14,29 +14,52 @@ def testFisherCorrectScore(self):
1414
[5, 6]])
1515
num_instances = [2, 2]
1616
top_k = 2
17-
actual_most_significant_features = test_session.run(fisher(data, num_instances, top_k))
17+
actual_most_significant_features, _ = test_session.run(fisher(data, num_instances, top_k))
1818
correct_most_significant_features = [3., .5]
1919

20-
self.assertAllEqual(actual_most_significant_features.values, correct_most_significant_features)
20+
self.assertAllEqual(actual_most_significant_features, correct_most_significant_features)
2121

2222
def testFisherPickFirstSignificantFeature(self):
2323
with self.test_session() as test_session:
2424
data = np.array([[2, 2],
2525
[4, 4],
2626
[3, 6],
2727
[5, 6]])
28+
2829
num_instances = [2, 2]
2930
top_k = 1
30-
actual_most_significant_features = test_session.run(fisher(data, num_instances, top_k))
31+
_, actual_most_significant_features = test_session.run(selection_wrapper(data,
32+
num_instances,
33+
fisher,
34+
num_features=top_k))
3135
correct_most_significant_features = [[2.], [4.], [6.], [6.]]
3236

3337
self.assertAllEqual(actual_most_significant_features, correct_most_significant_features)
3438

3539
def testFisherCorrectOrderOfFeatures(self):
36-
raise NotImplementedError
40+
with self.test_session() as test_session:
41+
data = np.array([[2, 2],
42+
[4, 4],
43+
[3, 6],
44+
[5, 6]])
45+
num_instances = [2, 2]
46+
top_k = 2
47+
_, actual_most_significant_features = test_session.run(fisher(data, num_instances, top_k))
48+
correct_most_significant_features = [1., 0.]
49+
50+
self.assertAllEqual(actual_most_significant_features, correct_most_significant_features)
3751

3852
def testMoreThan2ClassesIsNotAllowed(self):
39-
raise NotImplementedError
53+
with self.test_session() as test_session:
54+
data = np.array([[2, 2],
55+
[4, 4],
56+
[3, 6],
57+
[5, 6]])
58+
num_instances = [2, 2, 2]
59+
top_k = 2
60+
with self.assertRaises(AssertionError):
61+
_, actual_most_significant_features = test_session.run(fisher(data, num_instances, top_k))
62+
4063

4164
if __name__ == '__main__':
4265
tf.test.main()

tests/test_pearson.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,32 @@ def testPearsonCoefficientValueForTwoVectors(self):
99
with self.test_session() as test_session:
1010
x1 = np.array([2., 3., 4.])
1111
x2 = np.array([3., 1., 5.])
12+
1213
actual_pearson_coefficient = test_session.run(pearson_correlation(x1, x2))
13-
correct_pearson_coefficient = tf.constant([.5])
14-
self.assertEqual(actual_pearson_coefficient, correct_pearson_coefficient.eval())
14+
correct_pearson_coefficient = [.5]
15+
16+
self.assertEqual(actual_pearson_coefficient, correct_pearson_coefficient)
1517

1618
def testNegativePearsonCoefficientValueForTwoVectors(self):
1719
with self.test_session() as test_session:
1820
x1 = np.array([1., 2., 3.])
1921
x2 = np.array([-1., -2., -3.])
22+
2023
actual_pearson_coefficient = test_session.run(pearson_correlation(x1, x2))
21-
correct_pearson_coefficient = tf.constant([-1.])
22-
self.assertEqual(actual_pearson_coefficient, correct_pearson_coefficient.eval())
24+
correct_pearson_coefficient = [-1.]
25+
26+
self.assertEqual(actual_pearson_coefficient, correct_pearson_coefficient)
2327

2428
def testPositivePearsonCoefficientValueForTwoVectors(self):
2529
with self.test_session() as test_session:
2630
x1 = np.array([1., 2., 3.])
2731
x2 = np.array([1., 2., 3.])
32+
2833
actual_pearson_coefficient = test_session.run(pearson_correlation(x1, x2))
29-
correct_pearson_coefficient = tf.constant([1.])
30-
self.assertEqual(actual_pearson_coefficient, correct_pearson_coefficient.eval())
34+
correct_pearson_coefficient = [1.]
35+
36+
self.assertEqual(actual_pearson_coefficient, correct_pearson_coefficient)
37+
3138

3239
if __name__ == '__main__':
3340
tf.test.main()

0 commit comments

Comments
 (0)