diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index be3fab28..9bbb8103 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -36,7 +36,7 @@ Usage, UserMessage, ) -from api.setting import AWS_REGION, DEBUG, DEFAULT_MODEL, ENABLE_CROSS_REGION_INFERENCE +from api.setting import AWS_REGION, CUSTOM_MODEL_LIST, DEBUG, DEFAULT_MODEL, ENABLE_CROSS_REGION_INFERENCE logger = logging.getLogger(__name__) @@ -101,6 +101,10 @@ def list_bedrock_models() -> dict: if not stream_supported or status not in ["ACTIVE", "LEGACY"]: continue + # if the user provides a custom model list, filter only those models + if CUSTOM_MODEL_LIST and model_id not in CUSTOM_MODEL_LIST: + continue + inference_types = model.get("inferenceTypesSupported", []) input_modalities = model["inputModalities"] # Add on-demand model list diff --git a/src/api/setting.py b/src/api/setting.py index e090300a..1decfd8c 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -15,4 +15,5 @@ AWS_REGION = os.environ.get("AWS_REGION", "us-west-2") DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0") DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3") +CUSTOM_MODEL_LIST = os.environ.get("CUSTOM_MODEL_LIST", "").split(",") ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"