Skip to content

Commit e552729

Browse files
author
Tomasz Latkowski
committed
fixed bug with num of features
1 parent 9a474cd commit e552729

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

methods/selection.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import pandas as pd
21
import tensorflow as tf
32

4-
data_file = '../data/autism.tsv'
5-
df = pd.read_csv(data_file, sep='\t', header=None, index_col=0).T
63

7-
8-
def fisher(data, num_instances: list, top_k_features=10):
4+
def fisher(data, num_instances: list, top_k_features=2):
95
"""
106
Performs Fisher feature selection method according to the following formula:
117
D(f) = (m1(f) - m2(f) / (std1(f) - std2(f))
@@ -18,13 +14,15 @@ def fisher(data, num_instances: list, top_k_features=10):
1814
assert len(num_instances) == 2, "Fisher selection method can be performed for two-class problems."
1915
data = tf.convert_to_tensor(data)
2016
_, num_features = data.get_shape().as_list()
21-
if top_k_features < num_features:
17+
if top_k_features > num_features:
2218
top_k_features = num_features
2319
class1, class2 = tf.split(data, num_instances)
2420
mean1, std1 = tf.nn.moments(class1, axes=0)
2521
mean2, std2 = tf.nn.moments(class2, axes=0)
2622
fisher_coeffs = tf.abs(mean1 - mean2) / (std1 + std2)
27-
return tf.nn.top_k(fisher_coeffs, k=top_k_features)
23+
values, indices = tf.nn.top_k(fisher_coeffs, k=top_k_features)
24+
most_sig_f = tf.gather(data, indices, axis=1)
25+
return most_sig_f
2826

2927

3028
def feature_correlation_with_class(data, num_instances: list, top_k_features=10):
@@ -35,7 +33,7 @@ def feature_correlation_with_class(data, num_instances: list, top_k_features=10)
3533
"""
3634
data = tf.convert_to_tensor(data)
3735
_, num_features = data.get_shape().as_list()
38-
if top_k_features < num_features:
36+
if top_k_features > num_features:
3937
top_k_features = num_features
4038
class1, class2 = tf.split(data, num_instances)
4139
mean1, std1 = tf.nn.moments(class1, axes=0)
@@ -53,20 +51,10 @@ def t_test(data, num_instances: list, top_k_features=10):
5351
"""
5452
data = tf.convert_to_tensor(data)
5553
_, num_features = data.get_shape().as_list()
56-
if top_k_features < num_features:
54+
if top_k_features > num_features:
5755
top_k_features = num_features
5856
class1, class2 = tf.split(data, num_instances)
5957
mean1, std1 = tf.nn.moments(class1, axes=0)
6058
mean2, std2 = tf.nn.moments(class2, axes=0)
6159
t_test_coeffs = tf.abs(mean1 - mean2) / tf.sqrt(tf.square(std1)/num_instances[0] + tf.square(std2) / num_instances[1])
6260
return tf.nn.top_k(t_test_coeffs, k=top_k_features)
63-
64-
with tf.Session() as session:
65-
input_data = df.as_matrix()
66-
instances_per_class = [82, 64]
67-
fisher_coeffs = session.run(fisher(data=input_data, num_instances=instances_per_class, top_k_features=5))
68-
corr_coeffs = session.run(feature_correlation_with_class(data=input_data, num_instances=instances_per_class, top_k_features=5))
69-
t_test_coeff = session.run(t_test(data=input_data, num_instances=instances_per_class, top_k_features=5))
70-
print(fisher_coeffs)
71-
print(corr_coeffs)
72-
print(t_test_coeff)

tests/test_fisher.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,22 @@ def testFisherCorrectScore(self):
1515
num_instances = [2, 2]
1616
top_k = 2
1717
actual_most_significant_features = test_session.run(fisher(data, num_instances, top_k))
18+
correct_most_significant_features = [3., .5]
1819

19-
correct_most_significant_features = tf.constant([3., .5])
20-
self.assertAllEqual(actual_most_significant_features.values, correct_most_significant_features.eval())
20+
self.assertAllEqual(actual_most_significant_features.values, correct_most_significant_features)
21+
22+
def testFisherPickFirstSignificantFeature(self):
23+
with self.test_session() as test_session:
24+
data = np.array([[2, 2],
25+
[4, 4],
26+
[3, 6],
27+
[5, 6]])
28+
num_instances = [2, 2]
29+
top_k = 1
30+
actual_most_significant_features = test_session.run(fisher(data, num_instances, top_k))
31+
correct_most_significant_features = [[2.], [4.], [6.], [6.]]
32+
33+
self.assertAllEqual(actual_most_significant_features, correct_most_significant_features)
2134

2235
def testFisherCorrectOrderOfFeatures(self):
2336
raise NotImplementedError

0 commit comments

Comments
 (0)