|
35 | 35 | # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF |
36 | 36 | # DATA. |
37 | 37 | # |
| 38 | +import io |
| 39 | +import json |
| 40 | +import random |
38 | 41 | from contextlib import AbstractContextManager |
39 | 42 | from typing import Iterator, Callable, Any |
40 | 43 | from unittest.mock import patch |
|
55 | 58 | from .utils import patch_env_vars |
56 | 59 |
|
57 | 60 | TEXT_MODELS = [ |
58 | | - ("test.unavailable-text-model-v1", "NOT_AVAILABLE"), |
59 | | - ("test.available-text-model-v1", "AVAILABLE"), |
| 61 | + ("amazon.unavailable-text-model-v1", "NOT_AVAILABLE"), |
| 62 | + ("amazon.available-text-model-v1", "AVAILABLE"), |
60 | 63 | ] |
61 | 64 | EMBEDDING_MODELS = [ |
62 | | - ("test.unavailable-embedding-model-v1", "NOT_AVAILABLE"), |
63 | | - ("test.available-embedding-model-v1", "AVAILABLE"), |
| 65 | + ("amazon.unavailable-embedding-model-v1", "NOT_AVAILABLE"), |
| 66 | + ("amazon.available-embedding-model-v1", "AVAILABLE"), |
64 | 67 | ] |
65 | 68 | RERANKING_MODELS = [ |
66 | | - ("test.available-reranking-model-v1", "AVAILABLE"), |
| 69 | + ("amazon.available-reranking-model-v1", "AVAILABLE"), |
67 | 70 | ] |
68 | 71 |
|
69 | 72 |
|
@@ -152,7 +155,44 @@ def mock_make_api_call( |
152 | 155 | for model_id, _ in TEXT_MODELS + EMBEDDING_MODELS |
153 | 156 | ], |
154 | 157 | } |
155 | | - |
| 158 | + elif operation_name == "InvokeModel": |
| 159 | + texts: list[str] = json.loads(api_params["body"])["inputText"] |
| 160 | + return { |
| 161 | + "contentType": "application/json", |
| 162 | + # TODO: does this need to be botocore.response.StreamingBody? |
| 163 | + "body": io.BytesIO( |
| 164 | + json.dumps( |
| 165 | + { |
| 166 | + "texts": texts, |
| 167 | + "embeddings": [ |
| 168 | + [random.gauss(mu=0.0, sigma=0.1) for _ in range(16)] |
| 169 | + for _ in texts |
| 170 | + ], |
| 171 | + } |
| 172 | + ).encode() |
| 173 | + ), |
| 174 | + } |
| 175 | + elif operation_name == "Converse": |
| 176 | + return { |
| 177 | + "output": { |
| 178 | + "message": { |
| 179 | + "role": "assistant", |
| 180 | + "content": [{"text": "\n\nTest response."}], |
| 181 | + } |
| 182 | + }, |
| 183 | + "stopReason": "end_turn", |
| 184 | + # "usage": {"inputTokens": 21, "outputTokens": 75, "totalTokens": 96}, |
| 185 | + # "metrics": {"latencyMs": 827}, |
| 186 | + } |
| 187 | + elif operation_name == "Rerank": |
| 188 | + return { |
| 189 | + "results": [ |
| 190 | + # TODO: Is the document store checked prior to this? Do I need to mock that too? |
| 191 | + {"index": 0, "relevanceScore": random.random()}, |
| 192 | + {"index": 1, "relevanceScore": random.random()}, |
| 193 | + {"index": 2, "relevanceScore": random.random()}, |
| 194 | + ] |
| 195 | + } |
156 | 196 | else: |
157 | 197 | # passthrough |
158 | 198 | return make_api_call(self, operation_name, api_params) |
@@ -187,31 +227,45 @@ def test_bedrock_models(client: TestClient) -> None: |
187 | 227 | assert response.status_code == 200 |
188 | 228 | assert response.json() == "Bedrock" |
189 | 229 |
|
190 | | - response = client.get("/llm-service/models/embeddings") |
191 | | - assert response.status_code == 200 |
192 | | - assert [model["model_id"] for model in response.json()] == [ |
| 230 | + available_embedding_models = [ |
193 | 231 | model_id |
194 | 232 | for model_id, availability in EMBEDDING_MODELS |
195 | 233 | if availability == "AVAILABLE" |
196 | 234 | ] |
197 | | - |
198 | | - response = client.get("/llm-service/models/llm") |
| 235 | + response = client.get("/llm-service/models/embeddings") |
199 | 236 | assert response.status_code == 200 |
200 | | - assert [model["model_id"] for model in response.json()] == [ |
| 237 | + assert [ |
| 238 | + model["model_id"] for model in response.json() |
| 239 | + ] == available_embedding_models |
| 240 | + # for model_id in available_embedding_models: |
| 241 | + # response = client.get(f"/llm-service/models/embedding/{model_id}/test") |
| 242 | + # assert response.status_code == 200 # TODO |
| 243 | + |
| 244 | + available_text_models = [ |
201 | 245 | model_id |
202 | 246 | for model_id, availability in TEXT_MODELS |
203 | 247 | if availability == "AVAILABLE" |
204 | 248 | ] |
205 | | - |
206 | | - response = client.get("/llm-service/models/reranking") |
| 249 | + response = client.get("/llm-service/models/llm") |
207 | 250 | assert response.status_code == 200 |
208 | | - assert [model["model_id"] for model in response.json()] == [ |
| 251 | + assert [model["model_id"] for model in response.json()] == available_text_models |
| 252 | + for model_id in available_text_models: |
| 253 | + response = client.get(f"/llm-service/models/llm/{model_id}/test") |
| 254 | + assert response.status_code == 200 |
| 255 | + |
| 256 | + available_reranking_models = [ |
209 | 257 | model_id |
210 | 258 | for model_id, availability in RERANKING_MODELS |
211 | 259 | if availability == "AVAILABLE" |
212 | 260 | ] |
213 | | - |
214 | | - # response = client.get("/llm-service/models/embedding/cohere.embed-english-v3/test") |
| 261 | + response = client.get("/llm-service/models/reranking") |
| 262 | + assert response.status_code == 200 |
| 263 | + assert [ |
| 264 | + model["model_id"] for model in response.json() |
| 265 | + ] == available_reranking_models |
| 266 | + # for model_id in available_reranking_models: |
| 267 | + # response = client.get(f"/llm-service/models/reranking/{model_id}/test") |
| 268 | + # assert response.status_code == 200 # TODO |
215 | 269 |
|
216 | 270 |
|
217 | 271 | def test_bedrock_sessions(client: TestClient) -> None: |
|
0 commit comments