Skip to content

Commit 0a6b9ce

Browse files
Clean up logic for determining the active model provider (#309)
* Parametrize ModelProvider fixture directly * Simplify clearing env vars * Move get_model_source() to the tops of classes * Factor out get_provider_class() * Rename is_enabled() to env_vars_are_set() * Privatize ModelProvider base class * Rework get_provider_class() * Remove redundant ModelProviderType * Clean up logs
1 parent 8d58bc5 commit 0a6b9ce

File tree

18 files changed

+156
-176
lines changed

18 files changed

+156
-176
lines changed

llm-service/app/ai/indexing/summary_indexer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@
7373
from .base import BaseTextIndexer
7474
from .readers.base_reader import ReaderConfig, ChunksResult
7575
from ..vector_stores.vector_store_factory import VectorStoreFactory
76-
from ...config import settings
76+
from ...config import settings, ModelSource
7777
from ...services.metadata_apis import data_sources_metadata_api
78-
from ...services.models.providers import ModelProvider
79-
from ...services.models import ModelSource
78+
from ...services.models.providers import get_provider_class
8079

8180
logger = logging.getLogger(__name__)
8281

@@ -133,9 +132,7 @@ def __index_configuration(
133132
embed_summaries: bool = True,
134133
) -> Dict[str, Any]:
135134
prompt_helper: Optional[PromptHelper] = None
136-
model_source: ModelSource = (
137-
ModelProvider.get_provider_class().get_model_source()
138-
)
135+
model_source: ModelSource = get_provider_class().get_model_source()
139136
if model_source == "CAII":
140137
# if we're using CAII, let's be conservative and use a small context window to account for mistral's small context
141138
prompt_helper = PromptHelper(context_window=3000)

llm-service/app/config.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,21 @@
4646

4747
import logging
4848
import os.path
49+
from enum import Enum
4950
from typing import cast, Optional, Literal
5051

5152

5253
SummaryStorageProviderType = Literal["Local", "S3"]
5354
ChatStoreProviderType = Literal["Local", "S3"]
5455
VectorDbProviderType = Literal["QDRANT", "OPENSEARCH"]
5556
MetadataDbProviderType = Literal["H2", "PostgreSQL"]
56-
ModelProviderType = Literal["Azure", "CAII", "OpenAI", "Bedrock"]
57+
58+
59+
class ModelSource(str, Enum):
60+
AZURE = "Azure"
61+
OPENAI = "OpenAI"
62+
BEDROCK = "Bedrock"
63+
CAII = "CAII"
5764

5865

5966
class _Settings:
@@ -185,14 +192,15 @@ def openai_api_base(self) -> Optional[str]:
185192
return os.environ.get("OPENAI_API_BASE")
186193

187194
@property
188-
def model_provider(self) -> Optional[ModelProviderType]:
195+
def model_provider(self) -> Optional[ModelSource]:
189196
"""The preferred model provider to use.
190197
Options: 'AZURE', 'CAII', 'OPENAI', 'BEDROCK'
191198
If not set, will use the first available provider in priority order."""
192199
provider = os.environ.get("MODEL_PROVIDER")
193-
if provider and provider in ["Azure", "CAII", "OpenAI", "Bedrock"]:
194-
return cast(ModelProviderType, provider)
195-
return None
200+
try:
201+
return ModelSource(provider)
202+
except ValueError:
203+
return None
196204

197205

198206
settings = _Settings()

llm-service/app/routers/index/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from fastapi import APIRouter
4141

4242
import app.services.models
43-
import app.services.models._model_source
43+
from app.config import ModelSource
4444
from .... import exceptions
4545
from ....services import models
4646
from ....services.caii.caii import describe_endpoint, build_model_response
@@ -71,7 +71,7 @@ def get_reranking_models() -> List[ModelResponse]:
7171
"/model_source", summary="Model source enabled - Bedrock, CAII, OpenAI or Azure"
7272
)
7373
@exceptions.propagates
74-
def get_model() -> app.services.models._model_source.ModelSource:
74+
def get_model() -> ModelSource:
7575
return app.services.models.get_model_source()
7676

7777

llm-service/app/services/amp_metadata/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
ChatStoreProviderType,
5151
VectorDbProviderType,
5252
MetadataDbProviderType,
53-
ModelProviderType,
53+
ModelSource,
5454
)
5555
from app.services.models.providers import (
5656
CAIIModelProvider,
@@ -136,7 +136,7 @@ class ProjectConfig(BaseModel):
136136
chat_store_provider: ChatStoreProviderType
137137
vector_db_provider: VectorDbProviderType
138138
metadata_db_provider: MetadataDbProviderType
139-
model_provider: Optional[ModelProviderType] = None
139+
model_provider: Optional[ModelSource] = None
140140
aws_config: AwsConfig
141141
azure_config: AzureConfig
142142
caii_config: CaiiConfig
@@ -216,13 +216,13 @@ def validate_model_config(environ: dict[str, str]) -> ValidationResult:
216216
f"Preferred provider {preferred_provider} is properly configured. \n"
217217
)
218218
if preferred_provider == "Bedrock":
219-
valid_model_config_exists = BedrockModelProvider.is_enabled()
219+
valid_model_config_exists = BedrockModelProvider.env_vars_are_set()
220220
elif preferred_provider == "Azure":
221-
valid_model_config_exists = AzureModelProvider.is_enabled()
221+
valid_model_config_exists = AzureModelProvider.env_vars_are_set()
222222
elif preferred_provider == "OpenAI":
223-
valid_model_config_exists = OpenAiModelProvider.is_enabled()
223+
valid_model_config_exists = OpenAiModelProvider.env_vars_are_set()
224224
elif preferred_provider == "CAII":
225-
valid_model_config_exists = CAIIModelProvider.is_enabled()
225+
valid_model_config_exists = CAIIModelProvider.env_vars_are_set()
226226
return ValidationResult(
227227
valid=valid_model_config_exists,
228228
message=valid_message if valid_model_config_exists else message,
@@ -276,7 +276,7 @@ def validate_model_config(environ: dict[str, str]) -> ValidationResult:
276276

277277
if message == "":
278278
# check to see if CAII models are available via discovery
279-
if CAIIModelProvider.is_enabled():
279+
if CAIIModelProvider.env_vars_are_set():
280280
message = "CAII models are available."
281281
valid_model_config_exists = True
282282
else:
@@ -388,7 +388,7 @@ def build_configuration(
388388
validate_config = validate(frozenset(env.items()))
389389

390390
model_provider = (
391-
TypeAdapter(ModelProviderType).validate_python(env.get("MODEL_PROVIDER"))
391+
TypeAdapter(ModelSource).validate_python(env.get("MODEL_PROVIDER"))
392392
if env.get("MODEL_PROVIDER")
393393
else None
394394
)

llm-service/app/services/models/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@
3737
#
3838
from .embedding import Embedding
3939
from .llm import LLM
40-
from .providers import ModelProvider
40+
from .providers import get_provider_class
4141
from .reranking import Reranking
42-
from ._model_source import ModelSource
42+
from ...config import ModelSource
4343

44-
__all__ = ["Embedding", "LLM", "Reranking", "ModelSource"]
44+
__all__ = ["Embedding", "LLM", "Reranking", "get_model_source"]
4545

4646

4747
def get_model_source() -> ModelSource:
48-
return ModelProvider.get_provider_class().get_model_source()
48+
return get_provider_class().get_model_source()

llm-service/app/services/models/_model_source.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

llm-service/app/services/models/embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from llama_index.core.base.embeddings.base import BaseEmbedding
4242

4343
from . import _model_type, _noop
44-
from .providers._model_provider import ModelProvider
44+
from .providers import get_provider_class
4545
from ..caii.types import ModelResponse
4646

4747

@@ -51,15 +51,15 @@ def get(cls, model_name: Optional[str] = None) -> BaseEmbedding:
5151
if model_name is None:
5252
model_name = cls.list_available()[0].model_id
5353

54-
return ModelProvider.get_provider_class().get_embedding_model(model_name)
54+
return get_provider_class().get_embedding_model(model_name)
5555

5656
@staticmethod
5757
def get_noop() -> BaseEmbedding:
5858
return _noop.DummyEmbeddingModel()
5959

6060
@staticmethod
6161
def list_available() -> list[ModelResponse]:
62-
return ModelProvider.get_provider_class().list_embedding_models()
62+
return get_provider_class().list_embedding_models()
6363

6464
@classmethod
6565
def test(cls, model_name: str) -> str:

llm-service/app/services/models/llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from llama_index.core.base.llms.types import ChatMessage, MessageRole
4242

4343
from . import _model_type, _noop
44-
from .providers._model_provider import ModelProvider
44+
from .providers import get_provider_class
4545
from ..caii.types import ModelResponse
4646

4747

@@ -51,15 +51,15 @@ def get(cls, model_name: Optional[str] = None) -> llms.LLM:
5151
if not model_name:
5252
model_name = cls.list_available()[0].model_id
5353

54-
return ModelProvider.get_provider_class().get_llm_model(model_name)
54+
return get_provider_class().get_llm_model(model_name)
5555

5656
@staticmethod
5757
def get_noop() -> llms.LLM:
5858
return _noop.DummyLlm()
5959

6060
@staticmethod
6161
def list_available() -> list[ModelResponse]:
62-
return ModelProvider.get_provider_class().list_llm_models()
62+
return get_provider_class().list_llm_models()
6363

6464
@classmethod
6565
def test(cls, model_name: str) -> Literal["ok"]:

llm-service/app/services/models/providers/__init__.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,51 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38+
import logging
39+
40+
from app.config import settings
3841
from .azure import AzureModelProvider
3942
from .bedrock import BedrockModelProvider
4043
from .caii import CAIIModelProvider
4144
from .openai import OpenAiModelProvider
42-
from ._model_provider import ModelProvider
45+
from ._model_provider import _ModelProvider
46+
47+
logger = logging.getLogger(__name__)
4348

4449
__all__ = [
4550
"AzureModelProvider",
4651
"BedrockModelProvider",
4752
"CAIIModelProvider",
4853
"OpenAiModelProvider",
49-
"ModelProvider",
54+
"get_provider_class",
5055
]
56+
57+
58+
def get_provider_class() -> type[_ModelProvider]:
59+
"""Return the ModelProvider subclass for the given provider name."""
60+
model_providers: list[type[_ModelProvider]] = sorted(
61+
_ModelProvider.__subclasses__(),
62+
key=lambda ModelProviderSubcls: ModelProviderSubcls.get_priority(),
63+
)
64+
65+
model_provider = settings.model_provider
66+
for ModelProviderSubcls in model_providers:
67+
if model_provider == ModelProviderSubcls.get_model_source():
68+
logger.info(
69+
'using model provider "%s" based on `MODEL_PROVIDER` env var',
70+
ModelProviderSubcls.get_model_source().value,
71+
)
72+
return ModelProviderSubcls
73+
74+
# Fallback if no specific provider is set
75+
for ModelProviderSubcls in model_providers:
76+
if ModelProviderSubcls.env_vars_are_set():
77+
logger.info(
78+
'falling back to model provider "%s" based on env vars %s',
79+
ModelProviderSubcls.get_model_source().value,
80+
ModelProviderSubcls.get_env_var_names(),
81+
)
82+
return ModelProviderSubcls
83+
84+
logger.info('falling back to model provider "CAII"')
85+
return CAIIModelProvider

llm-service/app/services/models/providers/_model_provider.py

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -42,50 +42,36 @@
4242
from llama_index.core.llms import LLM
4343
from llama_index.core.postprocessor.types import BaseNodePostprocessor
4444

45-
from app.config import settings
46-
from .._model_source import ModelSource
45+
from app.config import ModelSource
4746
from ...caii.types import ModelResponse
4847

4948

50-
class ModelProvider(abc.ABC):
49+
class _ModelProvider(abc.ABC):
5150
@classmethod
52-
def is_enabled(cls) -> bool:
53-
"""Return whether this model provider is enabled, based on the presence of required env vars."""
51+
def env_vars_are_set(cls) -> bool:
52+
"""Return whether this model provider's env vars have set values."""
5453
return all(map(os.environ.get, cls.get_env_var_names()))
5554

5655
@staticmethod
57-
def get_provider_class() -> type["ModelProvider"]:
58-
"""Return the ModelProvider subclass for the given provider name."""
59-
from . import (
60-
AzureModelProvider,
61-
CAIIModelProvider,
62-
OpenAiModelProvider,
63-
BedrockModelProvider,
64-
)
65-
66-
model_provider = settings.model_provider
67-
if model_provider == "Azure":
68-
return AzureModelProvider
69-
elif model_provider == "CAII":
70-
return CAIIModelProvider
71-
elif model_provider == "OpenAI":
72-
return OpenAiModelProvider
73-
elif model_provider == "Bedrock":
74-
return BedrockModelProvider
56+
@abc.abstractmethod
57+
def get_env_var_names() -> set[str]:
58+
"""Return the names of the env vars required by this model provider."""
59+
raise NotImplementedError
7560

76-
# Fallback to priority order if no specific provider is set
77-
if AzureModelProvider.is_enabled():
78-
return AzureModelProvider
79-
elif OpenAiModelProvider.is_enabled():
80-
return OpenAiModelProvider
81-
elif BedrockModelProvider.is_enabled():
82-
return BedrockModelProvider
83-
return CAIIModelProvider
61+
@staticmethod
62+
@abc.abstractmethod
63+
def get_model_source() -> ModelSource:
64+
"""Return the name of this model provider"""
65+
raise NotImplementedError
8466

8567
@staticmethod
8668
@abc.abstractmethod
87-
def get_env_var_names() -> set[str]:
88-
"""Return the names of the env vars required by this model provider."""
69+
def get_priority() -> int:
70+
"""Return the priority of this model provider relative to the others.
71+
72+
1 is the highest priority.
73+
74+
"""
8975
raise NotImplementedError
9076

9177
@staticmethod
@@ -123,8 +109,3 @@ def get_embedding_model(name: str) -> BaseEmbedding:
123109
def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor:
124110
"""Return reranking model with `name`."""
125111
raise NotImplementedError
126-
127-
@staticmethod
128-
@abc.abstractmethod
129-
def get_model_source() -> ModelSource:
130-
raise NotImplementedError

0 commit comments

Comments
 (0)