Skip to content

Commit 93d33ac

Browse files
committed
Move get_all_env_var_names() to _model_provider
1 parent deba855 commit 93d33ac

File tree

3 files changed

+15
-19
lines changed

3 files changed

+15
-19
lines changed

llm-service/app/services/models/providers/_model_provider.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
# DATA.
3737
#
3838
import abc
39+
import itertools
3940
import os
4041

4142
from llama_index.core.base.embeddings.base import BaseEmbedding
@@ -109,3 +110,12 @@ def get_embedding_model(name: str) -> BaseEmbedding:
109110
def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor:
110111
"""Return reranking model with `name`."""
111112
raise NotImplementedError
113+
114+
115+
def get_all_env_var_names() -> set[str]:
116+
"""Return the names of all the env vars required by all model providers."""
117+
return set(
118+
itertools.chain.from_iterable(
119+
subcls.get_env_var_names() for subcls in _ModelProvider.__subclasses__()
120+
)
121+
)

llm-service/app/tests/provider_mocks/bedrock.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from app.services.caii.types import ModelResponse
5050
from app.services.models import ModelProvider
5151
from app.services.models.providers import BedrockModelProvider
52+
from app.services.models.providers._model_provider import get_all_env_var_names
5253

5354
TEXT_MODELS = [
5455
("test.unavailable-text-model-v1", "NOT_AVAILABLE"),
@@ -175,15 +176,6 @@ def list_reranking_models() -> list[ModelResponse]:
175176
yield
176177

177178

178-
def get_all_env_var_names() -> set[str]:
179-
"""Return the names of all the env vars required by all model providers."""
180-
return set(
181-
itertools.chain.from_iterable(
182-
subcls.get_env_var_names() for subcls in ModelProvider.__subclasses__()
183-
)
184-
)
185-
186-
187179
# TODO: move this test function to a discoverable place
188180
def test_bedrock(mock_bedrock, client) -> None:
189181
response = client.get("/llm-service/models/model_source")

llm-service/app/tests/services/test_models.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,10 @@
4343
from app.services.caii import caii
4444
from app.services.caii.types import ListEndpointEntry
4545
from app.services.models.providers import BedrockModelProvider
46-
from app.services.models.providers._model_provider import _ModelProvider
47-
48-
49-
def get_all_env_var_names() -> set[str]:
50-
"""Return the names of all the env vars required by all model providers."""
51-
return set(
52-
itertools.chain.from_iterable(
53-
subcls.get_env_var_names() for subcls in _ModelProvider.__subclasses__()
54-
)
55-
)
46+
from app.services.models.providers._model_provider import (
47+
_ModelProvider,
48+
get_all_env_var_names,
49+
)
5650

5751

5852
@pytest.fixture(params=_ModelProvider.__subclasses__())

0 commit comments

Comments
 (0)