diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index 36301fd..83b4bbf 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -78,33 +78,73 @@ pub async fn chat_completions( ) -> Result { let mut tracer = OtelTracer::start("chat", &payload); - for model_key in model_keys { - let model = model_registry.get(&model_key).unwrap(); - - if payload.model == model.model_type { - let response = model - .chat_completions(payload.clone()) - .await - .inspect_err(|e| { - eprintln!("Chat completion error for model {}: {:?}", model_key, e); - })?; - - if let ChatCompletionResponse::NonStream(completion) = response { - tracer.log_success(&completion); - return Ok(Json(completion).into_response()); + let matching_models: Vec<_> = model_keys + .iter() + .filter_map(|key| { + let model = model_registry.get(key)?; + if payload.model == model.model_type { + Some((key.clone(), model)) + } else { + None } + }) + .collect(); - if let ChatCompletionResponse::Stream(stream) = response { - return Ok(Sse::new(trace_and_stream(tracer, stream)) - .keep_alive(KeepAlive::default()) - .into_response()); + if matching_models.is_empty() { + tracer.log_error("No matching model found".to_string()); + eprintln!("No matching model found for: {}", payload.model); + return Err(StatusCode::NOT_FOUND); + } + + let mut last_error = None; + + for (model_key, model) in matching_models { + match model.chat_completions(payload.clone()).await { + Ok(response) => match response { + ChatCompletionResponse::NonStream(completion) => { + tracer.log_success(&completion); + return Ok(Json(completion).into_response()); + } + ChatCompletionResponse::Stream(stream) => { + return Ok(Sse::new(trace_and_stream(tracer, stream)) + .keep_alive(KeepAlive::default()) + .into_response()); + } + }, + Err(status_code) => { + eprintln!( + "Chat completion error for model {}: {:?}", + model_key, status_code + ); + + if is_transient_error(status_code) { + eprintln!( + "Transient error for model {}, trying next model...", + model_key + ); + last_error = Some(status_code); + continue; + } else { + return Err(status_code); + } } } } - tracer.log_error("No matching model found".to_string()); - eprintln!("No matching model found for: {}", payload.model); - Err(StatusCode::NOT_FOUND) + let error = last_error.unwrap(); + tracer.log_error(format!("All models failed with error: {}", error)); + Err(error) +} + +fn is_transient_error(status_code: StatusCode) -> bool { + matches!( + status_code, + StatusCode::TOO_MANY_REQUESTS | // 429 + StatusCode::REQUEST_TIMEOUT | // 408 + StatusCode::SERVICE_UNAVAILABLE | // 503 + StatusCode::BAD_GATEWAY | // 502 + StatusCode::GATEWAY_TIMEOUT // 504 + ) } pub async fn completions( @@ -114,21 +154,55 @@ pub async fn completions( ) -> impl IntoResponse { let mut tracer = OtelTracer::start("completion", &payload); - for model_key in model_keys { - let model = model_registry.get(&model_key).unwrap(); + let matching_models: Vec<_> = model_keys + .iter() + .filter_map(|key| { + let model = model_registry.get(key)?; + if payload.model == model.model_type { + Some((key.clone(), model)) + } else { + None + } + }) + .collect(); + + if matching_models.is_empty() { + tracer.log_error("No matching model found".to_string()); + eprintln!("No matching model found for: {}", payload.model); + return Err(StatusCode::NOT_FOUND); + } + + let mut last_error = None; + + for (model_key, model) in matching_models { + match model.completions(payload.clone()).await { + Ok(response) => { + tracer.log_success(&response); + return Ok(Json(response)); + } + Err(status_code) => { + eprintln!( + "Completion error for model {}: {:?}", + model_key, status_code + ); - if payload.model == model.model_type { - let response = model.completions(payload.clone()).await.inspect_err(|e| { - eprintln!("Completion error for model {}: {:?}", model_key, e); - })?; - tracer.log_success(&response); - return Ok(Json(response)); + if is_transient_error(status_code) { + eprintln!( + "Transient error for model {}, trying next model...", + model_key + ); + last_error = Some(status_code); + continue; + } else { + return Err(status_code); + } + } } } - tracer.log_error("No matching model found".to_string()); - eprintln!("No matching model found for: {}", payload.model); - Err(StatusCode::NOT_FOUND) + let error = last_error.unwrap(); + tracer.log_error(format!("All models failed with error: {}", error)); + Err(error) } pub async fn embeddings( @@ -138,19 +212,53 @@ pub async fn embeddings( ) -> impl IntoResponse { let mut tracer = OtelTracer::start("embeddings", &payload); - for model_key in model_keys { - let model = model_registry.get(&model_key).unwrap(); + let matching_models: Vec<_> = model_keys + .iter() + .filter_map(|key| { + let model = model_registry.get(key)?; + if payload.model == model.model_type { + Some((key.clone(), model)) + } else { + None + } + }) + .collect(); - if payload.model == model.model_type { - let response = model.embeddings(payload.clone()).await.inspect_err(|e| { - eprintln!("Embeddings error for model {}: {:?}", model_key, e); - })?; - tracer.log_success(&response); - return Ok(Json(response)); + if matching_models.is_empty() { + tracer.log_error("No matching model found".to_string()); + eprintln!("No matching model found for: {}", payload.model); + return Err(StatusCode::NOT_FOUND); + } + + let mut last_error = None; + + for (model_key, model) in matching_models { + match model.embeddings(payload.clone()).await { + Ok(response) => { + tracer.log_success(&response); + return Ok(Json(response)); + } + Err(status_code) => { + eprintln!( + "Embeddings error for model {}: {:?}", + model_key, status_code + ); + + if is_transient_error(status_code) { + eprintln!( + "Transient error for model {}, trying next model...", + model_key + ); + last_error = Some(status_code); + continue; + } else { + return Err(status_code); + } + } } } - tracer.log_error("No matching model found".to_string()); - eprintln!("No matching model found for: {}", payload.model); - Err(StatusCode::NOT_FOUND) + let error = last_error.unwrap(); + tracer.log_error(format!("All models failed with error: {}", error)); + Err(error) }