@@ -144,6 +144,32 @@ async def update_endpoint(
144144
145145 dbendpoint = await self ._db_writer .update_provider_endpoint (endpoint .to_db_model ())
146146
147+ # If the auth type has not changed or no authentication is needed,
148+ # we can update the models
149+ if (
150+ founddbe .auth_type == endpoint .auth_type
151+ or endpoint .auth_type == apimodelsv1 .ProviderAuthType .none
152+ ):
153+ try :
154+ authm = await self ._db_reader .get_auth_material_by_provider_id (str (endpoint .id ))
155+
156+ models = await self ._find_models_for_provider (
157+ endpoint , authm .auth_type , authm .auth_blob , prov
158+ )
159+
160+ await self ._update_models_for_provider (dbendpoint , endpoint , prov , models )
161+
162+ # a model might have been deleted, let's repopulate the cache
163+ await self ._ws_crud .repopulate_mux_cache ()
164+ except Exception as err :
165+ # This is a non-fatal error. The endpoint might have changed
166+ # And the user will need to push a new API key anyway.
167+ logger .error (
168+ "Unable to update models for provider" ,
169+ provider = endpoint .name ,
170+ err = str (err ),
171+ )
172+
147173 return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
148174
149175 async def configure_auth_material (
@@ -164,12 +190,9 @@ async def configure_auth_material(
164190 provider_registry = get_provider_registry ()
165191 prov = endpoint .get_from_registry (provider_registry )
166192
167- models = []
168- if config .auth_type != apimodelsv1 .ProviderAuthType .passthrough :
169- try :
170- models = prov .models (endpoint = endpoint .endpoint , api_key = config .api_key )
171- except Exception as err :
172- raise ProviderModelsNotFoundError (f"Unable to get models from provider: { err } " )
193+ models = await self ._find_models_for_provider (
194+ endpoint , config .auth_type , config .api_key , prov
195+ )
173196
174197 await self ._db_writer .push_provider_auth_material (
175198 dbmodels .ProviderAuthMaterial (
@@ -179,7 +202,32 @@ async def configure_auth_material(
179202 )
180203 )
181204
182- models_set = set (models )
205+ await self ._update_models_for_provider (dbendpoint , endpoint , models )
206+
207+ # a model might have been deleted, let's repopulate the cache
208+ await self ._ws_crud .repopulate_mux_cache ()
209+
210+ async def _find_models_for_provider (
211+ self ,
212+ endpoint : apimodelsv1 .ProviderEndpoint ,
213+ auth_type : apimodelsv1 .ProviderAuthType ,
214+ api_key : str ,
215+ prov : BaseProvider ,
216+ ) -> List [str ]:
217+ if auth_type != apimodelsv1 .ProviderAuthType .passthrough :
218+ try :
219+ return prov .models (endpoint = endpoint .endpoint , api_key = api_key )
220+ except Exception as err :
221+ raise ProviderModelsNotFoundError (f"Unable to get models from provider: { err } " )
222+ return []
223+
224+ async def _update_models_for_provider (
225+ self ,
226+ dbendpoint : dbmodels .ProviderEndpoint ,
227+ endpoint : apimodelsv1 .ProviderEndpoint ,
228+ found_models : List [str ],
229+ ) -> None :
230+ models_set = set (found_models )
183231
184232 # Get the models from the provider
185233 models_in_db = await self ._db_reader .get_provider_models_by_provider_id (str (endpoint .id ))
@@ -202,9 +250,6 @@ async def configure_auth_material(
202250 model ,
203251 )
204252
205- # a model might have been deleted, let's repopulate the cache
206- await self ._ws_crud .repopulate_mux_cache ()
207-
208253 async def delete_endpoint (self , provider_id : UUID ):
209254 """Delete an endpoint."""
210255
0 commit comments