diff --git a/llm-service/app/ai/indexing/summary_indexer.py b/llm-service/app/ai/indexing/summary_indexer.py index d89b1d51c..83a18d974 100644 --- a/llm-service/app/ai/indexing/summary_indexer.py +++ b/llm-service/app/ai/indexing/summary_indexer.py @@ -73,10 +73,9 @@ from .base import BaseTextIndexer from .readers.base_reader import ReaderConfig, ChunksResult from ..vector_stores.vector_store_factory import VectorStoreFactory -from ...config import settings +from ...config import settings, ModelSource from ...services.metadata_apis import data_sources_metadata_api -from ...services.models.providers import ModelProvider -from ...services.models import ModelSource +from ...services.models.providers import get_provider_class logger = logging.getLogger(__name__) @@ -133,9 +132,7 @@ def __index_configuration( embed_summaries: bool = True, ) -> Dict[str, Any]: prompt_helper: Optional[PromptHelper] = None - model_source: ModelSource = ( - ModelProvider.get_provider_class().get_model_source() - ) + model_source: ModelSource = get_provider_class().get_model_source() if model_source == "CAII": # if we're using CAII, let's be conservative and use a small context window to account for mistral's small context prompt_helper = PromptHelper(context_window=3000) diff --git a/llm-service/app/config.py b/llm-service/app/config.py index 85aa60430..ff882223b 100644 --- a/llm-service/app/config.py +++ b/llm-service/app/config.py @@ -46,6 +46,7 @@ import logging import os.path +from enum import Enum from typing import cast, Optional, Literal @@ -53,7 +54,13 @@ ChatStoreProviderType = Literal["Local", "S3"] VectorDbProviderType = Literal["QDRANT", "OPENSEARCH"] MetadataDbProviderType = Literal["H2", "PostgreSQL"] -ModelProviderType = Literal["Azure", "CAII", "OpenAI", "Bedrock"] + + +class ModelSource(str, Enum): + AZURE = "Azure" + OPENAI = "OpenAI" + BEDROCK = "Bedrock" + CAII = "CAII" class _Settings: @@ -185,14 +192,15 @@ def openai_api_base(self) -> Optional[str]: return os.environ.get("OPENAI_API_BASE") @property - def model_provider(self) -> Optional[ModelProviderType]: + def model_provider(self) -> Optional[ModelSource]: """The preferred model provider to use. Options: 'AZURE', 'CAII', 'OPENAI', 'BEDROCK' If not set, will use the first available provider in priority order.""" provider = os.environ.get("MODEL_PROVIDER") - if provider and provider in ["Azure", "CAII", "OpenAI", "Bedrock"]: - return cast(ModelProviderType, provider) - return None + try: + return ModelSource(provider) + except ValueError: + return None settings = _Settings() diff --git a/llm-service/app/routers/index/models/__init__.py b/llm-service/app/routers/index/models/__init__.py index ddddb2a0b..5e8185830 100644 --- a/llm-service/app/routers/index/models/__init__.py +++ b/llm-service/app/routers/index/models/__init__.py @@ -40,7 +40,7 @@ from fastapi import APIRouter import app.services.models -import app.services.models._model_source +from app.config import ModelSource from .... import exceptions from ....services import models from ....services.caii.caii import describe_endpoint, build_model_response @@ -71,7 +71,7 @@ def get_reranking_models() -> List[ModelResponse]: "/model_source", summary="Model source enabled - Bedrock, CAII, OpenAI or Azure" ) @exceptions.propagates -def get_model() -> app.services.models._model_source.ModelSource: +def get_model() -> ModelSource: return app.services.models.get_model_source() diff --git a/llm-service/app/services/amp_metadata/__init__.py b/llm-service/app/services/amp_metadata/__init__.py index 3baa8fb2b..0b8b86277 100644 --- a/llm-service/app/services/amp_metadata/__init__.py +++ b/llm-service/app/services/amp_metadata/__init__.py @@ -50,7 +50,7 @@ ChatStoreProviderType, VectorDbProviderType, MetadataDbProviderType, - ModelProviderType, + ModelSource, ) from app.services.models.providers import ( CAIIModelProvider, @@ -136,7 +136,7 @@ class ProjectConfig(BaseModel): chat_store_provider: ChatStoreProviderType vector_db_provider: VectorDbProviderType metadata_db_provider: MetadataDbProviderType - model_provider: Optional[ModelProviderType] = None + model_provider: Optional[ModelSource] = None aws_config: AwsConfig azure_config: AzureConfig caii_config: CaiiConfig @@ -216,13 +216,13 @@ def validate_model_config(environ: dict[str, str]) -> ValidationResult: f"Preferred provider {preferred_provider} is properly configured. \n" ) if preferred_provider == "Bedrock": - valid_model_config_exists = BedrockModelProvider.is_enabled() + valid_model_config_exists = BedrockModelProvider.env_vars_are_set() elif preferred_provider == "Azure": - valid_model_config_exists = AzureModelProvider.is_enabled() + valid_model_config_exists = AzureModelProvider.env_vars_are_set() elif preferred_provider == "OpenAI": - valid_model_config_exists = OpenAiModelProvider.is_enabled() + valid_model_config_exists = OpenAiModelProvider.env_vars_are_set() elif preferred_provider == "CAII": - valid_model_config_exists = CAIIModelProvider.is_enabled() + valid_model_config_exists = CAIIModelProvider.env_vars_are_set() return ValidationResult( valid=valid_model_config_exists, message=valid_message if valid_model_config_exists else message, @@ -276,7 +276,7 @@ def validate_model_config(environ: dict[str, str]) -> ValidationResult: if message == "": # check to see if CAII models are available via discovery - if CAIIModelProvider.is_enabled(): + if CAIIModelProvider.env_vars_are_set(): message = "CAII models are available." valid_model_config_exists = True else: @@ -388,7 +388,7 @@ def build_configuration( validate_config = validate(frozenset(env.items())) model_provider = ( - TypeAdapter(ModelProviderType).validate_python(env.get("MODEL_PROVIDER")) + TypeAdapter(ModelSource).validate_python(env.get("MODEL_PROVIDER")) if env.get("MODEL_PROVIDER") else None ) diff --git a/llm-service/app/services/models/__init__.py b/llm-service/app/services/models/__init__.py index da248cc22..d6d6e0211 100644 --- a/llm-service/app/services/models/__init__.py +++ b/llm-service/app/services/models/__init__.py @@ -37,12 +37,12 @@ # from .embedding import Embedding from .llm import LLM -from .providers import ModelProvider +from .providers import get_provider_class from .reranking import Reranking -from ._model_source import ModelSource +from ...config import ModelSource -__all__ = ["Embedding", "LLM", "Reranking", "ModelSource"] +__all__ = ["Embedding", "LLM", "Reranking", "get_model_source"] def get_model_source() -> ModelSource: - return ModelProvider.get_provider_class().get_model_source() + return get_provider_class().get_model_source() diff --git a/llm-service/app/services/models/_model_source.py b/llm-service/app/services/models/_model_source.py deleted file mode 100644 index 625662ea5..000000000 --- a/llm-service/app/services/models/_model_source.py +++ /dev/null @@ -1,48 +0,0 @@ -# -# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) -# (C) Cloudera, Inc. 2025 -# All rights reserved. -# -# Applicable Open Source License: Apache 2.0 -# -# NOTE: Cloudera open source products are modular software products -# made up of hundreds of individual components, each of which was -# individually copyrighted. Each Cloudera open source product is a -# collective work under U.S. Copyright Law. Your license to use the -# collective work is as provided in your written agreement with -# Cloudera. Used apart from the collective work, this file is -# licensed for your use pursuant to the open source license -# identified above. -# -# This code is provided to you pursuant a written agreement with -# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute -# this code. If you do not have a written agreement with Cloudera nor -# with an authorized and properly licensed third party, you do not -# have any rights to access nor to use this code. -# -# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the -# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY -# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED -# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO -# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND -# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, -# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS -# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE -# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY -# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR -# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES -# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF -# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF -# DATA. -# - -from enum import Enum - - -class ModelSource(str, Enum): - BEDROCK = "Bedrock" - CAII = "CAII" - AZURE = "Azure" - OPENAI = "OpenAI" - - diff --git a/llm-service/app/services/models/embedding.py b/llm-service/app/services/models/embedding.py index 9dd8f6b94..551308699 100644 --- a/llm-service/app/services/models/embedding.py +++ b/llm-service/app/services/models/embedding.py @@ -41,7 +41,7 @@ from llama_index.core.base.embeddings.base import BaseEmbedding from . import _model_type, _noop -from .providers._model_provider import ModelProvider +from .providers import get_provider_class from ..caii.types import ModelResponse @@ -51,7 +51,7 @@ def get(cls, model_name: Optional[str] = None) -> BaseEmbedding: if model_name is None: model_name = cls.list_available()[0].model_id - return ModelProvider.get_provider_class().get_embedding_model(model_name) + return get_provider_class().get_embedding_model(model_name) @staticmethod def get_noop() -> BaseEmbedding: @@ -59,7 +59,7 @@ def get_noop() -> BaseEmbedding: @staticmethod def list_available() -> list[ModelResponse]: - return ModelProvider.get_provider_class().list_embedding_models() + return get_provider_class().list_embedding_models() @classmethod def test(cls, model_name: str) -> str: diff --git a/llm-service/app/services/models/llm.py b/llm-service/app/services/models/llm.py index e8283ac32..15f126a63 100644 --- a/llm-service/app/services/models/llm.py +++ b/llm-service/app/services/models/llm.py @@ -41,7 +41,7 @@ from llama_index.core.base.llms.types import ChatMessage, MessageRole from . import _model_type, _noop -from .providers._model_provider import ModelProvider +from .providers import get_provider_class from ..caii.types import ModelResponse @@ -51,7 +51,7 @@ def get(cls, model_name: Optional[str] = None) -> llms.LLM: if not model_name: model_name = cls.list_available()[0].model_id - return ModelProvider.get_provider_class().get_llm_model(model_name) + return get_provider_class().get_llm_model(model_name) @staticmethod def get_noop() -> llms.LLM: @@ -59,7 +59,7 @@ def get_noop() -> llms.LLM: @staticmethod def list_available() -> list[ModelResponse]: - return ModelProvider.get_provider_class().list_llm_models() + return get_provider_class().list_llm_models() @classmethod def test(cls, model_name: str) -> Literal["ok"]: diff --git a/llm-service/app/services/models/providers/__init__.py b/llm-service/app/services/models/providers/__init__.py index 47ecc8c78..7dfb6d34c 100644 --- a/llm-service/app/services/models/providers/__init__.py +++ b/llm-service/app/services/models/providers/__init__.py @@ -35,16 +35,51 @@ # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF # DATA. # +import logging + +from app.config import settings from .azure import AzureModelProvider from .bedrock import BedrockModelProvider from .caii import CAIIModelProvider from .openai import OpenAiModelProvider -from ._model_provider import ModelProvider +from ._model_provider import _ModelProvider + +logger = logging.getLogger(__name__) __all__ = [ "AzureModelProvider", "BedrockModelProvider", "CAIIModelProvider", "OpenAiModelProvider", - "ModelProvider", + "get_provider_class", ] + + +def get_provider_class() -> type[_ModelProvider]: + """Return the ModelProvider subclass for the given provider name.""" + model_providers: list[type[_ModelProvider]] = sorted( + _ModelProvider.__subclasses__(), + key=lambda ModelProviderSubcls: ModelProviderSubcls.get_priority(), + ) + + model_provider = settings.model_provider + for ModelProviderSubcls in model_providers: + if model_provider == ModelProviderSubcls.get_model_source(): + logger.info( + 'using model provider "%s" based on `MODEL_PROVIDER` env var', + ModelProviderSubcls.get_model_source().value, + ) + return ModelProviderSubcls + + # Fallback if no specific provider is set + for ModelProviderSubcls in model_providers: + if ModelProviderSubcls.env_vars_are_set(): + logger.info( + 'falling back to model provider "%s" based on env vars %s', + ModelProviderSubcls.get_model_source().value, + ModelProviderSubcls.get_env_var_names(), + ) + return ModelProviderSubcls + + logger.info('falling back to model provider "CAII"') + return CAIIModelProvider diff --git a/llm-service/app/services/models/providers/_model_provider.py b/llm-service/app/services/models/providers/_model_provider.py index 954d93b2f..e591c5d40 100644 --- a/llm-service/app/services/models/providers/_model_provider.py +++ b/llm-service/app/services/models/providers/_model_provider.py @@ -42,50 +42,36 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor.types import BaseNodePostprocessor -from app.config import settings -from .._model_source import ModelSource +from app.config import ModelSource from ...caii.types import ModelResponse -class ModelProvider(abc.ABC): +class _ModelProvider(abc.ABC): @classmethod - def is_enabled(cls) -> bool: - """Return whether this model provider is enabled, based on the presence of required env vars.""" + def env_vars_are_set(cls) -> bool: + """Return whether this model provider's env vars have set values.""" return all(map(os.environ.get, cls.get_env_var_names())) @staticmethod - def get_provider_class() -> type["ModelProvider"]: - """Return the ModelProvider subclass for the given provider name.""" - from . import ( - AzureModelProvider, - CAIIModelProvider, - OpenAiModelProvider, - BedrockModelProvider, - ) - - model_provider = settings.model_provider - if model_provider == "Azure": - return AzureModelProvider - elif model_provider == "CAII": - return CAIIModelProvider - elif model_provider == "OpenAI": - return OpenAiModelProvider - elif model_provider == "Bedrock": - return BedrockModelProvider + @abc.abstractmethod + def get_env_var_names() -> set[str]: + """Return the names of the env vars required by this model provider.""" + raise NotImplementedError - # Fallback to priority order if no specific provider is set - if AzureModelProvider.is_enabled(): - return AzureModelProvider - elif OpenAiModelProvider.is_enabled(): - return OpenAiModelProvider - elif BedrockModelProvider.is_enabled(): - return BedrockModelProvider - return CAIIModelProvider + @staticmethod + @abc.abstractmethod + def get_model_source() -> ModelSource: + """Return the name of this model provider""" + raise NotImplementedError @staticmethod @abc.abstractmethod - def get_env_var_names() -> set[str]: - """Return the names of the env vars required by this model provider.""" + def get_priority() -> int: + """Return the priority of this model provider relative to the others. + + 1 is the highest priority. + + """ raise NotImplementedError @staticmethod @@ -123,8 +109,3 @@ def get_embedding_model(name: str) -> BaseEmbedding: def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor: """Return reranking model with `name`.""" raise NotImplementedError - - @staticmethod - @abc.abstractmethod - def get_model_source() -> ModelSource: - raise NotImplementedError diff --git a/llm-service/app/services/models/providers/azure.py b/llm-service/app/services/models/providers/azure.py index c079e1aa3..da222b5f6 100644 --- a/llm-service/app/services/models/providers/azure.py +++ b/llm-service/app/services/models/providers/azure.py @@ -38,19 +38,26 @@ from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding from llama_index.llms.azure_openai import AzureOpenAI -from ._model_provider import ModelProvider -from .._model_source import ModelSource +from ._model_provider import _ModelProvider from ...caii.types import ModelResponse from ...llama_utils import completion_to_prompt, messages_to_prompt from ...query.simple_reranker import SimpleReranker -from ....config import settings +from ....config import settings, ModelSource -class AzureModelProvider(ModelProvider): +class AzureModelProvider(_ModelProvider): @staticmethod def get_env_var_names() -> set[str]: return {"AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "OPENAI_API_VERSION"} + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.AZURE + + @staticmethod + def get_priority() -> int: + return 1 + @staticmethod def list_llm_models() -> list[ModelResponse]: return [ @@ -105,10 +112,6 @@ def get_embedding_model(name: str) -> AzureOpenAIEmbedding: def get_reranking_model(name: str, top_n: int) -> SimpleReranker: return SimpleReranker(top_n=top_n) - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.AZURE - # ensure interface is implemented _ = AzureModelProvider() diff --git a/llm-service/app/services/models/providers/bedrock.py b/llm-service/app/services/models/providers/bedrock.py index f74ab7f70..241137cfc 100644 --- a/llm-service/app/services/models/providers/bedrock.py +++ b/llm-service/app/services/models/providers/bedrock.py @@ -51,9 +51,8 @@ from llama_index.postprocessor.bedrock_rerank import AWSBedrockRerank from pydantic import TypeAdapter -from app.config import settings -from ._model_provider import ModelProvider -from .._model_source import ModelSource +from app.config import settings, ModelSource +from ._model_provider import _ModelProvider from ...caii.types import ModelResponse from ...llama_utils import completion_to_prompt, messages_to_prompt from ...utils import raise_for_http_error, timed_lru_cache @@ -73,11 +72,19 @@ } -class BedrockModelProvider(ModelProvider): +class BedrockModelProvider(_ModelProvider): @staticmethod def get_env_var_names() -> set[str]: return {"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION"} + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.BEDROCK + + @staticmethod + def get_priority() -> int: + return 3 + @staticmethod def get_foundation_models( modality: Optional[BedrockModality] = None, @@ -309,10 +316,6 @@ def get_embedding_model(name: str) -> BedrockEmbedding: def get_reranking_model(name: str, top_n: int) -> AWSBedrockRerank: return AWSBedrockRerank(rerank_model_name=name, top_n=top_n) - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.BEDROCK - # ensure interface is implemented _ = BedrockModelProvider() diff --git a/llm-service/app/services/models/providers/caii.py b/llm-service/app/services/models/providers/caii.py index c5d37c5f6..1f0de10e0 100644 --- a/llm-service/app/services/models/providers/caii.py +++ b/llm-service/app/services/models/providers/caii.py @@ -42,8 +42,8 @@ from llama_index.core.postprocessor.types import BaseNodePostprocessor from packaging.version import Version -from ._model_provider import ModelProvider -from .._model_source import ModelSource +from ._model_provider import _ModelProvider +from app.config import ModelSource from ...caii.caii import ( get_caii_llm_models, get_caii_embedding_models, @@ -51,7 +51,8 @@ get_llm as get_caii_llm_model, get_embedding_model as get_caii_embedding_model, get_reranking_model as get_caii_reranking_model, - describe_endpoint, get_models_with_task, + describe_endpoint, + get_models_with_task, ) from ...caii.types import ModelResponse from ...caii.utils import get_cml_version_from_sense_bootstrap @@ -59,11 +60,19 @@ from ...utils import timed_lru_cache -class CAIIModelProvider(ModelProvider): +class CAIIModelProvider(_ModelProvider): @staticmethod def get_env_var_names() -> set[str]: return {"CAII_DOMAIN"} + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.CAII + + @staticmethod + def get_priority() -> int: + return 4 + @staticmethod @timed_lru_cache(maxsize=1, seconds=300) def list_llm_models() -> list[ModelResponse]: @@ -100,21 +109,17 @@ def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor: return get_caii_reranking_model(name, top_n) @classmethod - def is_enabled(cls) -> bool: + def env_vars_are_set(cls) -> bool: version: Optional[str] = get_cml_version_from_sense_bootstrap() if not version: - return super().is_enabled() + return super().env_vars_are_set() cml_version = Version(version) if cml_version >= Version("2.0.50-b68"): available_models = get_models_with_task("TEXT_GENERATION") if available_models: return True - return super().is_enabled() - - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.CAII + return super().env_vars_are_set() # ensure interface is implemented diff --git a/llm-service/app/services/models/providers/openai.py b/llm-service/app/services/models/providers/openai.py index 62d258790..0f67099fe 100644 --- a/llm-service/app/services/models/providers/openai.py +++ b/llm-service/app/services/models/providers/openai.py @@ -43,18 +43,25 @@ from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI -from ._model_provider import ModelProvider -from .._model_source import ModelSource +from ._model_provider import _ModelProvider from ...caii.types import ModelResponse from ...llama_utils import completion_to_prompt, messages_to_prompt -from ....config import settings +from ....config import settings, ModelSource -class OpenAiModelProvider(ModelProvider): +class OpenAiModelProvider(_ModelProvider): @staticmethod def get_env_var_names() -> set[str]: return {"OPENAI_API_KEY"} + @staticmethod + def get_model_source() -> ModelSource: + return ModelSource.OPENAI + + @staticmethod + def get_priority() -> int: + return 2 + @staticmethod def list_llm_models() -> list[ModelResponse]: return [ @@ -119,10 +126,6 @@ def get_embedding_model(name: str) -> OpenAIEmbedding: def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor: raise NotImplementedError("No reranking models available") - @staticmethod - def get_model_source() -> ModelSource: - return ModelSource.OPENAI - # ensure interface is implemented _ = OpenAiModelProvider() diff --git a/llm-service/app/services/models/reranking.py b/llm-service/app/services/models/reranking.py index 95cc6f870..8aad6959d 100644 --- a/llm-service/app/services/models/reranking.py +++ b/llm-service/app/services/models/reranking.py @@ -42,7 +42,7 @@ from llama_index.core.schema import NodeWithScore, TextNode from . import _model_type -from .providers._model_provider import ModelProvider +from .providers import get_provider_class from ..caii.types import ModelResponse from ..query.simple_reranker import SimpleReranker @@ -57,9 +57,7 @@ def get( if not model_name: return SimpleReranker(top_n=top_n) - return ModelProvider.get_provider_class().get_reranking_model( - name=model_name, top_n=top_n - ) + return get_provider_class().get_reranking_model(name=model_name, top_n=top_n) @staticmethod def get_noop() -> BaseNodePostprocessor: @@ -67,7 +65,7 @@ def get_noop() -> BaseNodePostprocessor: @staticmethod def list_available() -> list[ModelResponse]: - return ModelProvider.get_provider_class().list_reranking_models() + return get_provider_class().list_reranking_models() @classmethod def test(cls, model_name: str) -> str: diff --git a/llm-service/app/services/query/agents/tool_calling_querier.py b/llm-service/app/services/query/agents/tool_calling_querier.py index 954f1877d..6ad49b08c 100644 --- a/llm-service/app/services/query/agents/tool_calling_querier.py +++ b/llm-service/app/services/query/agents/tool_calling_querier.py @@ -409,7 +409,7 @@ async def agen() -> AsyncGenerator[ChatResponse, None]: # if delta is empty and response is empty, # it is a start to a tool call stream - if BedrockModelProvider.is_enabled(): + if BedrockModelProvider.env_vars_are_set(): delta = event.delta or "" if ( isinstance(event.raw, dict) diff --git a/llm-service/app/services/query/querier.py b/llm-service/app/services/query/querier.py index 299e67a77..bf706341c 100644 --- a/llm-service/app/services/query/querier.py +++ b/llm-service/app/services/query/querier.py @@ -53,7 +53,7 @@ from .flexible_retriever import FlexibleRetriever from .multi_retriever import MultiSourceRetriever from ..metadata_apis.session_metadata_api import Session -from ..models._model_source import ModelSource +from ...config import ModelSource from ..models import get_model_source if TYPE_CHECKING: diff --git a/llm-service/app/tests/services/test_models.py b/llm-service/app/tests/services/test_models.py index 1649c8eba..2397f18f5 100644 --- a/llm-service/app/tests/services/test_models.py +++ b/llm-service/app/tests/services/test_models.py @@ -43,39 +43,34 @@ from app.services.caii import caii from app.services.caii.types import ListEndpointEntry from app.services.models.providers import BedrockModelProvider -from app.services.models.providers._model_provider import ModelProvider +from app.services.models.providers._model_provider import _ModelProvider def get_all_env_var_names() -> set[str]: """Return the names of all the env vars required by all model providers.""" return set( itertools.chain.from_iterable( - subcls.get_env_var_names() for subcls in ModelProvider.__subclasses__() + subcls.get_env_var_names() for subcls in _ModelProvider.__subclasses__() ) ) -@pytest.fixture() +@pytest.fixture(params=_ModelProvider.__subclasses__()) def EnabledModelProvider( request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch, -) -> type[ModelProvider]: +) -> type[_ModelProvider]: """Sets and unsets environment variables for the given model provider.""" - ModelProviderSubcls: type[ModelProvider] = request.param + ModelProviderSubcls: type[_ModelProvider] = request.param + for name in get_all_env_var_names(): + monkeypatch.delenv(name, raising=False) for name in ModelProviderSubcls.get_env_var_names(): monkeypatch.setenv(name, "test") - for name in get_all_env_var_names() - ModelProviderSubcls.get_env_var_names(): - monkeypatch.delenv(name, raising=False) return ModelProviderSubcls -@pytest.mark.parametrize( - "EnabledModelProvider", - ModelProvider.__subclasses__(), - indirect=True, -) class TestListAvailableModels: @pytest.fixture(autouse=True) def caii_get_models(self, monkeypatch: pytest.MonkeyPatch) -> None: @@ -103,18 +98,18 @@ def get_foundation_models(self, monkeypatch: pytest.MonkeyPatch) -> None: BedrockModelProvider, "get_foundation_models", lambda modality: [] ) - def test_embedding(self, EnabledModelProvider: type[ModelProvider]) -> None: + def test_embedding(self, EnabledModelProvider: type[_ModelProvider]) -> None: """Verify models.Embedding.list_available() only returns models from the enabled model provider.""" assert ( models.Embedding.list_available() == EnabledModelProvider.list_embedding_models() ) - def test_llm(self, EnabledModelProvider: type[ModelProvider]) -> None: + def test_llm(self, EnabledModelProvider: type[_ModelProvider]) -> None: """Verify models.LLM.list_available() only returns models from the enabled model provider.""" assert models.LLM.list_available() == EnabledModelProvider.list_llm_models() - def test_reranking(self, EnabledModelProvider: type[ModelProvider]) -> None: + def test_reranking(self, EnabledModelProvider: type[_ModelProvider]) -> None: """Verify models.Reranking.list_available() only returns models from the enabled model provider.""" assert ( models.Reranking.list_available()