11import numpy as np
22import tensorflow as tf
33
4- from methods .selection import fisher
4+ from methods .selection import fisher , selection_wrapper
55
66
77class TestFisherSelection (tf .test .TestCase ):
@@ -14,29 +14,52 @@ def testFisherCorrectScore(self):
1414 [5 , 6 ]])
1515 num_instances = [2 , 2 ]
1616 top_k = 2
17- actual_most_significant_features = test_session .run (fisher (data , num_instances , top_k ))
17+ actual_most_significant_features , _ = test_session .run (fisher (data , num_instances , top_k ))
1818 correct_most_significant_features = [3. , .5 ]
1919
20- self .assertAllEqual (actual_most_significant_features . values , correct_most_significant_features )
20+ self .assertAllEqual (actual_most_significant_features , correct_most_significant_features )
2121
2222 def testFisherPickFirstSignificantFeature (self ):
2323 with self .test_session () as test_session :
2424 data = np .array ([[2 , 2 ],
2525 [4 , 4 ],
2626 [3 , 6 ],
2727 [5 , 6 ]])
28+
2829 num_instances = [2 , 2 ]
2930 top_k = 1
30- actual_most_significant_features = test_session .run (fisher (data , num_instances , top_k ))
31+ _ , actual_most_significant_features = test_session .run (selection_wrapper (data ,
32+ num_instances ,
33+ fisher ,
34+ num_features = top_k ))
3135 correct_most_significant_features = [[2. ], [4. ], [6. ], [6. ]]
3236
3337 self .assertAllEqual (actual_most_significant_features , correct_most_significant_features )
3438
3539 def testFisherCorrectOrderOfFeatures (self ):
36- raise NotImplementedError
40+ with self .test_session () as test_session :
41+ data = np .array ([[2 , 2 ],
42+ [4 , 4 ],
43+ [3 , 6 ],
44+ [5 , 6 ]])
45+ num_instances = [2 , 2 ]
46+ top_k = 2
47+ _ , actual_most_significant_features = test_session .run (fisher (data , num_instances , top_k ))
48+ correct_most_significant_features = [1. , 0. ]
49+
50+ self .assertAllEqual (actual_most_significant_features , correct_most_significant_features )
3751
3852 def testMoreThan2ClassesIsNotAllowed (self ):
39- raise NotImplementedError
53+ with self .test_session () as test_session :
54+ data = np .array ([[2 , 2 ],
55+ [4 , 4 ],
56+ [3 , 6 ],
57+ [5 , 6 ]])
58+ num_instances = [2 , 2 , 2 ]
59+ top_k = 2
60+ with self .assertRaises (AssertionError ):
61+ _ , actual_most_significant_features = test_session .run (fisher (data , num_instances , top_k ))
62+
4063
4164if __name__ == '__main__' :
4265 tf .test .main ()
0 commit comments