Skip to content

Commit 2731f37

Browse files
committed
Refactor, modularize
1 parent da9c049 commit 2731f37

File tree

3 files changed

+99
-51
lines changed

3 files changed

+99
-51
lines changed

llm-service/app/tests/provider_mocks/bedrock.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
# DATA.
3737
#
3838
import itertools
39+
from contextlib import AbstractContextManager
3940
from typing import Generator
4041
from unittest.mock import patch
4142
from urllib.parse import urljoin
@@ -47,9 +48,9 @@
4748

4849
from app.config import settings
4950
from app.services.caii.types import ModelResponse
50-
from app.services.models import ModelProvider
5151
from app.services.models.providers import BedrockModelProvider
5252
from app.services.models.providers._model_provider import get_all_env_var_names
53+
from .utils import patch_env_vars
5354

5455
TEXT_MODELS = [
5556
("test.unavailable-text-model-v1", "NOT_AVAILABLE"),
@@ -64,14 +65,7 @@
6465
]
6566

6667

67-
@pytest.fixture
68-
def mock_bedrock(monkeypatch) -> Generator[None, None, None]:
69-
for name in BedrockModelProvider.get_env_var_names():
70-
monkeypatch.setenv(name, "test")
71-
for name in get_all_env_var_names() - BedrockModelProvider.get_env_var_names():
72-
monkeypatch.delenv(name, raising=False)
73-
74-
# mock calls made directly through `requests`
68+
def _patch_requests() -> AbstractContextManager:
7569
bedrock_url_base = f"https://bedrock.{settings.aws_default_region}.amazonaws.com/"
7670
r_mock = responses.RequestsMock(assert_all_requests_are_fired=False)
7771
for model_id, availability in TEXT_MODELS + EMBEDDING_MODELS:
@@ -92,7 +86,10 @@ def mock_bedrock(monkeypatch) -> Generator[None, None, None]:
9286
},
9387
)
9488

95-
# mock calls made through `boto3`
89+
return r_mock
90+
91+
92+
def _patch_boto3() -> AbstractContextManager:
9693
make_api_call = botocore.client.BaseClient._make_api_call
9794

9895
def mock_make_api_call(self, operation_name: str, api_params: dict[str, str]):
@@ -151,29 +148,28 @@ def mock_make_api_call(self, operation_name: str, api_params: dict[str, str]):
151148
# passthrough
152149
return make_api_call(self, operation_name, api_params)
153150

154-
# mock reranking models, which are hard-coded in our app
155-
def list_reranking_models() -> list[ModelResponse]:
156-
return [
157-
ModelResponse(model_id=model_id, name=model_id.upper())
158-
for model_id, _ in RERANKING_MODELS
159-
]
160-
161-
with (
162-
r_mock,
163-
patch(
164-
"botocore.client.BaseClient._make_api_call",
165-
new=mock_make_api_call,
166-
),
167-
patch(
168-
"app.services.models.providers.BedrockModelProvider.list_reranking_models",
169-
new=list_reranking_models,
170-
),
171-
patch( # work around a llama-index filter we have in list_llm_models()
172-
"app.services.models.providers.bedrock.BEDROCK_MODELS",
173-
new=BEDROCK_MODELS | {model_id: 128000 for model_id, _ in TEXT_MODELS},
174-
),
175-
):
176-
yield
151+
return patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call)
152+
153+
154+
@pytest.fixture
155+
def mock_bedrock(monkeypatch) -> Generator[None, None, None]:
156+
with patch_env_vars(BedrockModelProvider):
157+
with (
158+
_patch_requests(),
159+
_patch_boto3(),
160+
patch( # mock reranking models, which are hard-coded in our app
161+
"app.services.models.providers.BedrockModelProvider.list_reranking_models",
162+
new=lambda: [
163+
ModelResponse(model_id=model_id, name=model_id.upper())
164+
for model_id, _ in RERANKING_MODELS
165+
],
166+
),
167+
patch( # work around a llama-index filter we have in list_llm_models()
168+
"app.services.models.providers.bedrock.BEDROCK_MODELS",
169+
new=BEDROCK_MODELS | {model_id: 128000 for model_id, _ in TEXT_MODELS},
170+
),
171+
):
172+
yield
177173

178174

179175
# TODO: move this test function to a discoverable place
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#
2+
# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
3+
# (C) Cloudera, Inc. 2025
4+
# All rights reserved.
5+
#
6+
# Applicable Open Source License: Apache 2.0
7+
#
8+
# NOTE: Cloudera open source products are modular software products
9+
# made up of hundreds of individual components, each of which was
10+
# individually copyrighted. Each Cloudera open source product is a
11+
# collective work under U.S. Copyright Law. Your license to use the
12+
# collective work is as provided in your written agreement with
13+
# Cloudera. Used apart from the collective work, this file is
14+
# licensed for your use pursuant to the open source license
15+
# identified above.
16+
#
17+
# This code is provided to you pursuant a written agreement with
18+
# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
19+
# this code. If you do not have a written agreement with Cloudera nor
20+
# with an authorized and properly licensed third party, you do not
21+
# have any rights to access nor to use this code.
22+
#
23+
# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the
24+
# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
25+
# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
26+
# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
27+
# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
28+
# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
29+
# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
30+
# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
31+
# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
32+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
33+
# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
34+
# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
35+
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
36+
# DATA.
37+
#
38+
import contextlib
39+
from typing import Iterator
40+
41+
from pytest import MonkeyPatch
42+
43+
from app.config import MODEL_PROVIDER_ENV_VAR_NAME
44+
from app.services.models.providers._model_provider import (
45+
_ModelProvider,
46+
get_all_env_var_names,
47+
)
48+
49+
50+
@contextlib.contextmanager
51+
def patch_env_vars(ModelProviderSubcls: type[_ModelProvider]) -> Iterator[None]:
52+
"""Set and unset environment variables for the given model provider."""
53+
with MonkeyPatch.context() as monkeypatch:
54+
for name in get_all_env_var_names():
55+
monkeypatch.delenv(name, raising=False)
56+
for name in ModelProviderSubcls.get_env_var_names():
57+
monkeypatch.setenv(name, "test")
58+
monkeypatch.setenv(
59+
MODEL_PROVIDER_ENV_VAR_NAME,
60+
ModelProviderSubcls.get_model_source(),
61+
)
62+
yield

llm-service/app/tests/services/test_models.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,37 +35,27 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38+
from typing import Iterator
39+
3840
import pytest
3941

40-
from app.config import MODEL_PROVIDER_ENV_VAR_NAME
4142
from app.services import models
4243
from app.services.caii import caii
4344
from app.services.caii.types import ListEndpointEntry
4445
from app.services.models.providers import BedrockModelProvider
45-
from app.services.models.providers._model_provider import (
46-
_ModelProvider,
47-
get_all_env_var_names,
48-
)
46+
from app.services.models.providers._model_provider import _ModelProvider
47+
from app.tests.provider_mocks.utils import patch_env_vars
4948

5049

5150
@pytest.fixture(params=_ModelProvider.__subclasses__())
5251
def EnabledModelProvider(
5352
request: pytest.FixtureRequest,
5453
monkeypatch: pytest.MonkeyPatch,
55-
) -> type[_ModelProvider]:
56-
"""Sets and unsets environment variables for the given model provider."""
54+
) -> Iterator[type[_ModelProvider]]:
55+
"""Parametrize a test to run for each supported model provider."""
5756
ModelProviderSubcls: type[_ModelProvider] = request.param
58-
59-
for name in get_all_env_var_names():
60-
monkeypatch.delenv(name, raising=False)
61-
for name in ModelProviderSubcls.get_env_var_names():
62-
monkeypatch.setenv(name, "test")
63-
monkeypatch.setenv(
64-
MODEL_PROVIDER_ENV_VAR_NAME,
65-
ModelProviderSubcls.get_model_source(),
66-
)
67-
68-
return ModelProviderSubcls
57+
with patch_env_vars(ModelProviderSubcls):
58+
yield ModelProviderSubcls
6959

7060

7161
class TestListAvailableModels:

0 commit comments

Comments
 (0)