|
42 | 42 | from llama_index.core.llms import LLM |
43 | 43 | from llama_index.core.postprocessor.types import BaseNodePostprocessor |
44 | 44 |
|
45 | | -from app.config import settings |
46 | | -from .._model_source import ModelSource |
| 45 | +from app.config import ModelSource |
47 | 46 | from ...caii.types import ModelResponse |
48 | 47 |
|
49 | 48 |
|
50 | | -class ModelProvider(abc.ABC): |
| 49 | +class _ModelProvider(abc.ABC): |
51 | 50 | @classmethod |
52 | | - def is_enabled(cls) -> bool: |
53 | | - """Return whether this model provider is enabled, based on the presence of required env vars.""" |
| 51 | + def env_vars_are_set(cls) -> bool: |
| 52 | + """Return whether this model provider's env vars have set values.""" |
54 | 53 | return all(map(os.environ.get, cls.get_env_var_names())) |
55 | 54 |
|
56 | 55 | @staticmethod |
57 | | - def get_provider_class() -> type["ModelProvider"]: |
58 | | - """Return the ModelProvider subclass for the given provider name.""" |
59 | | - from . import ( |
60 | | - AzureModelProvider, |
61 | | - CAIIModelProvider, |
62 | | - OpenAiModelProvider, |
63 | | - BedrockModelProvider, |
64 | | - ) |
65 | | - |
66 | | - model_provider = settings.model_provider |
67 | | - if model_provider == "Azure": |
68 | | - return AzureModelProvider |
69 | | - elif model_provider == "CAII": |
70 | | - return CAIIModelProvider |
71 | | - elif model_provider == "OpenAI": |
72 | | - return OpenAiModelProvider |
73 | | - elif model_provider == "Bedrock": |
74 | | - return BedrockModelProvider |
| 56 | + @abc.abstractmethod |
| 57 | + def get_env_var_names() -> set[str]: |
| 58 | + """Return the names of the env vars required by this model provider.""" |
| 59 | + raise NotImplementedError |
75 | 60 |
|
76 | | - # Fallback to priority order if no specific provider is set |
77 | | - if AzureModelProvider.is_enabled(): |
78 | | - return AzureModelProvider |
79 | | - elif OpenAiModelProvider.is_enabled(): |
80 | | - return OpenAiModelProvider |
81 | | - elif BedrockModelProvider.is_enabled(): |
82 | | - return BedrockModelProvider |
83 | | - return CAIIModelProvider |
| 61 | + @staticmethod |
| 62 | + @abc.abstractmethod |
| 63 | + def get_model_source() -> ModelSource: |
| 64 | + """Return the name of this model provider""" |
| 65 | + raise NotImplementedError |
84 | 66 |
|
85 | 67 | @staticmethod |
86 | 68 | @abc.abstractmethod |
87 | | - def get_env_var_names() -> set[str]: |
88 | | - """Return the names of the env vars required by this model provider.""" |
| 69 | + def get_priority() -> int: |
| 70 | + """Return the priority of this model provider relative to the others. |
| 71 | +
|
| 72 | + 1 is the highest priority. |
| 73 | +
|
| 74 | + """ |
89 | 75 | raise NotImplementedError |
90 | 76 |
|
91 | 77 | @staticmethod |
@@ -123,8 +109,3 @@ def get_embedding_model(name: str) -> BaseEmbedding: |
123 | 109 | def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor: |
124 | 110 | """Return reranking model with `name`.""" |
125 | 111 | raise NotImplementedError |
126 | | - |
127 | | - @staticmethod |
128 | | - @abc.abstractmethod |
129 | | - def get_model_source() -> ModelSource: |
130 | | - raise NotImplementedError |
0 commit comments