3636# DATA.
3737#
3838import itertools
39+ from contextlib import AbstractContextManager
3940from typing import Generator
4041from unittest .mock import patch
4142from urllib .parse import urljoin
4748
4849from app .config import settings
4950from app .services .caii .types import ModelResponse
50- from app .services .models import ModelProvider
5151from app .services .models .providers import BedrockModelProvider
5252from app .services .models .providers ._model_provider import get_all_env_var_names
53+ from .utils import patch_env_vars
5354
5455TEXT_MODELS = [
5556 ("test.unavailable-text-model-v1" , "NOT_AVAILABLE" ),
6465]
6566
6667
67- @pytest .fixture
68- def mock_bedrock (monkeypatch ) -> Generator [None , None , None ]:
69- for name in BedrockModelProvider .get_env_var_names ():
70- monkeypatch .setenv (name , "test" )
71- for name in get_all_env_var_names () - BedrockModelProvider .get_env_var_names ():
72- monkeypatch .delenv (name , raising = False )
73-
74- # mock calls made directly through `requests`
68+ def _patch_requests () -> AbstractContextManager :
7569 bedrock_url_base = f"https://bedrock.{ settings .aws_default_region } .amazonaws.com/"
7670 r_mock = responses .RequestsMock (assert_all_requests_are_fired = False )
7771 for model_id , availability in TEXT_MODELS + EMBEDDING_MODELS :
@@ -92,7 +86,10 @@ def mock_bedrock(monkeypatch) -> Generator[None, None, None]:
9286 },
9387 )
9488
95- # mock calls made through `boto3`
89+ return r_mock
90+
91+
92+ def _patch_boto3 () -> AbstractContextManager :
9693 make_api_call = botocore .client .BaseClient ._make_api_call
9794
9895 def mock_make_api_call (self , operation_name : str , api_params : dict [str , str ]):
@@ -151,29 +148,28 @@ def mock_make_api_call(self, operation_name: str, api_params: dict[str, str]):
151148 # passthrough
152149 return make_api_call (self , operation_name , api_params )
153150
154- # mock reranking models, which are hard-coded in our app
155- def list_reranking_models () -> list [ModelResponse ]:
156- return [
157- ModelResponse (model_id = model_id , name = model_id .upper ())
158- for model_id , _ in RERANKING_MODELS
159- ]
160-
161- with (
162- r_mock ,
163- patch (
164- "botocore.client.BaseClient._make_api_call" ,
165- new = mock_make_api_call ,
166- ),
167- patch (
168- "app.services.models.providers.BedrockModelProvider.list_reranking_models" ,
169- new = list_reranking_models ,
170- ),
171- patch ( # work around a llama-index filter we have in list_llm_models()
172- "app.services.models.providers.bedrock.BEDROCK_MODELS" ,
173- new = BEDROCK_MODELS | {model_id : 128000 for model_id , _ in TEXT_MODELS },
174- ),
175- ):
176- yield
151+ return patch ("botocore.client.BaseClient._make_api_call" , new = mock_make_api_call )
152+
153+
154+ @pytest .fixture
155+ def mock_bedrock (monkeypatch ) -> Generator [None , None , None ]:
156+ with patch_env_vars (BedrockModelProvider ):
157+ with (
158+ _patch_requests (),
159+ _patch_boto3 (),
160+ patch ( # mock reranking models, which are hard-coded in our app
161+ "app.services.models.providers.BedrockModelProvider.list_reranking_models" ,
162+ new = lambda : [
163+ ModelResponse (model_id = model_id , name = model_id .upper ())
164+ for model_id , _ in RERANKING_MODELS
165+ ],
166+ ),
167+ patch ( # work around a llama-index filter we have in list_llm_models()
168+ "app.services.models.providers.bedrock.BEDROCK_MODELS" ,
169+ new = BEDROCK_MODELS | {model_id : 128000 for model_id , _ in TEXT_MODELS },
170+ ),
171+ ):
172+ yield
177173
178174
179175# TODO: move this test function to a discoverable place
0 commit comments