diff --git a/backend/python/rerankers/backend.py b/backend/python/rerankers/backend.py index aadb5b9afcae..8ce2636d7a13 100755 --- a/backend/python/rerankers/backend.py +++ b/backend/python/rerankers/backend.py @@ -75,12 +75,13 @@ def Rerank(self, request, context): documents.append(doc) ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents)))) # Prepare results to return + cropped_results = ranked_results.top_k(request.top_n) if request.top_n > 0 else ranked_results results = [ backend_pb2.DocumentResult( index=res.doc_id, text=res.text, relevance_score=res.score - ) for res in ranked_results.top_k(request.top_n) + ) for res in (cropped_results) ] # Calculate the usage and total tokens diff --git a/backend/python/rerankers/test.py b/backend/python/rerankers/test.py index 3f2ddf0b7700..f5890fc25d24 100755 --- a/backend/python/rerankers/test.py +++ b/backend/python/rerankers/test.py @@ -76,7 +76,35 @@ def test_rerank(self): ) response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder")) self.assertTrue(response.success) - + + rerank_response = stub.Rerank(request) + print(rerank_response.results[0]) + self.assertIsNotNone(rerank_response.results) + self.assertEqual(len(rerank_response.results), 2) + self.assertEqual(rerank_response.results[0].text, "I really like you") + self.assertEqual(rerank_response.results[1].text, "I hate you") + except Exception as err: + print(err) + self.fail("Reranker service failed") + finally: + self.tearDown() + + def test_rerank_omit_top_n(self): + """ + This method tests if the embeddings are generated successfully even top_n is omitted + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + request = backend_pb2.RerankRequest( + query="I love you", + documents=["I hate you", "I really like you"], + top_n=0 # + ) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder")) + self.assertTrue(response.success) + rerank_response = stub.Rerank(request) print(rerank_response.results[0]) self.assertIsNotNone(rerank_response.results) @@ -91,7 +119,7 @@ def test_rerank(self): def test_rerank_crop(self): """ - This method tests if the embeddings are generated successfully + This method tests top_n cropping """ try: self.setUp() @@ -104,7 +132,7 @@ def test_rerank_crop(self): ) response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder")) self.assertTrue(response.success) - + rerank_response = stub.Rerank(request) print(rerank_response.results[0]) self.assertIsNotNone(rerank_response.results) @@ -115,4 +143,4 @@ def test_rerank_crop(self): print(err) self.fail("Reranker service failed") finally: - self.tearDown() \ No newline at end of file + self.tearDown()