File tree Expand file tree Collapse file tree 3 files changed +15
-19
lines changed
services/models/providers Expand file tree Collapse file tree 3 files changed +15
-19
lines changed Original file line number Diff line number Diff line change 3636# DATA.
3737#
3838import abc
39+ import itertools
3940import os
4041
4142from llama_index .core .base .embeddings .base import BaseEmbedding
@@ -109,3 +110,12 @@ def get_embedding_model(name: str) -> BaseEmbedding:
109110 def get_reranking_model (name : str , top_n : int ) -> BaseNodePostprocessor :
110111 """Return reranking model with `name`."""
111112 raise NotImplementedError
113+
114+
115+ def get_all_env_var_names () -> set [str ]:
116+ """Return the names of all the env vars required by all model providers."""
117+ return set (
118+ itertools .chain .from_iterable (
119+ subcls .get_env_var_names () for subcls in _ModelProvider .__subclasses__ ()
120+ )
121+ )
Original file line number Diff line number Diff line change 4949from app .services .caii .types import ModelResponse
5050from app .services .models import ModelProvider
5151from app .services .models .providers import BedrockModelProvider
52+ from app .services .models .providers ._model_provider import get_all_env_var_names
5253
5354TEXT_MODELS = [
5455 ("test.unavailable-text-model-v1" , "NOT_AVAILABLE" ),
@@ -175,15 +176,6 @@ def list_reranking_models() -> list[ModelResponse]:
175176 yield
176177
177178
178- def get_all_env_var_names () -> set [str ]:
179- """Return the names of all the env vars required by all model providers."""
180- return set (
181- itertools .chain .from_iterable (
182- subcls .get_env_var_names () for subcls in ModelProvider .__subclasses__ ()
183- )
184- )
185-
186-
187179# TODO: move this test function to a discoverable place
188180def test_bedrock (mock_bedrock , client ) -> None :
189181 response = client .get ("/llm-service/models/model_source" )
Original file line number Diff line number Diff line change 4343from app .services .caii import caii
4444from app .services .caii .types import ListEndpointEntry
4545from app .services .models .providers import BedrockModelProvider
46- from app .services .models .providers ._model_provider import _ModelProvider
47-
48-
49- def get_all_env_var_names () -> set [str ]:
50- """Return the names of all the env vars required by all model providers."""
51- return set (
52- itertools .chain .from_iterable (
53- subcls .get_env_var_names () for subcls in _ModelProvider .__subclasses__ ()
54- )
55- )
46+ from app .services .models .providers ._model_provider import (
47+ _ModelProvider ,
48+ get_all_env_var_names ,
49+ )
5650
5751
5852@pytest .fixture (params = _ModelProvider .__subclasses__ ())
You can’t perform that action at this time.
0 commit comments