@@ -106,18 +106,58 @@ the algorithm must predict the label (which its positive or negative) for the po
106106 classifier .train (Arrays .asList (pointA , pointB , pointC , pointD ));
107107 List <Neighbor > similarNeighbors = classifier .similarNeighbors (pointE , 2 );
108108
109- Neighbor n1 = new Neighbor (new LabeledInstance (null , pointA .getModel ()), 0d , null );
110- Neighbor n2 = new Neighbor (new LabeledInstance (null , pointB .getModel ()), 0d , null );
109+ Neighbor n1 = new Neighbor (new LabeledInstance (negativeLabel , pointA .getModel ()), 0d , null );
110+ Neighbor n2 = new Neighbor (new LabeledInstance (negativeLabel , pointB .getModel ()), 0d , null );
111111
112112 Truth .assertThat (similarNeighbors )
113113 .containsAllIn (Arrays .asList (n1 , n2 ));
114114 }
115115
116+ @ Test
117+ public void when_trained_with_k_fold_it_should_predict_a_positive_label (){
118+ /*
119+ given a set of negative points:
120+ - A(2,4); B(3,2); C(4,4)
121+ and a set of positive points:
122+ - D(4,1); E(5,5); F(6,3)
123+ the algorithm must predict the label (which its positive or negative) for the point G(10,7)
124+ */
125+
126+ String positiveLabel = "positive" ;
127+ String negativeLabel = "negative" ;
128+
129+ LabeledInstance pointA = new LabeledInstance (negativeLabel , new TestModel (null , Arrays .asList (2d , 4d )));
130+ LabeledInstance pointB = new LabeledInstance (negativeLabel , new TestModel (null , Arrays .asList (3d , 2d )));
131+ LabeledInstance pointC = new LabeledInstance (negativeLabel , new TestModel (null , Arrays .asList (4d , 4d )));
132+
133+ LabeledInstance pointD = new LabeledInstance (positiveLabel , new TestModel (null , Arrays .asList (4d , 1d )));
134+ LabeledInstance pointE = new LabeledInstance (positiveLabel , new TestModel (null , Arrays .asList (5d , 5d )));
135+ LabeledInstance pointF = new LabeledInstance (positiveLabel , new TestModel (null , Arrays .asList (6d , 3d )));
136+
137+ classifier .setK (3 );
138+ classifier .train (Arrays .asList (pointA , pointB , pointC , pointD , pointE , pointF ), 3 );
139+
140+ double scoreExpected = Math .sqrt (29 )/100 ;
141+ Prediction predictedInstance = new Prediction (positiveLabel , scoreExpected );
142+
143+ Prediction predictedInstance1 = classifier .predict (pointF );
144+ Truth .assertThat (predictedInstance .getLabel ())
145+ .isEqualTo (positiveLabel );
146+ Truth .assertThat (predictedInstance .getScore ())
147+ .isEqualTo (scoreExpected );
148+ }
149+
150+
151+
116152 @ Test (expected = IllegalArgumentException .class )
117153 public void when_similarNeighbors_its_called_with_null_neighbors_args_it_should_raise_an_exception (){
118154 classifier .similarNeighbors (null , 10 );
119155 }
120156
157+ @ Test (expected = IllegalArgumentException .class )
158+ public void it_should_raise_an_exception_when_predict_list_method_its_called (){
159+ classifier .predict (Collections .emptyList ());
160+ }
121161}
122162
123163
0 commit comments