22import json
33import os
44import re
5+ from typing import Optional
56
67import requests
78import structlog
89import yaml
10+ from checks import CheckLoader
911from dotenv import find_dotenv , load_dotenv
10- from sklearn .metrics .pairwise import cosine_similarity
11-
12- from codegate .inference .inference_engine import LlamaCppInferenceEngine
12+ from requesters import RequesterFactory
1313
1414logger = structlog .get_logger ("codegate" )
1515
1616
1717class CodegateTestRunner :
1818 def __init__ (self ):
19- self .inference_engine = LlamaCppInferenceEngine ()
20- self .embedding_model = "codegate_volume/models/all-minilm-L6-v2-q5_k_m.gguf"
19+ self .requester_factory = RequesterFactory ()
20+
21+ def call_codegate (
22+ self , url : str , headers : dict , data : dict , provider : str
23+ ) -> Optional [requests .Response ]:
24+ logger .debug (f"Creating requester for provider: { provider } " )
25+ requester = self .requester_factory .create_requester (provider )
26+ logger .debug (f"Using requester type: { requester .__class__ .__name__ } " )
27+
28+ logger .debug (f"Making request to URL: { url } " )
29+ logger .debug (f"Headers: { headers } " )
30+ logger .debug (f"Data: { data } " )
31+
32+ response = requester .make_request (url , headers , data )
33+
34+ # Enhanced response logging
35+ if response is not None :
36+
37+ if response .status_code != 200 :
38+ logger .debug (f"Response error status: { response .status_code } " )
39+ logger .debug (f"Response error headers: { dict (response .headers )} " )
40+ try :
41+ error_content = response .json ()
42+ logger .error (f"Request error as JSON: { error_content } " )
43+ except ValueError :
44+ # If not JSON, try to get raw text
45+ logger .error (f"Raw request error: { response .text } " )
46+ else :
47+ logger .error ("No response received" )
2148
22- @staticmethod
23- def call_codegate (url , headers , data ):
24- response = None
25- try :
26- response = requests .post (url , headers = headers , json = data )
27- except Exception as e :
28- logger .exception ("An error occurred: %s" , e )
2949 return response
3050
3151 @staticmethod
@@ -50,6 +70,8 @@ def parse_response_message(response, streaming=True):
5070
5171 message_content = None
5272 if "choices" in json_line :
73+ if "finish_reason" in json_line ["choices" ][0 ]:
74+ break
5375 if "delta" in json_line ["choices" ][0 ]:
5476 message_content = json_line ["choices" ][0 ]["delta" ].get ("content" , "" )
5577 elif "text" in json_line ["choices" ][0 ]:
@@ -75,12 +97,6 @@ def parse_response_message(response, streaming=True):
7597
7698 return response_message
7799
78- async def calculate_string_similarity (self , str1 , str2 ):
79- vector1 = await self .inference_engine .embed (self .embedding_model , [str1 ])
80- vector2 = await self .inference_engine .embed (self .embedding_model , [str2 ])
81- similarity = cosine_similarity (vector1 , vector2 )
82- return similarity [0 ]
83-
84100 @staticmethod
85101 def replace_env_variables (input_string , env ):
86102 """
@@ -103,51 +119,115 @@ def replacement(match):
103119 pattern = r"ENV\w*"
104120 return re .sub (pattern , replacement , input_string )
105121
106- async def run_test (self , test , test_headers ) :
122+ async def run_test (self , test : dict , test_headers : dict ) -> None :
107123 test_name = test ["name" ]
108124 url = test ["url" ]
109125 data = json .loads (test ["data" ])
110126 streaming = data .get ("stream" , False )
111- response = CodegateTestRunner .call_codegate (url , test_headers , data )
112- expected_response = test ["expected" ]
127+ provider = test ["provider" ]
128+
129+ response = self .call_codegate (url , test_headers , data , provider )
130+ if not response :
131+ logger .error (f"Test { test_name } failed: No response received" )
132+ return
133+
134+ # Debug response info
135+ logger .debug (f"Response status: { response .status_code } " )
136+ logger .debug (f"Response headers: { dict (response .headers )} " )
137+
113138 try :
114- parsed_response = CodegateTestRunner .parse_response_message (
115- response , streaming = streaming
116- )
117- similarity = await self .calculate_string_similarity (parsed_response , expected_response )
118- if similarity < 0.8 :
119- logger .error (f"Test { test_name } failed" )
120- logger .error (f"Similarity: { similarity } " )
121- logger .error (f"Response: { parsed_response } " )
122- logger .error (f"Expected Response: { expected_response } " )
123- else :
124- logger .info (f"Test { test ['name' ]} passed" )
139+ parsed_response = self .parse_response_message (response , streaming = streaming )
140+
141+ # Load appropriate checks for this test
142+ checks = CheckLoader .load (test )
143+
144+ # Run all checks
145+ passed = True
146+ for check in checks :
147+ passed_check = await check .run_check (parsed_response , test )
148+ if not passed_check :
149+ passed = False
150+ logger .info (f"Test { test_name } passed" if passed else f"Test { test_name } failed" )
151+
125152 except Exception as e :
126153 logger .exception ("Could not parse response: %s" , e )
127154
128- async def run_tests (self , testcases_file ):
155+ async def run_tests (
156+ self ,
157+ testcases_file : str ,
158+ providers : Optional [list [str ]] = None ,
159+ test_names : Optional [list [str ]] = None ,
160+ ) -> None :
129161 with open (testcases_file , "r" ) as f :
130162 tests = yaml .safe_load (f )
131163
132164 headers = tests ["headers" ]
133- for _ , header_val in headers .items ():
134- if header_val is None :
135- continue
136- for key , val in header_val .items ():
137- header_val [key ] = CodegateTestRunner .replace_env_variables (val , os .environ )
165+ testcases = tests ["testcases" ]
138166
139- test_count = len (tests ["testcases" ])
167+ if providers or test_names :
168+ filtered_testcases = {}
140169
141- logger .info (f"Running { test_count } tests" )
142- for _ , test_data in tests ["testcases" ].items ():
170+ for test_id , test_data in testcases .items ():
171+ if providers :
172+ if test_data .get ("provider" , "" ).lower () not in [p .lower () for p in providers ]:
173+ continue
174+
175+ if test_names :
176+ if test_data .get ("name" , "" ).lower () not in [t .lower () for t in test_names ]:
177+ continue
178+
179+ filtered_testcases [test_id ] = test_data
180+
181+ testcases = filtered_testcases
182+
183+ if not testcases :
184+ filter_msg = []
185+ if providers :
186+ filter_msg .append (f"providers: { ', ' .join (providers )} " )
187+ if test_names :
188+ filter_msg .append (f"test names: { ', ' .join (test_names )} " )
189+ logger .warning (f"No tests found for { ' and ' .join (filter_msg )} " )
190+ return
191+
192+ test_count = len (testcases )
193+ filter_msg = []
194+ if providers :
195+ filter_msg .append (f"providers: { ', ' .join (providers )} " )
196+ if test_names :
197+ filter_msg .append (f"test names: { ', ' .join (test_names )} " )
198+
199+ logger .info (
200+ f"Running { test_count } tests"
201+ + (f" for { ' and ' .join (filter_msg )} " if filter_msg else "" )
202+ )
203+
204+ for test_id , test_data in testcases .items ():
143205 test_headers = headers .get (test_data ["provider" ], {})
206+ test_headers = {
207+ k : self .replace_env_variables (v , os .environ ) for k , v in test_headers .items ()
208+ }
144209 await self .run_test (test_data , test_headers )
145210
146211
147212async def main ():
148213 load_dotenv (find_dotenv ())
149214 test_runner = CodegateTestRunner ()
150- await test_runner .run_tests ("./tests/integration/testcases.yaml" )
215+
216+ # Get providers and test names from environment variables
217+ providers_env = os .environ .get ("CODEGATE_PROVIDERS" )
218+ test_names_env = os .environ .get ("CODEGATE_TEST_NAMES" )
219+
220+ providers = None
221+ if providers_env :
222+ providers = [p .strip () for p in providers_env .split ("," ) if p .strip ()]
223+
224+ test_names = None
225+ if test_names_env :
226+ test_names = [t .strip () for t in test_names_env .split ("," ) if t .strip ()]
227+
228+ await test_runner .run_tests (
229+ "./tests/integration/testcases.yaml" , providers = providers , test_names = test_names
230+ )
151231
152232
153233if __name__ == "__main__" :
0 commit comments