Skip to content

Commit 55d207b

Browse files
committed
Start mocking Bedrock model-calling endpoints
1 parent de1d580 commit 55d207b

File tree

2 files changed

+81
-27
lines changed

2 files changed

+81
-27
lines changed

llm-service/app/tests/conftest.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,16 @@ def get_datasource_metadata(data_source_id: int) -> RagDataSource:
175175
)
176176

177177

178-
@pytest.fixture(autouse=True)
179-
def embedding_model(monkeypatch: pytest.MonkeyPatch) -> None:
180-
model = DummyEmbeddingModel()
181-
monkeypatch.setattr(models.Embedding, "get", lambda cls, model_name=None: model)
182-
183-
184-
@pytest.fixture(autouse=True)
185-
def llm(monkeypatch: pytest.MonkeyPatch) -> None:
186-
model = models.LLM.get_noop()
187-
monkeypatch.setattr(models.LLM, "get", lambda cls, model_name=None: model)
178+
# @pytest.fixture(autouse=True)
179+
# def embedding_model(monkeypatch: pytest.MonkeyPatch) -> None:
180+
# model = DummyEmbeddingModel()
181+
# monkeypatch.setattr(models.Embedding, "get", lambda cls, model_name=None: model)
182+
#
183+
#
184+
# @pytest.fixture(autouse=True)
185+
# def llm(monkeypatch: pytest.MonkeyPatch) -> None:
186+
# model = models.LLM.get_noop()
187+
# monkeypatch.setattr(models.LLM, "get", lambda cls, model_name=None: model)
188188

189189

190190
@pytest.fixture

llm-service/app/tests/model_provider_mocks/bedrock.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38+
import io
39+
import json
40+
import random
3841
from contextlib import AbstractContextManager
3942
from typing import Iterator, Callable, Any
4043
from unittest.mock import patch
@@ -55,15 +58,15 @@
5558
from .utils import patch_env_vars
5659

5760
TEXT_MODELS = [
58-
("test.unavailable-text-model-v1", "NOT_AVAILABLE"),
59-
("test.available-text-model-v1", "AVAILABLE"),
61+
("amazon.unavailable-text-model-v1", "NOT_AVAILABLE"),
62+
("amazon.available-text-model-v1", "AVAILABLE"),
6063
]
6164
EMBEDDING_MODELS = [
62-
("test.unavailable-embedding-model-v1", "NOT_AVAILABLE"),
63-
("test.available-embedding-model-v1", "AVAILABLE"),
65+
("amazon.unavailable-embedding-model-v1", "NOT_AVAILABLE"),
66+
("amazon.available-embedding-model-v1", "AVAILABLE"),
6467
]
6568
RERANKING_MODELS = [
66-
("test.available-reranking-model-v1", "AVAILABLE"),
69+
("amazon.available-reranking-model-v1", "AVAILABLE"),
6770
]
6871

6972

@@ -152,7 +155,44 @@ def mock_make_api_call(
152155
for model_id, _ in TEXT_MODELS + EMBEDDING_MODELS
153156
],
154157
}
155-
158+
elif operation_name == "InvokeModel":
159+
texts: list[str] = json.loads(api_params["body"])["inputText"]
160+
return {
161+
"contentType": "application/json",
162+
# TODO: does this need to be botocore.response.StreamingBody?
163+
"body": io.BytesIO(
164+
json.dumps(
165+
{
166+
"texts": texts,
167+
"embeddings": [
168+
[random.gauss(mu=0.0, sigma=0.1) for _ in range(16)]
169+
for _ in texts
170+
],
171+
}
172+
).encode()
173+
),
174+
}
175+
elif operation_name == "Converse":
176+
return {
177+
"output": {
178+
"message": {
179+
"role": "assistant",
180+
"content": [{"text": "\n\nTest response."}],
181+
}
182+
},
183+
"stopReason": "end_turn",
184+
# "usage": {"inputTokens": 21, "outputTokens": 75, "totalTokens": 96},
185+
# "metrics": {"latencyMs": 827},
186+
}
187+
elif operation_name == "Rerank":
188+
return {
189+
"results": [
190+
# TODO: Is the document store checked prior to this? Do I need to mock that too?
191+
{"index": 0, "relevanceScore": random.random()},
192+
{"index": 1, "relevanceScore": random.random()},
193+
{"index": 2, "relevanceScore": random.random()},
194+
]
195+
}
156196
else:
157197
# passthrough
158198
return make_api_call(self, operation_name, api_params)
@@ -187,31 +227,45 @@ def test_bedrock_models(client: TestClient) -> None:
187227
assert response.status_code == 200
188228
assert response.json() == "Bedrock"
189229

190-
response = client.get("/llm-service/models/embeddings")
191-
assert response.status_code == 200
192-
assert [model["model_id"] for model in response.json()] == [
230+
available_embedding_models = [
193231
model_id
194232
for model_id, availability in EMBEDDING_MODELS
195233
if availability == "AVAILABLE"
196234
]
197-
198-
response = client.get("/llm-service/models/llm")
235+
response = client.get("/llm-service/models/embeddings")
199236
assert response.status_code == 200
200-
assert [model["model_id"] for model in response.json()] == [
237+
assert [
238+
model["model_id"] for model in response.json()
239+
] == available_embedding_models
240+
# for model_id in available_embedding_models:
241+
# response = client.get(f"/llm-service/models/embedding/{model_id}/test")
242+
# assert response.status_code == 200 # TODO
243+
244+
available_text_models = [
201245
model_id
202246
for model_id, availability in TEXT_MODELS
203247
if availability == "AVAILABLE"
204248
]
205-
206-
response = client.get("/llm-service/models/reranking")
249+
response = client.get("/llm-service/models/llm")
207250
assert response.status_code == 200
208-
assert [model["model_id"] for model in response.json()] == [
251+
assert [model["model_id"] for model in response.json()] == available_text_models
252+
for model_id in available_text_models:
253+
response = client.get(f"/llm-service/models/llm/{model_id}/test")
254+
assert response.status_code == 200
255+
256+
available_reranking_models = [
209257
model_id
210258
for model_id, availability in RERANKING_MODELS
211259
if availability == "AVAILABLE"
212260
]
213-
214-
# response = client.get("/llm-service/models/embedding/cohere.embed-english-v3/test")
261+
response = client.get("/llm-service/models/reranking")
262+
assert response.status_code == 200
263+
assert [
264+
model["model_id"] for model in response.json()
265+
] == available_reranking_models
266+
# for model_id in available_reranking_models:
267+
# response = client.get(f"/llm-service/models/reranking/{model_id}/test")
268+
# assert response.status_code == 200 # TODO
215269

216270

217271
def test_bedrock_sessions(client: TestClient) -> None:

0 commit comments

Comments
 (0)